Skip to content

Commit 8020183

Browse files
committed
overload basic maths operators for matrices ans vectors
1 parent 039b5fb commit 8020183

File tree

4 files changed

+38
-166
lines changed

4 files changed

+38
-166
lines changed

Tests/benchmarks/bench.cpp

Lines changed: 0 additions & 163 deletions
This file was deleted.

includes/Tensorium/Core/Matrix.hpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,37 @@ template <typename K, bool RowMajor = false> class Matrix {
455455
}
456456

457457
return r;
458-
}
458+
}
459+
460+
Matrix& operator+=(const Matrix& m) { this->add(m); return *this; }
461+
Matrix& operator-=(const Matrix& m) { this->sub(m); return *this; }
462+
Matrix& operator*=(K alpha) { this->scl(alpha); return *this; }
459463
};
464+
template<typename K, bool RM>
465+
Matrix<K, RM> operator+(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
466+
Matrix<K, RM> res = a;
467+
res.add(b);
468+
return res;
469+
}
470+
template<typename K, bool RM>
471+
Matrix<K, RM> operator-(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
472+
Matrix<K, RM> res = a;
473+
res.sub(b);
474+
return res;
475+
}
476+
template<typename K, bool RM>
477+
Matrix<K, RM> operator*(const Matrix<K, RM>& a, const Matrix<K, RM>& b) {
478+
return a._mul_mat(b);
479+
}
480+
template<typename K, bool RM>
481+
Matrix<K, RM> operator*(const Matrix<K, RM>& m, K alpha) {
482+
Matrix<K, RM> res = m;
483+
res.scl(alpha);
484+
return res;
485+
}
486+
template<typename K, bool RM>
487+
Matrix<K, RM> operator*(K alpha, const Matrix<K, RM>& m) {
488+
return m * alpha;
489+
}
490+
460491
} // namespace tensorium

includes/Tensorium/Core/MatrixKernels/GemmKernel_bigger.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
* sizes
1212
*
1313
*/
14-
1514
namespace tensorium {
1615
template <typename T> class GemmKernelBigger {
1716
public:
1817
using Simd = simd::SimdTraits<T, DefaultISA>;
1918
using reg = typename Simd::reg;
2019
static constexpr int SimdWidth = Simd::width;
21-
static constexpr int TileRows = SimdWidth * 2;
20+
static constexpr int TileRows = SimdWidth * 4;
2221
static constexpr int TileCols = 6;
2322
static constexpr int NThreads = 16;
2423

includes/Tensorium/Core/Vector.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ template <typename K> class Vector {
447447

448448
return r;
449449
}
450+
451+
Vector<K>& operator+=(const Vector<K>& m) { this->add(m); return *this; }
452+
Vector<K>& operator-=(const Vector<K>& m) { this->sub(m); return *this; }
453+
Vector<K>& operator*=(K alpha) { this->scl(alpha); return *this; }
454+
450455
};
451456

452457
template <typename K> inline Vector<K> operator+(const Vector<K> &a, const Vector<K> &b) {

0 commit comments

Comments
 (0)