55#include < algorithm>
66#include < cstdlib>
77#include < cstring>
8+ #include < omp.h>
89/*
910 * this Gemm kernel is based on Aman Salykov version. Improvment of the OMP schedulding and Block
1011 * sizes
@@ -20,7 +21,7 @@ template <typename T> class GemmKernelBigger {
2021 static constexpr int SimdWidth = Simd::width;
2122 static constexpr int TileRows = SimdWidth * 4 ;
2223 static constexpr int TileCols = 6 ;
23- static constexpr int NThreads = 72 ;
24+ static constexpr int NThreads = 36 ;
2425
2526 // static constexpr int BlockDepth = 256;
2627 // static constexpr int BlockRows = 384;
@@ -196,7 +197,7 @@ template <typename T> class GemmKernelBigger {
196197 }
197198
198199 inline static void build_masks (__m256i *packed_mask_0, __m256i *packed_mask_1, int mr) {
199- #if defined(__AVX512F__)
200+ # if defined(__AVX512F__)
200201 __m128i m0 = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(&mask[32 - mr]));
201202 __m128i m1 = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(&mask[32 - mr + 16 ]));
202203
@@ -206,15 +207,15 @@ template <typename T> class GemmKernelBigger {
206207 *packed_mask_0 = _mm512_castsi512_si256 (p0);
207208 *packed_mask_1 = _mm512_castsi512_si256 (p1);
208209
209- #elif defined(__AVX2__)
210+ # elif defined(__AVX2__)
210211 __m128i m0 = _mm_loadl_epi64 (reinterpret_cast <const __m128i *>(&mask[16 - mr]));
211212 __m128i m1 = _mm_loadl_epi64 (reinterpret_cast <const __m128i *>(&mask[16 - mr + 8 ]));
212213
213214 *packed_mask_0 = _mm256_cvtepi8_epi32 (m0);
214215 *packed_mask_1 = _mm256_cvtepi8_epi32 (m1);
215- #else
216- # error "AVX2 or AVX-512 required"
217- #endif
216+ # else
217+ # error "AVX2 or AVX-512 required"
218+ # endif
218219 }
219220
220221 inline void maskload_accum_00 (T *C, reg *C_accum_00, reg *C_accum_01, __m256i packed_mask_0,
@@ -507,8 +508,8 @@ template <typename T> class GemmKernelBigger {
507508 Simd::maskstore (&C[5 * M + 8 ], packed_mask_1, *C_accum_51);
508509 }
509510
510- void kernel_16x6_load_accum (T* __restrict blockA_packed, T* __restrict blockB_packed, T* __restrict C, int mr, int nr, int kc ,
511- int M) {
511+ inline void kernel_16x6_load_accum (T * __restrict blockA_packed, T * __restrict blockB_packed ,
512+ T *__restrict C, int mr, int nr, int kc, int M) {
512513 reg C_accum_00 = {};
513514 reg C_accum_01 = {};
514515 reg C_accum_10 = {};
@@ -651,8 +652,9 @@ template <typename T> class GemmKernelBigger {
651652 }
652653 }
653654
654- void kernel_16x6_zero_init_accum (T* __restrict blockA_packed, T* __restrict blockB_packed, T* __restrict C, int mr, int nr,
655- int kc, int M) {
655+ inline void kernel_16x6_zero_init_accum (T *__restrict blockA_packed,
656+ T *__restrict blockB_packed, T *__restrict C, int mr,
657+ int nr, int kc, int M) {
656658 reg C_accum_00 = {};
657659 reg C_accum_01 = {};
658660 reg C_accum_10 = {};
@@ -769,25 +771,25 @@ template <typename T> class GemmKernelBigger {
769771 }
770772 }
771773
772- #ifndef NTHREADS
773- # define NTHREADS 36
774- #endif
774+ # ifndef NTHREADS
775+ # define NTHREADS 36
776+ # endif
775777
776- #define MC ( 16 * ( 40 / NTHREADS) * NTHREADS)
777- #define NC ( 6 * ( 800 / NTHREADS) * NTHREADS)
778- #define KC 500
778+ # define KC 512
779+ # define MC 384
780+ # define NC 4096
779781
780- #ifndef OMP_SCHEDULE
781- # define OMP_SCHEDULE auto
782- #endif
783- #define _min (x, y ) ((x) < (y) ? (x) : (y))
784- #define PRAGMA_OMP_PARALLEL_FOR \
785- _Pragma (" omp parallel for schedule(OMP_SCHEDULE) num_threads(NTHREADS)" )
782+ # ifndef OMP_SCHEDULE
783+ # define OMP_SCHEDULE dynamic
784+ # endif
785+ # define _min (x, y ) ((x) < (y) ? (x) : (y))
786+ # define PRAGMA_OMP_PARALLEL_FOR \
787+ _Pragma (" omp parallel for schedule(OMP_SCHEDULE) num_threads(NTHREADS)" )
786788
787789 static T blockA_packed[MC * KC] __attribute__((aligned(64 )));
788790 static T blockB_packed[NC * KC] __attribute__((aligned(64 )));
789791
790- void pack_panelB (T *B, T *blockB_packed, int nr, int kc, int K) {
792+ inline void pack_panelB (T *B, T *blockB_packed, int nr, int kc, int K) {
791793 for (int p = 0 ; p < kc; p++) {
792794 for (int j = 0 ; j < nr; j++) {
793795 *blockB_packed++ = B[j * K + p];
@@ -799,14 +801,14 @@ template <typename T> class GemmKernelBigger {
799801 }
800802
801803 void pack_blockB (T *B, T *blockB_packed, int nc, int kc, int K) {
802- #pragma omp for schedule(dynamic)
804+ # pragma omp for schedule(dynamic)
803805 for (int j = 0 ; j < nc; j += 6 ) {
804806 int nr = _min (6 , nc - j);
805807 pack_panelB (&B[j * K], &blockB_packed[j * kc], nr, kc, K);
806808 }
807809 }
808810
809- void pack_panelA (T *A, T *blockA_packed, int mr, int kc, int M) {
811+ inline void pack_panelA (T *A, T *blockA_packed, int mr, int kc, int M) {
810812 for (int p = 0 ; p < kc; p++) {
811813 for (int i = 0 ; i < mr; i++) {
812814 *blockA_packed++ = A[p * M + i];
@@ -817,15 +819,22 @@ template <typename T> class GemmKernelBigger {
817819 }
818820 }
819821
820- void pack_blockA (T *A, T *blockA_packed, int mc, int kc, int M) {
822+ inline void pack_blockA (T *A, T *blockA_packed, int mc, int kc, int M) {
821823 PRAGMA_OMP_PARALLEL_FOR
822824 for (int i = 0 ; i < mc; i += 16 ) {
823825 int mr = _min (16 , mc - i);
824826 pack_panelA (&A[i], &blockA_packed[i * kc], mr, kc, M);
825827 }
826828 }
827-
828829 void matmul (T *A, T *B, T *C, int M, int N, int K) {
830+ # pragma omp parallel
831+ {
832+ int tid = omp_get_thread_num ();
833+ cpu_set_t cpuset;
834+ CPU_ZERO (&cpuset);
835+ CPU_SET (tid % 36 , &cpuset);
836+ sched_setaffinity (0 , sizeof (cpuset), &cpuset);
837+ }
829838 for (int j = 0 ; j < N; j += NC) {
830839 int nc = _min (NC, N - j);
831840 int kc = _min (KC, K);
0 commit comments