Skip to content

Commit 13500e8

Browse files
refactor template parameters
1 parent 301ae30 commit 13500e8

File tree

3 files changed

+53
-47
lines changed

3 files changed

+53
-47
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,12 +798,12 @@ template<> struct mma_tile_sizes<8> {
798798
};
799799
#else // Volta
800800
template<int ncols> struct mma_tile_sizes {
801-
using T_A_KQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, false>; // row-major
802-
using T_B_KQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major
803-
using T_C_KQ = tile<32, 8, float, DATA_SPLIT_NONE, false>; // column-major
804-
using T_A_VKQ = tile< 8, 4, half2, DATA_SPLIT_MIRRORED, true>; // column-major
805-
using T_B_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major
806-
using T_C_VKQ = tile<32, 4, half2, DATA_SPLIT_NONE, false>; // column-major
801+
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
802+
using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
803+
using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
804+
using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
805+
using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
806+
using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
807807
};
808808
#endif // defined(TURING_MMA_AVAILABLE)
809809

ggml/src/ggml-cuda/mma.cuh

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,28 @@ namespace ggml_cuda_mma {
7171
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
7272
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
7373
// In those cases the data can be split in different ways across the warp.
74-
enum data_split {
75-
DATA_SPLIT_NONE = 0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
76-
DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup.
74+
enum data_layout {
75+
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76+
// For the A/C matrices this means I major == row major, J major == column major.
77+
// For the B matrix this means I major == column major, J major == row major.
78+
// MIRRORED == Each data value is held exactly once per thread subgroup.
79+
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
80+
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
81+
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
7782
};
7883
// Implemented mma combinations are:
79-
// - (NONE, NONE) -> NONE
80-
// - (NONE, MIRRORED) -> NONE
84+
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
85+
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
86+
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8187

82-
template <int I_, int J_, typename T, data_split ds_=DATA_SPLIT_NONE, bool transposed=false>
88+
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
8389
struct tile {};
8490

8591
template <int I_, int J_, typename T>
86-
struct tile<I_, J_, T, DATA_SPLIT_NONE, false> {
87-
static constexpr int I = I_;
88-
static constexpr int J = J_;
89-
static constexpr data_split ds = DATA_SPLIT_NONE;
92+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
93+
static constexpr int I = I_;
94+
static constexpr int J = J_;
95+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
9096

9197
#if defined(AMD_MFMA_AVAILABLE)
9298
static constexpr int ne = I * J / 64;
@@ -242,10 +248,10 @@ namespace ggml_cuda_mma {
242248
};
243249

244250
template <int I_, int J_>
245-
struct tile<I_, J_, half2, DATA_SPLIT_NONE, false> {
246-
static constexpr int I = I_;
247-
static constexpr int J = J_;
248-
static constexpr data_split ds = DATA_SPLIT_NONE;
251+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
252+
static constexpr int I = I_;
253+
static constexpr int J = J_;
254+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
249255

250256
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
251257
static constexpr int ne = I * J / WARP_SIZE;
@@ -349,11 +355,11 @@ namespace ggml_cuda_mma {
349355
};
350356

351357
template <int I_, int J_>
352-
struct tile<I_, J_, nv_bfloat162, DATA_SPLIT_NONE, false> {
353-
static constexpr int I = I_;
354-
static constexpr int J = J_;
355-
static constexpr data_split ds = DATA_SPLIT_NONE;
356-
static constexpr int ne = I * J / WARP_SIZE;
358+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
359+
static constexpr int I = I_;
360+
static constexpr int J = J_;
361+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
362+
static constexpr int ne = I * J / WARP_SIZE;
357363

358364
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
359365

@@ -417,11 +423,11 @@ namespace ggml_cuda_mma {
417423
};
418424

419425
template <int I_, int J_>
420-
struct tile<I_, J_, half2, DATA_SPLIT_MIRRORED, false> {
421-
static constexpr int I = I_;
422-
static constexpr int J = J_;
423-
static constexpr data_split ds = DATA_SPLIT_MIRRORED;
424-
static constexpr int ne = I * J / (WARP_SIZE/4);
426+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
427+
static constexpr int I = I_;
428+
static constexpr int J = J_;
429+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
430+
static constexpr int ne = I * J / (WARP_SIZE/4);
425431

426432
half2 x[ne] = {{0.0f, 0.0f}};
427433

@@ -450,11 +456,11 @@ namespace ggml_cuda_mma {
450456
};
451457

452458
template <int I_, int J_>
453-
struct tile<I_, J_, half2, DATA_SPLIT_MIRRORED, true> {
454-
static constexpr int I = I_;
455-
static constexpr int J = J_;
456-
static constexpr data_split ds = DATA_SPLIT_MIRRORED;
457-
static constexpr int ne = I * J / (WARP_SIZE/4);
459+
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
460+
static constexpr int I = I_;
461+
static constexpr int J = J_;
462+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
463+
static constexpr int ne = I * J / (WARP_SIZE/4);
458464

459465
half2 x[ne] = {{0.0f, 0.0f}};
460466

@@ -518,8 +524,8 @@ namespace ggml_cuda_mma {
518524
}
519525
#endif // defined(TURING_MMA_AVAILABLE)
520526

521-
template <int I, int J, typename T, data_split ds, bool transposed>
522-
static __device__ __forceinline__ void load_generic(tile<I, J, T, ds, transposed> & t, const T * __restrict__ xs0, const int stride) {
527+
template <int I, int J, typename T, data_layout dl>
528+
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
523529
#if defined(AMD_MFMA_AVAILABLE)
524530
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
525531
#pragma unroll
@@ -622,12 +628,12 @@ namespace ggml_cuda_mma {
622628
}
623629

624630
static __device__ __forceinline__ void load_ldmatrix(
625-
tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & t, const half2 * __restrict__ xs0, const int stride) {
631+
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
626632
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
627633
}
628634

629635
static __device__ __forceinline__ void load_ldmatrix(
630-
tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & t, const half2 * __restrict__ xs0, const int stride) {
636+
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
631637
#pragma unroll
632638
for (int l0 = 0; l0 < t.ne; l0 += 2) {
633639
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
@@ -972,7 +978,7 @@ namespace ggml_cuda_mma {
972978
}
973979

974980
static __device__ __forceinline__ void mma(
975-
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, false> & B) {
981+
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
976982
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
977983
const int * Axi = (const int *) A.x;
978984
const int * Bxi = (const int *) B.x;
@@ -992,7 +998,7 @@ namespace ggml_cuda_mma {
992998
}
993999

9941000
static __device__ __forceinline__ void mma(
995-
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_SPLIT_MIRRORED, true> & B) {
1001+
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
9961002
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
9971003
const int * Axi = (const int *) A.x;
9981004
const int * Bxi = (const int *) B.x;

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ static __global__ void mul_mat_f(
4040
#else
4141
#ifdef VOLTA_MMA_AVAILABLE
4242
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
43-
typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A;
44-
typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B;
45-
typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C;
43+
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
44+
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
45+
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
4646
#else
4747
typedef tile<16, 8, T> tile_A;
4848
typedef tile<8, 8, T> tile_B;
@@ -280,9 +280,9 @@ static __global__ void mul_mat_f_ids(
280280
#else
281281
#ifdef VOLTA_MMA_AVAILABLE
282282
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
283-
typedef tile<32, 4, T, DATA_SPLIT_NONE, false> tile_A;
284-
typedef tile< 8, 4, T, DATA_SPLIT_MIRRORED, false> tile_B;
285-
typedef tile<32, 8, float, DATA_SPLIT_NONE, false> tile_C;
283+
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
284+
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
285+
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
286286
#else
287287
typedef tile<16, 8, T> tile_A;
288288
typedef tile<8, 8, T> tile_B;

0 commit comments

Comments
 (0)