Skip to content

Commit c1f580e

Browse files
committed
does not compile atm, need to fix some issues
1 parent f383510 commit c1f580e

File tree

6 files changed

+142
-50
lines changed

6 files changed

+142
-50
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "am
4949

5050
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
5151
message(STATUS "Configuring for Apple Silicon ARM64: disabling AVX flags")
52-
# Apple Clang n’a pas de support AVX/FMA → on reste sur les optimisations génériques ARM
5352
set(CMAKE_CXX_FLAGS "-O3 -mcpu=apple-m1 -Wno-ignored-attributes")
5453
else()
5554
message(WARNING "Unknown architecture (${CMAKE_SYSTEM_PROCESSOR}); using generic optimization flags.")

Tests/test.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <cmath>
66
#include <vector>
77
#include <chrono>
8-
#include <immintrin.h>
98
#include <cassert>
109
#include <cstdlib>
1110
#include <cstring>

includes/Tensorium/Core/Derivate.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "Vector.hpp"
99
#include <cassert>
1010
#include <cmath>
11-
#include <immintrin.h>
1211
#include <iostream>
1312
#include <numeric>
1413
#include <vector>

includes/Tensorium/Core/Matrix.hpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "Vector.hpp"
99
#include <cassert>
1010
#include <cmath>
11-
#include <immintrin.h>
1211
#include <iostream>
1312
#include <vector>
1413

