1- #pragma once
1+ #pragma once
2+ #include < cassert>
3+ #include < cstddef>
4+ #include < immintrin.h>
5+ #include < memory>
6+ #include < type_traits>
7+
8+ namespace tensorium_RG {
9+
10+ #ifndef TENSORIUM_ALIGN
11+ #define TENSORIUM_ALIGN 64
12+ #endif
13+
14+ #ifndef TENSORIUM_SIMD_WIDTH_F32
15+ #if defined(__AVX512F__)
16+ #define TENSORIUM_SIMD_WIDTH_F32 16
17+ #elif defined(__AVX2__)
18+ #define TENSORIUM_SIMD_WIDTH_F32 8
19+ #else
20+ #define TENSORIUM_SIMD_WIDTH_F32 4
21+ #endif
22+ #endif
23+
24+ template <typename T> inline size_t pad_simd (size_t n) {
25+ const size_t w =
26+ (std::is_same<T, float >::value ? TENSORIUM_SIMD_WIDTH_F32 : 8 );
27+ return ((n + w - 1 ) / w) * w;
28+ }
29+
30+ template <typename T> struct AlignedDeleter {
31+ void operator ()(T *p) const noexcept {
32+ ::operator delete[] (p, std::align_val_t (TENSORIUM_ALIGN));
33+ }
34+ };
35+
36+ template <typename T>
37+ using aligned_unique_ptr = std::unique_ptr<T[], AlignedDeleter<T>>;
38+
39+ template <typename T> aligned_unique_ptr<T> aligned_alloc_n (size_t n) {
40+ return aligned_unique_ptr<T>(static_cast <T *>(
41+ ::operator new [](n * sizeof (T), std::align_val_t (TENSORIUM_ALIGN))));
42+ }
43+
44+ struct GridDims {
45+ size_t nx, ny, nz;
46+ size_t ng;
47+ };
48+
49+ template <typename T> struct Strides {
50+ size_t sx, sy, sz;
51+ size_t nx_tot, ny_tot, nz_tot;
52+ };
53+
54+ template <typename T> struct Field3D {
55+ aligned_unique_ptr<T> data;
56+ Strides<T> st;
57+
58+ inline T *ptr () const noexcept { return data.get (); }
59+ inline size_t idx (size_t i, size_t j, size_t k) const noexcept {
60+ return i * st.sx + j * st.sy + k * st.sz ;
61+ }
62+ };
63+
64+ template <typename T>
65+ void halo_periodic (Field3D<T> &f, const GridDims &D, const Strides<T> &st) {
66+ const size_t I0 = D.ng , I1 = D.ng + D.nx ;
67+ const size_t J0 = D.ng , J1 = D.ng + D.ny ;
68+ const size_t K0 = D.ng , K1 = D.ng + D.nz ;
69+
70+ for (size_t g = 0 ; g < D.ng ; ++g) {
71+ size_t isrc = I1 - 1 - g;
72+ size_t idst = I0 - 1 - g;
73+ for (size_t j = J0; j < J1; ++j)
74+ for (size_t k = K0; k < K1; ++k)
75+ f.ptr ()[idst * st.sx + j * st.sy + k] =
76+ f.ptr ()[isrc * st.sx + j * st.sy + k];
77+ }
78+
79+ for (size_t g = 0 ; g < D.ng ; ++g) {
80+ size_t isrc = I0 + g;
81+ size_t idst = I1 + g;
82+ for (size_t j = J0; j < J1; ++j)
83+ for (size_t k = K0; k < K1; ++k)
84+ f.ptr ()[idst * st.sx + j * st.sy + k] =
85+ f.ptr ()[isrc * st.sx + j * st.sy + k];
86+ }
87+
88+ for (size_t g = 0 ; g < D.ng ; ++g) {
89+ size_t jsrc = J1 - 1 - g, jdst = J0 - 1 - g;
90+ for (size_t i = I0; i < I1; ++i)
91+ for (size_t k = K0; k < K1; ++k)
92+ f.ptr ()[i * st.sx + jdst * st.sy + k] =
93+ f.ptr ()[i * st.sx + jsrc * st.sy + k];
94+ }
95+ for (size_t g = 0 ; g < D.ng ; ++g) {
96+ size_t jsrc = J0 + g, jdst = J1 + g;
97+ for (size_t i = I0; i < I1; ++i)
98+ for (size_t k = K0; k < K1; ++k)
99+ f.ptr ()[i * st.sx + jdst * st.sy + k] =
100+ f.ptr ()[i * st.sx + jsrc * st.sy + k];
101+ }
102+
103+ for (size_t g = 0 ; g < D.ng ; ++g) {
104+ size_t ksrc = K1 - 1 - g, kdst = K0 - 1 - g;
105+ for (size_t i = I0; i < I1; ++i)
106+ for (size_t j = J0; j < J1; ++j)
107+ f.ptr ()[i * st.sx + j * st.sy + kdst] =
108+ f.ptr ()[i * st.sx + j * st.sy + ksrc];
109+ }
110+ for (size_t g = 0 ; g < D.ng ; ++g) {
111+ size_t ksrc = K0 + g, kdst = K1 + g;
112+ for (size_t i = I0; i < I1; ++i)
113+ for (size_t j = J0; j < J1; ++j)
114+ f.ptr ()[i * st.sx + j * st.sy + kdst] =
115+ f.ptr ()[i * st.sx + j * st.sy + ksrc];
116+ }
117+ }
118+
119+ enum Sym6 { XX = 0 , XY = 1 , XZ = 2 , YY = 3 , YZ = 4 , ZZ = 5 };
120+
121+ template <typename T> class BSSNGridSoA {
122+ public:
123+ GridDims dims;
124+ Strides<T> st;
125+
126+ Field3D<T> alpha, chi;
127+
128+ Field3D<T> beta[3 ], tildeGamma[3 ], contractedGamma[3 ];
129+
130+ Field3D<T> gamma_ij[6 ], gamma_ij_inv[6 ];
131+ Field3D<T> gamma_tilde[6 ], gamma_tilde_inv[6 ];
132+ Field3D<T> A_tilde[6 ], K_ij[6 ];
133+
134+ Field3D<T> d_beta[3 ][3 ];
135+ Field3D<T> d_gamma[6 ][3 ];
136+
137+ T dx, dy, dz;
138+
139+ BSSNGridSoA (size_t nx, size_t ny, size_t nz, size_t ng, T dx_, T dy_, T dz_)
140+ : dims{nx, ny, nz, ng}, dx(dx_), dy(dy_), dz(dz_) {
141+ const size_t nx_tot = pad_simd<T>(nx + 2 * ng);
142+ const size_t ny_tot = ny + 2 * ng;
143+ const size_t nz_tot = pad_simd<T>(nz + 2 * ng);
144+
145+ st.nx_tot = nx_tot;
146+ st.ny_tot = ny_tot;
147+ st.nz_tot = nz_tot;
148+ st.sz = 1 ;
149+ st.sy = nz_tot;
150+ st.sx = ny_tot * nz_tot;
151+
152+ auto alloc_field = [&](Field3D<T> &f) {
153+ const size_t N = nx_tot * ny_tot * nz_tot;
154+ f.data = aligned_alloc_n<T>(N);
155+ f.st = st;
156+ };
157+
158+ alloc_field (alpha);
159+ alloc_field (chi);
160+ for (int i = 0 ; i < 3 ; ++i) {
161+ alloc_field (beta[i]);
162+ alloc_field (tildeGamma[i]);
163+ alloc_field (contractedGamma[i]);
164+ }
165+ for (int s = 0 ; s < 6 ; ++s) {
166+ alloc_field (gamma_ij[s]);
167+ alloc_field (gamma_ij_inv[s]);
168+ alloc_field (gamma_tilde[s]);
169+ alloc_field (gamma_tilde_inv[s]);
170+ alloc_field (A_tilde[s]);
171+ alloc_field (K_ij[s]);
172+ }
173+ for (int i = 0 ; i < 3 ; ++i)
174+ for (int j = 0 ; j < 3 ; ++j)
175+ alloc_field (d_beta[i][j]);
176+ for (int s = 0 ; s < 6 ; ++s)
177+ for (int i = 0 ; i < 3 ; ++i)
178+ alloc_field (d_gamma[s][i]);
179+ }
180+
181+ inline void domain_bounds (size_t &i0, size_t &i1, size_t &j0, size_t &j1,
182+ size_t &k0, size_t &k1) const noexcept {
183+ i0 = dims.ng ;
184+ i1 = dims.ng + dims.nx ;
185+ j0 = dims.ng ;
186+ j1 = dims.ng + dims.ny ;
187+ k0 = dims.ng ;
188+ k1 = dims.ng + dims.nz ;
189+ }
190+
191+ T x0 = 0 , y0 = 0 , z0 = 0 ;
192+ inline void coords (size_t i, size_t j, size_t k, T &x, T &y,
193+ T &z) const noexcept {
194+ x = x0 + (i - dims.ng ) * dx;
195+ y = y0 + (j - dims.ng ) * dy;
196+ z = z0 + (k - dims.ng ) * dz;
197+ }
198+ };
199+
200+ template <typename T>
201+ inline void store_sym6 (Field3D<T> *f6, size_t idx, T xx, T xy, T xz, T yy, T yz,
202+ T zz) {
203+ f6[XX].ptr ()[idx] = xx;
204+ f6[XY].ptr ()[idx] = xy;
205+ f6[XZ].ptr ()[idx] = xz;
206+ f6[YY].ptr ()[idx] = yy;
207+ f6[YZ].ptr ()[idx] = yz;
208+ f6[ZZ].ptr ()[idx] = zz;
209+ }
210+
211+ template <typename T>
212+ inline void load_sym6 (Field3D<T> *f6, size_t idx, T &xx, T &xy, T &xz, T &yy,
213+ T &yz, T &zz) {
214+ xx = f6[XX].ptr ()[idx];
215+ xy = f6[XY].ptr ()[idx];
216+ xz = f6[XZ].ptr ()[idx];
217+ yy = f6[YY].ptr ()[idx];
218+ yz = f6[YZ].ptr ()[idx];
219+ zz = f6[ZZ].ptr ()[idx];
220+ }
221+
222+
223+ template <typename T> inline Field3D<T> make_field (const Strides<T> &st) {
224+ Field3D<T> f;
225+ f.st = st;
226+ const size_t N = st.nx_tot * st.ny_tot * st.nz_tot ;
227+ f.data = aligned_alloc_n<T>(N);
228+ return f;
229+ }
230+
231+ struct BoundaryPeriodic {
232+ template <typename T>
233+ static void apply (Field3D<T> &f, const GridDims &D, const Strides<T> &st) {
234+ halo_periodic (f, D, st);
235+ }
236+ };
237+
238+ struct BoundaryClamp {
239+ template <typename T>
240+ static void apply (Field3D<T> &f, const GridDims &D, const Strides<T> &st) {
241+ const size_t I0 = D.ng , I1 = D.ng + D.nx , J0 = D.ng , J1 = D.ng + D.ny ,
242+ K0 = D.ng , K1 = D.ng + D.nz ;
243+ for (size_t g = 0 ; g < D.ng ; ++g) {
244+ size_t idst = I0 - 1 - g, isrc = I0;
245+ for (size_t j = J0; j < J1; ++j)
246+ for (size_t k = K0; k < K1; ++k)
247+ f.ptr ()[idst * st.sx + j * st.sy + k] =
248+ f.ptr ()[isrc * st.sx + j * st.sy + k];
249+ }
250+ for (size_t g = 0 ; g < D.ng ; ++g) {
251+ size_t idst = I1 + g, isrc = I1 - 1 ;
252+ for (size_t j = J0; j < J1; ++j)
253+ for (size_t k = K0; k < K1; ++k)
254+ f.ptr ()[idst * st.sx + j * st.sy + k] =
255+ f.ptr ()[isrc * st.sx + j * st.sy + k];
256+ }
257+ for (size_t g = 0 ; g < D.ng ; ++g) {
258+ size_t jdst = J0 - 1 - g, jsrc = J0;
259+ for (size_t i = I0; i < I1; ++i)
260+ for (size_t k = K0; k < K1; ++k)
261+ f.ptr ()[i * st.sx + jdst * st.sy + k] =
262+ f.ptr ()[i * st.sx + jsrc * st.sy + k];
263+ }
264+ for (size_t g = 0 ; g < D.ng ; ++g) {
265+ size_t jdst = J1 + g, jsrc = J1 - 1 ;
266+ for (size_t i = I0; i < I1; ++i)
267+ for (size_t k = K0; k < K1; ++k)
268+ f.ptr ()[i * st.sx + jdst * st.sy + k] =
269+ f.ptr ()[i * st.sx + jsrc * st.sy + k];
270+ }
271+ for (size_t g = 0 ; g < D.ng ; ++g) {
272+ size_t kdst = K0 - 1 - g, ksrc = K0;
273+ for (size_t i = I0; i < I1; ++i)
274+ for (size_t j = J0; j < J1; ++j)
275+ f.ptr ()[i * st.sx + j * st.sy + kdst] =
276+ f.ptr ()[i * st.sx + j * st.sy + ksrc];
277+ }
278+ for (size_t g = 0 ; g < D.ng ; ++g) {
279+ size_t kdst = K1 + g, ksrc = K1 - 1 ;
280+ for (size_t i = I0; i < I1; ++i)
281+ for (size_t j = J0; j < J1; ++j)
282+ f.ptr ()[i * st.sx + j * st.sy + kdst] =
283+ f.ptr ()[i * st.sx + j * st.sy + ksrc];
284+ }
285+ }
286+ };
287+
288+ template <class Boundary , typename T>
289+ inline void apply_halos_grid (BSSNGridSoA<T> &G) {
290+ auto &D = G.dims ;
291+ auto &st = G.st ;
292+ Boundary::apply (G.alpha , D, st);
293+ Boundary::apply (G.chi , D, st);
294+ for (int c = 0 ; c < 3 ; ++c) {
295+ Boundary::apply (G.beta [c], D, st);
296+ Boundary::apply (G.tildeGamma [c], D, st);
297+ Boundary::apply (G.contractedGamma [c], D, st);
298+ }
299+ for (int s = 0 ; s < 6 ; ++s) {
300+ Boundary::apply (G.gamma_ij [s], D, st);
301+ Boundary::apply (G.gamma_ij_inv [s], D, st);
302+ Boundary::apply (G.gamma_tilde [s], D, st);
303+ Boundary::apply (G.gamma_tilde_inv [s], D, st);
304+ Boundary::apply (G.A_tilde [s], D, st);
305+ Boundary::apply (G.K_ij [s], D, st);
306+ }
307+ }
308+ } // namespace tensorium_RG
0 commit comments