@@ -71,22 +71,28 @@ namespace ggml_cuda_mma {
7171 // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
7272 // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
7373 // In those cases the data can be split in different ways across the warp.
74- enum data_split {
75- DATA_SPLIT_NONE = 0 , // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
76- DATA_SPLIT_MIRRORED = 10 , // Each data value is held exactly once per subgroup.
74+ enum data_layout {
75+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76+ // For the A/C matrices this means I major == row major, J major == column major.
77+ // For the B matrix this means I major == column major, J major == row major.
78+ // MIRRORED == Each data value is held exactly once per thread subgroup.
79+ DATA_LAYOUT_I_MAJOR = 0 , // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
80+ DATA_LAYOUT_I_MAJOR_MIRRORED = 10 ,
81+ DATA_LAYOUT_J_MAJOR_MIRRORED = 20 ,
7782 };
7883 // Implemented mma combinations are:
79- // - (NONE, NONE) -> NONE
80- // - (NONE, MIRRORED) -> NONE
84+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
85+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
86+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8187
82- template <int I_, int J_, typename T, data_split ds_=DATA_SPLIT_NONE, bool transposed= false >
88+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR >
8389 struct tile {};
8490
8591 template <int I_, int J_, typename T>
86- struct tile <I_, J_, T, DATA_SPLIT_NONE, false > {
87- static constexpr int I = I_;
88- static constexpr int J = J_;
89- static constexpr data_split ds = DATA_SPLIT_NONE ;
92+ struct tile <I_, J_, T, DATA_LAYOUT_I_MAJOR > {
93+ static constexpr int I = I_;
94+ static constexpr int J = J_;
95+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR ;
9096
9197#if defined(AMD_MFMA_AVAILABLE)
9298 static constexpr int ne = I * J / 64 ;
@@ -242,10 +248,10 @@ namespace ggml_cuda_mma {
242248 };
243249
244250 template <int I_, int J_>
245- struct tile <I_, J_, half2, DATA_SPLIT_NONE, false > {
246- static constexpr int I = I_;
247- static constexpr int J = J_;
248- static constexpr data_split ds = DATA_SPLIT_NONE ;
251+ struct tile <I_, J_, half2, DATA_LAYOUT_I_MAJOR > {
252+ static constexpr int I = I_;
253+ static constexpr int J = J_;
254+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR ;
249255
250256#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
251257 static constexpr int ne = I * J / WARP_SIZE;
@@ -349,11 +355,11 @@ namespace ggml_cuda_mma {
349355 };
350356
351357 template <int I_, int J_>
352- struct tile <I_, J_, nv_bfloat162, DATA_SPLIT_NONE, false > {
353- static constexpr int I = I_;
354- static constexpr int J = J_;
355- static constexpr data_split ds = DATA_SPLIT_NONE ;
356- static constexpr int ne = I * J / WARP_SIZE;
358+ struct tile <I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR > {
359+ static constexpr int I = I_;
360+ static constexpr int J = J_;
361+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR ;
362+ static constexpr int ne = I * J / WARP_SIZE;
357363
358364 nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
359365
@@ -417,11 +423,11 @@ namespace ggml_cuda_mma {
417423 };
418424
419425 template <int I_, int J_>
420- struct tile <I_, J_, half2, DATA_SPLIT_MIRRORED, false > {
421- static constexpr int I = I_;
422- static constexpr int J = J_;
423- static constexpr data_split ds = DATA_SPLIT_MIRRORED ;
424- static constexpr int ne = I * J / (WARP_SIZE/4 );
426+ struct tile <I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED > {
427+ static constexpr int I = I_;
428+ static constexpr int J = J_;
429+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED ;
430+ static constexpr int ne = I * J / (WARP_SIZE/4 );
425431
426432 half2 x[ne] = {{0 .0f , 0 .0f }};
427433
@@ -450,11 +456,11 @@ namespace ggml_cuda_mma {
450456 };
451457
452458 template <int I_, int J_>
453- struct tile <I_, J_, half2, DATA_SPLIT_MIRRORED, true > {
454- static constexpr int I = I_;
455- static constexpr int J = J_;
456- static constexpr data_split ds = DATA_SPLIT_MIRRORED ;
457- static constexpr int ne = I * J / (WARP_SIZE/4 );
459+ struct tile <I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED > {
460+ static constexpr int I = I_;
461+ static constexpr int J = J_;
462+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED ;
463+ static constexpr int ne = I * J / (WARP_SIZE/4 );
458464
459465 half2 x[ne] = {{0 .0f , 0 .0f }};
460466
@@ -518,8 +524,8 @@ namespace ggml_cuda_mma {
518524 }
519525#endif // defined(TURING_MMA_AVAILABLE)
520526
521- template <int I, int J, typename T, data_split ds, bool transposed >
522- static __device__ __forceinline__ void load_generic (tile<I, J, T, ds, transposed > & t, const T * __restrict__ xs0, const int stride) {
527+ template <int I, int J, typename T, data_layout dl >
528+ static __device__ __forceinline__ void load_generic (tile<I, J, T, dl > & t, const T * __restrict__ xs0, const int stride) {
523529#if defined(AMD_MFMA_AVAILABLE)
524530 if constexpr (I == 64 && J == 2 ) { // Special tile size to load <16, 4> as <16, 8>
525531#pragma unroll
@@ -622,12 +628,12 @@ namespace ggml_cuda_mma {
622628 }
623629
624630 static __device__ __forceinline__ void load_ldmatrix (
625- tile<8 , 4 , half2, DATA_SPLIT_MIRRORED, false > & t, const half2 * __restrict__ xs0, const int stride) {
631+ tile<8 , 4 , half2, DATA_LAYOUT_I_MAJOR_MIRRORED > & t, const half2 * __restrict__ xs0, const int stride) {
626632 ggml_cuda_memcpy_1<4 *sizeof (half2)>(t.x , xs0 + t.get_i (0 )*stride);
627633 }
628634
629635 static __device__ __forceinline__ void load_ldmatrix (
630- tile<8 , 4 , half2, DATA_SPLIT_MIRRORED, true > & t, const half2 * __restrict__ xs0, const int stride) {
636+ tile<8 , 4 , half2, DATA_LAYOUT_J_MAJOR_MIRRORED > & t, const half2 * __restrict__ xs0, const int stride) {
631637#pragma unroll
632638 for (int l0 = 0 ; l0 < t.ne ; l0 += 2 ) {
633639 ggml_cuda_memcpy_1<2 *sizeof (half2)>(t.x + l0, xs0 + t.get_i (l0)*stride + t.get_j (l0));
@@ -972,7 +978,7 @@ namespace ggml_cuda_mma {
972978 }
973979
974980 static __device__ __forceinline__ void mma (
975- tile<32 , 8 , float > & D, const tile<32 , 4 , half2> & A, const tile<8 , 4 , half2, DATA_SPLIT_MIRRORED, false > & B) {
981+ tile<32 , 8 , float > & D, const tile<32 , 4 , half2> & A, const tile<8 , 4 , half2, DATA_LAYOUT_I_MAJOR_MIRRORED > & B) {
976982#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
977983 const int * Axi = (const int *) A.x ;
978984 const int * Bxi = (const int *) B.x ;
@@ -992,7 +998,7 @@ namespace ggml_cuda_mma {
992998 }
993999
9941000 static __device__ __forceinline__ void mma (
995- tile<32 , 4 , half2> & D, const tile<32 , 4 , half2> & A, const tile<8 , 4 , half2, DATA_SPLIT_MIRRORED, true > & B) {
1001+ tile<32 , 4 , half2> & D, const tile<32 , 4 , half2> & A, const tile<8 , 4 , half2, DATA_LAYOUT_J_MAJOR_MIRRORED > & B) {
9961002#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
9971003 const int * Axi = (const int *) A.x ;
9981004 const int * Bxi = (const int *) B.x ;
0 commit comments