Skip to content

Commit 77971c2

Browse files
committed
fix up
1 parent de73f39 commit 77971c2

File tree

7 files changed

+52
-80
lines changed

7 files changed

+52
-80
lines changed

npsr/hwy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using hn::TFromV;
2424
using hn::VFromD;
2525
constexpr bool kNativeFMA = HWY_NATIVE_FMA != 0;
2626

27-
HWY_ATTR void DummyToSuppressUnusedWarning() {}
27+
inline HWY_ATTR void DummyToSuppressUnusedWarning() {}
2828
} // namespace npsr::HWY_NAMESPACE
2929
HWY_AFTER_NAMESPACE();
3030

npsr/lut-inl.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#define NPSR_LUT_INL_H_
66
#endif
77

8+
#include <tuple>
9+
810
#include "npsr/hwy.h"
911

1012
HWY_BEFORE_NAMESPACE();
@@ -107,22 +109,16 @@ class Lut {
107109
using D = Rebind<T, DU>;
108110
const D d;
109111

110-
#if !HWY_HAVE_SCALABLE
111-
constexpr size_t kLanes = Lanes(du);
112-
if constexpr (kLanes == kCols) {
112+
HWY_LANES_CONSTEXPR size_t kLanes = Lanes(du);
113+
if HWY_LANES_CONSTEXPR (kLanes == kCols) {
113114
// Vector size matches table width - use single table lookup
114115
const auto ind = IndicesFromVec(d, idx);
115116
LoadX1_<Off>(ind, out...);
116-
} else if constexpr (kLanes * 2 == kCols) {
117+
} else if HWY_LANES_CONSTEXPR (kLanes * 2 == kCols) {
117118
// Vector size is half table width - use two table lookup
118119
const auto ind = IndicesFromVec(d, idx);
119120
LoadX2_<Off>(ind, out...);
120-
}
121-
#else
122-
if constexpr (0) {
123-
}
124-
#endif
125-
else {
121+
} else {
126122
// Fallback to gather for other configurations
127123
LoadGather_<Off>(idx, out...);
128124
}
@@ -135,8 +131,8 @@ class Lut {
135131
using D = DFromV<OutV0>;
136132
const D d;
137133

138-
const OutV0 lut0 = Load(d, row_ + Off);
139-
out0 = TableLookupLanes(d, lut0, ind);
134+
const OutV0 lut0 = LoadU(d, row_ + Off);
135+
out0 = TableLookupLanes(lut0, ind);
140136

141137
if constexpr (sizeof...(OutV) > 0) {
142138
LoadX1_<Off + kCols>(ind, out...);

npsr/precise.h

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,6 @@ constexpr auto kNoSpecialCases = _NoSpecialCases{};
2727
constexpr auto kNoExceptions = _NoExceptions{};
2828
constexpr auto kLowAccuracy = _LowAccuracy{};
2929

30-
// Rounding mode control
31-
// Forces a specific rounding mode during computation
32-
struct Round {
33-
struct _Force {
34-
static constexpr const char* kName = "kForce";
35-
};
36-
static constexpr auto kForce = _Force{};
37-
};
38-
3930
// Subnormal (denormal) number handling modes
4031
// Controls how the CPU handles numbers smaller than the minimum normalized
4132
// value
@@ -52,12 +43,34 @@ struct Subnormal {
5243

5344
// Floating-point exception flags
5445
// These match the standard C library FE_* macros
55-
struct FPExceptions {
46+
class FPExceptions {
47+
public:
5648
static constexpr auto kNone = 0;
5749
static constexpr auto kInvalid = FE_INVALID;
5850
static constexpr auto kDivByZero = FE_DIVBYZERO;
5951
static constexpr auto kOverflow = FE_OVERFLOW;
6052
static constexpr auto kUnderflow = FE_UNDERFLOW;
53+
54+
void Raise(int errors) noexcept { mask_ |= errors; }
55+
56+
protected:
57+
void Load() noexcept {
58+
loaded_ = std::fegetexceptflag(&saved_, FE_ALL_EXCEPT) == 0;
59+
}
60+
61+
~FPExceptions() noexcept {
62+
if (loaded_) {
63+
std::fesetexceptflag(&saved_, FE_ALL_EXCEPT);
64+
}
65+
if (mask_ != kNone) {
66+
std::feraiseexcept(mask_);
67+
}
68+
}
69+
70+
private:
71+
bool loaded_ = false;
72+
int mask_ = kNone;
73+
std::fexcept_t saved_;
6174
};
6275

6376
/**
@@ -84,7 +97,6 @@ struct FPExceptions {
8497
* - kNoLargeArgument: Skip extended precision reduction for large arguments
8598
* - kNoSpecialCases: Skip NaN/Inf handling (assumes finite inputs)
8699
* - kNoExceptions: Disable FP exception tracking for better performance
87-
* - Round::kForce: Force round-to-nearest mode
88100
* - Subnormal::kDAZ/kFTZ: Flush subnormals to zero for performance
89101
* - Subnormal::kIEEE754: Strict IEEE 754 compliance (default if DAZ/FTZ not
90102
* specified)
@@ -109,23 +121,13 @@ struct FPExceptions {
109121
* ```
110122
*/
111123
template <typename... Args>
112-
class Precise {
124+
class Precise : public FPExceptions {
113125
public:
114126
// Default constructor saves current FP state
115127
Precise() noexcept {
116128
// Save exception flags unless disabled
117129
if constexpr (!kNoExceptions) {
118-
fegetexceptflag(&_exceptions, FE_ALL_EXCEPT);
119-
}
120-
121-
// Force rounding mode if requested
122-
if constexpr (kRoundForce) {
123-
_rounding_mode = fegetround();
124-
int new_mode = _NewRoundingMode();
125-
if (_rounding_mode != new_mode) {
126-
_retrieve_rounding_mode = true;
127-
fesetround(new_mode);
128-
}
130+
FPExceptions::Load();
129131
}
130132
}
131133

@@ -136,33 +138,8 @@ class Precise {
136138
// This constructor exists to enable Precise{tag1, tag2, ...} syntax
137139
}
138140

139-
// Restore saved exception flags to FP environment
140-
void FlushExceptions() noexcept {
141-
if constexpr (!kNoExceptions) {
142-
fesetexceptflag(&_exceptions, FE_ALL_EXCEPT);
143-
}
144-
}
145-
146-
// Record that an exception occurred (will be raised on destruction)
147-
void Raise(int errors) noexcept {
148-
static_assert(!kNoExceptions,
149-
"Cannot raise exceptions in NoExceptions mode");
150-
_exceptions |= errors;
151-
}
152-
153-
// Destructor restores original FP state
154-
~Precise() noexcept {
155-
FlushExceptions();
156-
if constexpr (kRoundForce) {
157-
if (_retrieve_rounding_mode) {
158-
fesetround(_rounding_mode);
159-
}
160-
}
161-
}
162-
163141
// Compile-time configuration queries
164142
// These allow algorithms to optimize based on precision requirements
165-
166143
static constexpr bool kNoExceptions = (is_same_v<_NoExceptions, Args> || ...);
167144
static constexpr bool kNoLargeArgument =
168145
(is_same_v<_NoLargeArgument, Args> || ...);
@@ -176,9 +153,6 @@ class Precise {
176153
static constexpr bool kSpecialCases = !kNoSpecialCases;
177154
static constexpr bool kExceptions = !kNoExceptions;
178155

179-
// Rounding mode configuration
180-
static constexpr bool kRoundForce = (is_same_v<Round::_Force, Args> || ...);
181-
182156
// Subnormal handling configuration
183157
static constexpr bool kDAZ = (is_same_v<Subnormal::_DAZ, Args> || ...);
184158
static constexpr bool kFTZ = (is_same_v<Subnormal::_FTZ, Args> || ...);
@@ -193,16 +167,6 @@ class Precise {
193167

194168
// Default to IEEE754 if no subnormal mode specified
195169
static constexpr bool kIEEE754 = _kIEEE754 || !(kDAZ || kFTZ);
196-
197-
private:
198-
// Currently only supports round-to-nearest mode
199-
// Could be extended to support other modes (toward zero, up, down)
200-
int _NewRoundingMode() const { return FE_TONEAREST; }
201-
202-
// Saved floating-point state
203-
int _rounding_mode = 0;
204-
bool _retrieve_rounding_mode = false;
205-
fexcept_t _exceptions; // Saved exception flags
206170
}; // namespace npsr
207171

208172
// Deduction guides for convenient construction

npsr/trig/high-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ HWY_INLINE V High(V x) {
5656
// N' = N - 0.5
5757
n = Sub(n, Set(d, 0.5f));
5858
}
59-
auto WideCal = [](VW nh, VW xh_abs) -> VW {
59+
auto WideCal = [](const VW &nh, const VW &xh_abs) -> VW {
6060
const DFromV<VW> dw;
6161
constexpr auto kPiPrec35 = data::kPiPrec35<true>;
6262
VW r = NegMulAdd(nh, Set(dw, kPiPrec35[0]), xh_abs);

npsr/trig/low-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ HWY_API V Low(V x) {
120120
r_lo = NegMulAdd(n, Set(d, kPi[3]), r_lo);
121121
}
122122

123-
if (kIsSingle) {
123+
if constexpr (kIsSingle) {
124124
r = r_lo;
125125
}
126126
V r2 = Mul(r, r);
127127
V poly = PolyLow<IS_COS>(r, r2);
128128

129-
if (!kIsSingle) {
129+
if constexpr (!kIsSingle) {
130130
V r2_corr = Mul(r2, r_lo);
131131
poly = MulAdd(r2_corr, poly, r_lo);
132132
}

tools/sollya/cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pathlib import Path
1111
from typing import Final
1212
from dataclasses import dataclass
13-
from itertools import batched
1413

1514

1615
# ANSI color codes for terminal output

tools/sollya/core.sol

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,29 @@ procedure PrettyJoin(pList, pSfx, pSep, pLineEvery) {
120120
return r;
121121
};
122122

123-
// C array formatting procedures
123+
// Ensure zeros are represented as 0.0 for C++ template deduction
124+
procedure FixZero(pList) {
125+
var r, i;
126+
r = [||];
127+
for i in pList do {
128+
if (i == 0) then {
129+
r = r :. "0.0"; // Ensure zero is represented as 0.0
130+
} else {
131+
r = r :. i;
132+
};
133+
};
134+
return r;
135+
};
124136

137+
// C array formatting procedures
125138
// Generate C array initializer
126139
procedure CArray(pList, pLineEvery) {
127140
return "{\n" @ PrettyJoin(pList, "", ", ", pLineEvery) @ "}";
128141
};
129142

130143
// Generate C array with type-specific suffix (e.g., "f" for float)
131144
procedure CArrayT(pT, pList, pLineEvery) {
132-
return "{\n" @ PrettyJoin(pList, pT.kCSFX, ", ", pLineEvery) @ "}";
145+
return "{\n" @ PrettyJoin(FixZero(pList), pT.kCSFX, ", ", pLineEvery) @ "}";
133146
};
134147

135148
// Generate C array with unsigned integer suffix

0 commit comments

Comments
 (0)