|
8 | 8 | #include "Vector.hpp" |
9 | 9 | #include <cassert> |
10 | 10 | #include <cmath> |
11 | | -#include <immintrin.h> |
12 | 11 | #include <iostream> |
13 | 12 | #include <vector> |
14 | 13 |
|
@@ -209,25 +208,51 @@ template <typename K, bool RowMajor = false> class Matrix { |
209 | 208 | * Uses blocking and micro-kernels to avoid cache bottleneck with FMA/AVX units repartition. |
210 | 209 | * Fast-paths exist for 4×4, 8×8, and 16×16. |
211 | 210 | */ |
| 211 | + |
212 | 212 | inline Matrix _mul_mat(const Matrix<K> &mat) const { |
213 | 213 | if (cols != mat.rows) |
214 | 214 | throw std::invalid_argument("Matrix dimensions do not match for multiplication"); |
215 | 215 |
|
216 | 216 | Matrix<K> result(rows, mat.cols); |
217 | 217 |
|
218 | | - const K *A = data.data(); // Already column-major (this) |
219 | | - const K *B = mat.data.data(); // Already column-major (rhs) |
220 | | - K *C = result.data.data(); // Output (also column-major) |
| 218 | + const K *A = data.data(); // column-major (this) |
| 219 | + const K *B = mat.data.data(); // column-major (rhs) |
| 220 | + K *C = result.data.data(); // column-major output |
221 | 221 |
|
| 222 | +#if defined(TENSORIUM_X86) |
| 223 | + // SIMD kernel for x86 (AVX2 / AVX512) |
222 | 224 | tensorium::GemmKernelBigger<K> kernel; |
223 | 225 | kernel.matmul(const_cast<K *>(A), const_cast<K *>(B), C, |
224 | 226 | static_cast<int>(rows), // M |
225 | 227 | static_cast<int>(mat.cols), // N |
226 | | - static_cast<int>(cols) // K |
227 | | - ); |
| 228 | + static_cast<int>(cols)); // K |
| 229 | + |
| 230 | +#elif defined(TENSORIUM_ARM) |
| 231 | + // Temporary fallback (naïve scalar matmul) |
| 232 | + for (size_t i = 0; i < rows; ++i) { |
| 233 | + for (size_t j = 0; j < mat.cols; ++j) { |
| 234 | + K sum = static_cast<K>(0); |
| 235 | + for (size_t k = 0; k < cols; ++k) |
| 236 | + sum += A[i + k * rows] * B[k + j * mat.rows]; |
| 237 | + C[i + j * rows] = sum; |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | +#else |
| 242 | + // Generic scalar fallback |
| 243 | + for (size_t i = 0; i < rows; ++i) { |
| 244 | + for (size_t j = 0; j < mat.cols; ++j) { |
| 245 | + K sum = static_cast<K>(0); |
| 246 | + for (size_t k = 0; k < cols; ++k) |
| 247 | + sum += A[i + k * rows] * B[k + j * mat.rows]; |
| 248 | + C[i + j * rows] = sum; |
| 249 | + } |
| 250 | + } |
| 251 | +#endif |
228 | 252 |
|
229 | 253 | return result; |
230 | 254 | } |
| 255 | + |
231 | 256 | /** |
232 | 257 | * @brief Multiply matrix by a vector using SIMD |
233 | 258 | * |
@@ -455,37 +480,44 @@ template <typename K, bool RowMajor = false> class Matrix { |
455 | 480 | } |
456 | 481 |
|
457 | 482 | return r; |
458 | | - } |
| 483 | + } |
459 | 484 |
|
460 | | - Matrix& operator+=(const Matrix& m) { this->add(m); return *this; } |
461 | | - Matrix& operator-=(const Matrix& m) { this->sub(m); return *this; } |
462 | | - Matrix& operator*=(K alpha) { this->scl(alpha); return *this; } |
| 485 | + Matrix &operator+=(const Matrix &m) { |
| 486 | + this->add(m); |
| 487 | + return *this; |
| 488 | + } |
| 489 | + Matrix &operator-=(const Matrix &m) { |
| 490 | + this->sub(m); |
| 491 | + return *this; |
| 492 | + } |
| 493 | + Matrix &operator*=(K alpha) { |
| 494 | + this->scl(alpha); |
| 495 | + return *this; |
| 496 | + } |
463 | 497 | }; |
464 | | -template<typename K, bool RM> |
465 | | -Matrix<K, RM> operator+(const Matrix<K, RM>& a, const Matrix<K, RM>& b) { |
466 | | - Matrix<K, RM> res = a; |
467 | | - res.add(b); |
468 | | - return res; |
| 498 | +template <typename K, bool RM> |
| 499 | +Matrix<K, RM> operator+(const Matrix<K, RM> &a, const Matrix<K, RM> &b) { |
| 500 | + Matrix<K, RM> res = a; |
| 501 | + res.add(b); |
| 502 | + return res; |
469 | 503 | } |
470 | | -template<typename K, bool RM> |
471 | | -Matrix<K, RM> operator-(const Matrix<K, RM>& a, const Matrix<K, RM>& b) { |
472 | | - Matrix<K, RM> res = a; |
473 | | - res.sub(b); |
474 | | - return res; |
| 504 | +template <typename K, bool RM> |
| 505 | +Matrix<K, RM> operator-(const Matrix<K, RM> &a, const Matrix<K, RM> &b) { |
| 506 | + Matrix<K, RM> res = a; |
| 507 | + res.sub(b); |
| 508 | + return res; |
475 | 509 | } |
476 | | -template<typename K, bool RM> |
477 | | -Matrix<K, RM> operator*(const Matrix<K, RM>& a, const Matrix<K, RM>& b) { |
478 | | - return a._mul_mat(b); |
| 510 | +template <typename K, bool RM> |
| 511 | +Matrix<K, RM> operator*(const Matrix<K, RM> &a, const Matrix<K, RM> &b) { |
| 512 | + return a._mul_mat(b); |
479 | 513 | } |
480 | | -template<typename K, bool RM> |
481 | | -Matrix<K, RM> operator*(const Matrix<K, RM>& m, K alpha) { |
482 | | - Matrix<K, RM> res = m; |
483 | | - res.scl(alpha); |
484 | | - return res; |
| 514 | +template <typename K, bool RM> Matrix<K, RM> operator*(const Matrix<K, RM> &m, K alpha) { |
| 515 | + Matrix<K, RM> res = m; |
| 516 | + res.scl(alpha); |
| 517 | + return res; |
485 | 518 | } |
486 | | -template<typename K, bool RM> |
487 | | -Matrix<K, RM> operator*(K alpha, const Matrix<K, RM>& m) { |
488 | | - return m * alpha; |
| 519 | +template <typename K, bool RM> Matrix<K, RM> operator*(K alpha, const Matrix<K, RM> &m) { |
| 520 | + return m * alpha; |
489 | 521 | } |
490 | 522 |
|
491 | 523 | } // namespace tensorium |
0 commit comments