@@ -209,25 +208,51 @@ template <typename K, bool RowMajor = false> class Matrix {
209208
* Uses blocking and micro-kernels to avoid cache bottleneck with FMA/AVX units repartition.
210209
* Fast-paths exist for 4×4, 8×8, and 16×16.
211210
*/
211+
212212
inline Matrix _mul_mat(const Matrix<K> &mat) const {
213213
if (cols != mat.rows)
214214
throw std::invalid_argument("Matrix dimensions do not match for multiplication");
215215

216216
Matrix<K> result(rows, mat.cols);
217217

218-
const K *A = data.data(); // Already column-major (this)
219-
const K *B = mat.data.data(); // Already column-major (rhs)
220-
K *C = result.data.data(); // Output (also column-major)
218+
const K *A = data.data(); // column-major (this)
219+
const K *B = mat.data.data(); // column-major (rhs)
220+
K *C = result.data.data(); // column-major output
221221

222+
#if defined(TENSORIUM_X86)
223+
// SIMD kernel for x86 (AVX2 / AVX512)
222224
tensorium::GemmKernelBigger<K> kernel;
223225
kernel.matmul(const_cast<K *>(A), const_cast<K *>(B), C,
224226
static_cast<int>(rows), // M
225227
static_cast<int>(mat.cols), // N
226-
static_cast<int>(cols) // K
227-
);
228+
static_cast<int>(cols)); // K
229+
230+
#elif defined(TENSORIUM_ARM)
231+
// Temporary fallback (naïve scalar matmul)
232+
for (size_t i = 0; i < rows; ++i) {
233+
for (size_t j = 0; j < mat.cols; ++j) {
234+
K sum = static_cast<K>(0);
235+
for (size_t k = 0; k < cols; ++k)
236+
sum += A[i + k * rows] * B[k + j * mat.rows];
237+
C[i + j * rows] = sum;
238+
}
239+
}
240+
241+
#else
242+
// Generic scalar fallback
243+
for (size_t i = 0; i < rows; ++i) {
244+
for (size_t j = 0; j < mat.cols; ++j) {
245+
K sum = static_cast<K>(0);
246+
for (size_t k = 0; k < cols; ++k)
247+
sum += A[i + k * rows] * B[k + j * mat.rows];
248+
C[i + j * rows] = sum;
249+
}
250+
}
251+
#endif
228252

229253
return result;
230254
}
255+
231256
/**
232257
* @brief Multiply matrix by a vector using SIMD
233258
*
@@ -455,37 +480,44 @@ template <typename K, bool RowMajor = false> class Matrix {
455480
}
456481

457482
return r;
458-
}
483+
}
459484

460-
Matrix& operator+=(const Matrix& m) { this->add(m); return *this; }
461-
Matrix& operator-=(const Matrix& m) { this->sub(m); return *this; }
462-
Matrix& operator*=(K alpha) { this->scl(alpha); return *this; }
485+
Matrix &operator+=(const Matrix &m) {
486+
this->add(m);
487+
return *this;
488+
}
489+
Matrix &operator-=(const Matrix &m) {
490+
this->sub(m);
491+
return *this;
492+
}
493+
Matrix &operator*=(K alpha) {
494+
this->scl(alpha);
495+
return *this;
496+
}
463497
};
464-
template<typename K, bool RM>
465-
Matrix<K, RM> operator+(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
466-
Matrix<K, RM> res = a;
467-
res.add(b);
468-
return res;
498+
template <typename K, bool RM>
499+
Matrix<K, RM> operator+(const Matrix<K, RM> &a, const Matrix<K, RM> &b) {
500+
Matrix<K, RM> res = a;
501+
res.add(b);
502+
return res;
469503
}
470-
template<typename K, bool RM>
471-
Matrix<K, RM> operator-(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
472-
Matrix<K, RM> res = a;
473-
res.sub(b);
474-
return res;
504+
template <typename K, bool RM>
505+
Matrix<K, RM> operator-(const Matrix<K, RM> &a, const Matrix<K, RM> &b) {
506+
Matrix<K, RM> res = a;
507+
res.sub(b);
508+
return res;
475509
}
476-
template<typename K, bool RM>
477-
Matrix<K, RM> operator*(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
478-
return a._mul_mat(b);
510+
template <typename K, bool RM>
511+
Matrix<K, RM> operator*(const Matrix<K, RM> &a, const Matrix<K, RM> &b) {
512+
return a._mul_mat(b);
479513
}
480-
template<typename K, bool RM>
481-
Matrix<K, RM> operator*(const Matrix<K, RM>& m, K alpha) {
482-
Matrix<K, RM> res = m;
483-
res.scl(alpha);
484-
return res;
514+
template <typename K, bool RM> Matrix<K, RM> operator*(const Matrix<K, RM> &m, K alpha) {
515+
Matrix<K, RM> res = m;
516+
res.scl(alpha);
517+
return res;
485518
}
486-
template<typename K, bool RM>
487-
Matrix<K, RM> operator*(K alpha, const Matrix<K, RM>& m) {
488-
return m * alpha;
519+
template <typename K, bool RM> Matrix<K, RM> operator*(K alpha, const Matrix<K, RM> &m) {
520+
return m * alpha;
489521
}
490522

491523
} // namespace tensorium

includes/Tensorium/Core/MatrixKernels/GemmKernel_bigger.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
* sizes
1111
*
1212
*/
13+
#ifdef TENSORIUM_X86
14+
1315
namespace tensorium {
1416
template <typename T> class GemmKernelBigger {
1517
public:
@@ -877,3 +879,4 @@ template <typename T> T GemmKernelBigger<T>::blockA_packed[MC * KC] __attribute_
877879

878880
template <typename T> T GemmKernelBigger<T>::blockB_packed[NC * KC] __attribute__((aligned(64)));
879881
} // namespace tensorium
882+
#endif

includes/Tensorium/SIMD/SIMD.hpp

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@
2727
# define ALIGN 8
2828
#endif
2929

30+
// Défauts si rien d'autre ne les fixe plus bas
31+
#ifndef UNROLL
32+
# define UNROLL 4
33+
#endif
34+
#ifndef SIMD_WIDTH
35+
# define SIMD_WIDTH 4
36+
#endif
37+
#ifndef ALIGN
38+
# define ALIGN 16
39+
#endif
40+
41+
// OpenMP : laisse l'include compiler seulement si activé par le compilateur
42+
#ifdef _OPENMP
43+
# include <omp.h>
44+
#endif
3045
// disable x86-only prefetch macros on ARM
3146
#if !defined(__x86_64__)
3247
# define _MM_HINT_T0 0
@@ -1071,7 +1086,15 @@ template <typename F> inline void dispatch_simd(F &&f) { f(DefaultISA{}); }
10711086

10721087
namespace simd {
10731088

1074-
// ───────────────────────── float32 (neon32_t) ─────────────────────────
1089+
static inline float32x4_t andnot_f32(float32x4_t a, float32x4_t b) {
1090+
uint32x4_t na = veorq_u32(vreinterpretq_u32_f32(a), vdupq_n_u32(~0u));
1091+
return vreinterpretq_f32_u32(vandq_u32(na, vreinterpretq_u32_f32(b)));
1092+
}
1093+
static inline float64x2_t andnot_f64(float64x2_t a, float64x2_t b) {
1094+
uint64x2_t na = veorq_u64(vreinterpretq_u64_f64(a), vdupq_n_u64(~0ULL));
1095+
return vreinterpretq_f64_u64(vandq_u64(na, vreinterpretq_u64_f64(b)));
1096+
}
1097+
10751098
template <> struct SimdTraits<float, neon32_t> {
10761099
using reg = float32x4_t;
10771100
static constexpr size_t width = 4;
@@ -1083,6 +1106,7 @@ template <> struct SimdTraits<float, neon32_t> {
10831106
static inline reg loadu(const float *p) { return vld1q_f32(p); }
10841107
static inline void store(float *p, reg v) { vst1q_f32(p, v); }
10851108
static inline void storeu(float *p, reg v) { vst1q_f32(p, v); }
1109+
static inline void store_stream(float *p, reg v) { vst1q_f32(p, v); } // pas de NT-store NEON
10861110
static inline reg zero() { return vdupq_n_f32(0.f); }
10871111
static inline reg add(reg a, reg b) { return vaddq_f32(a, b); }
10881112
static inline reg sub(reg a, reg b) { return vsubq_f32(a, b); }
@@ -1092,6 +1116,17 @@ template <> struct SimdTraits<float, neon32_t> {
10921116
# else
10931117
static inline reg fmadd(reg a, reg b, reg c) { return vaddq_f32(c, vmulq_f32(a, b)); }
10941118
# endif
1119+
static inline reg max(reg a, reg b) { return vmaxq_f32(a, b); }
1120+
static inline reg min(reg a, reg b) { return vminq_f32(a, b); }
1121+
static inline reg andnot(reg a, reg b) { return andnot_f32(a, b); }
1122+
1123+
// float (neon32_t)
1124+
static inline float extract(reg x, size_t idx) {
1125+
alignas(16) float t[4];
1126+
vst1q_f32(t, x);
1127+
return t[idx & 3];
1128+
}
1129+
10951130
static inline float horizontal_add(reg v) {
10961131
float32x2_t lo = vget_low_f32(v);
10971132
float32x2_t hi = vget_high_f32(v);
@@ -1112,24 +1147,31 @@ template <> struct SimdTraits<double, neon64_t> {
11121147
static inline reg loadu(const double *p) { return vld1q_f64(p); }
11131148
static inline void store(double *p, reg v) { vst1q_f64(p, v); }
11141149
static inline void storeu(double *p, reg v) { vst1q_f64(p, v); }
1150+
static inline void store_stream(double *p, reg v) { vst1q_f64(p, v); }
11151151
static inline reg zero() { return vdupq_n_f64(0.0); }
11161152
static inline reg add(reg a, reg b) { return vaddq_f64(a, b); }
11171153
static inline reg sub(reg a, reg b) { return vsubq_f64(a, b); }
11181154
static inline reg mul(reg a, reg b) { return vmulq_f64(a, b); }
11191155
# if defined(__aarch64__)
1120-
static inline reg fmadd(reg a, reg b, reg c) { return vfmaq_f64(c, a, b); } // c + a*b
1156+
static inline reg fmadd(reg a, reg b, reg c) { return vfmaq_f64(c, a, b); }
11211157
# else
11221158
static inline reg fmadd(reg a, reg b, reg c) { return vaddq_f64(c, vmulq_f64(a, b)); }
11231159
# endif
1160+
static inline reg max(reg a, reg b) { return vmaxq_f64(a, b); }
1161+
static inline reg min(reg a, reg b) { return vminq_f64(a, b); }
1162+
static inline reg andnot(reg a, reg b) { return andnot_f32(a, b); }
1163+
static inline double extract(reg x, size_t idx = 0) {
1164+
alignas(16) double t[2];
1165+
vst1q_f64(t, x);
1166+
return t[idx & 1];
1167+
}
11241168
static inline double horizontal_add(reg v) {
11251169
float64x1_t s = vadd_f64(vget_low_f64(v), vget_high_f64(v));
11261170
return vget_lane_f64(s, 0);
11271171
}
11281172
};
11291173

1130-
// ───────────────────────── entier 64-bit (fallback simple) ─────────────────────────
1131-
// On s'en tient à des opérations élémentaires (pas de mul 64x64 SIMD portable en NEON).
1132-
template <> struct SimdTraits<size_t, neon64_t> {
1174+
template <> struct SimdTraits<size_t, neon32_t> {
11331175
using reg = uint64x2_t;
11341176
static constexpr size_t width = 2;
11351177
static constexpr size_t alignment = 16;
@@ -1140,11 +1182,11 @@ template <> struct SimdTraits<size_t, neon64_t> {
11401182
static inline reg loadu(const size_t *p) { return vld1q_u64((const uint64_t *)p); }
11411183
static inline void store(size_t *p, reg v) { vst1q_u64((uint64_t *)p, v); }
11421184
static inline void storeu(size_t *p, reg v) { vst1q_u64((uint64_t *)p, v); }
1185+
static inline void store_stream(size_t *p, reg v) { vst1q_u64((uint64_t *)p, v); }
11431186
static inline reg zero() { return vdupq_n_u64(0); }
11441187
static inline reg add(reg a, reg b) { return vaddq_u64(a, b); }
11451188
static inline reg sub(reg a, reg b) { return vsubq_u64(a, b); }
1146-
// mul element-wise (scalaire) pour rester portable
1147-
static inline reg mul(reg a, reg b) {
1189+
static inline reg mul(reg a, reg b) {
11481190
uint64_t A[2], B[2], R[2];
11491191
vst1q_u64(A, a);
11501192
vst1q_u64(B, b);
@@ -1161,10 +1203,19 @@ template <> struct SimdTraits<size_t, neon64_t> {
11611203
R[1] = A[1] * B[1] + C[1];
11621204
return vld1q_u64(R);
11631205
}
1206+
static inline reg andnot(reg a, reg b) {
1207+
uint64x2_t na = veorq_u64(a, vdupq_n_u64(~0ULL));
1208+
return vandq_u64(na, b);
1209+
}
1210+
static inline size_t extract(reg x, size_t idx) {
1211+
alignas(16) uint64_t t[2];
1212+
vst1q_u64(t, x);
1213+
return (size_t)t[idx & 1];
1214+
}
11641215
static inline uint64_t horizontal_add(reg v) {
1165-
uint64_t tmp[2];
1166-
vst1q_u64(tmp, v);
1167-
return tmp[0] + tmp[1];
1216+
alignas(16) uint64_t t[2];
1217+
vst1q_u64(t, v);
1218+
return t[0] + t[1];
11681219
}
11691220
};
11701221

@@ -1247,12 +1298,21 @@ template <> struct SimdTraits<std::complex<double>, neon64_t> {
12471298
} // namespace simd
12481299

12491300
namespace detail {
1250-
template <typename Simd> static inline typename Simd::reg reduce_sum(typename Simd::reg v) {
1251-
# if defined(__x86_64__)
1252-
return v;
1253-
# else
1254-
return v;
1255-
# endif
1301+
inline float reduce_sum(float32x4_t v) { // float
1302+
float32x2_t lo = vget_low_f32(v);
1303+
float32x2_t hi = vget_high_f32(v);
1304+
float32x2_t s = vadd_f32(lo, hi);
1305+
s = vpadd_f32(s, s);
1306+
return vget_lane_f32(s, 0);
1307+
}
1308+
inline double reduce_sum(float64x2_t v) { // double
1309+
float64x1_t s = vadd_f64(vget_low_f64(v), vget_high_f64(v));
1310+
return vget_lane_f64(s, 0);
1311+
}
1312+
inline uint64_t reduce_sum(uint64x2_t v) { // entier
1313+
uint64_t t[2];
1314+
vst1q_u64(t, v);
1315+
return t[0] + t[1];
12561316
}
12571317
} // namespace detail
12581318

0 commit comments

Comments
 (0)