Skip to content

Commit d861a34

Browse files
Aman GuptaJohannesGaessler
authored andcommitted
use struct for MMA FA kernel config
1 parent ec176ee commit d861a34

File tree

1 file changed

+53
-55
lines changed

1 file changed

+53
-55
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,37 @@
55

66
using namespace ggml_cuda_mma;
77

8-
// Config options for specific head sizes.
8+
// Config options for the MMA kernel.
99
// Should not affect results, only speed/register pressure/shared memory use.
10-
//
11-
// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
12-
// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
13-
// Q_in_reg: whether the Q values should be kept permanently in registers.
14-
// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
15-
// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
16-
// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
17-
// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
18-
19-
// The ROCm compiler cannot handle templating in __launch_bounds__.
20-
// As a workaround, define a macro to package the kernel parameters as uint32_t:
21-
#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K2, nbatch_V2, nbatch_combine, nstages_target, Q_in_reg) \
22-
if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
23-
static_assert((nthreads) % 32 == 0 && (nthreads) <= 512, "bad nthreads"); \
24-
static_assert( (occupancy) <= 8, "bad occupancy"); \
25-
static_assert((nbatch_fa) % 32 == 0 && (nbatch_fa) <= 256, "bad nbatch_fa"); \
26-
static_assert((nbatch_K2) % 4 == 0 && (nbatch_K2) <= 512, "bad nbatch_K2"); \
27-
static_assert((nbatch_V2) % 4 == 0 && (nbatch_V2) <= 256, "bad nbatch_V2"); \
28-
static_assert((nbatch_combine) % 4 == 0 && (nbatch_combine) <= 128, "bad nbatch_combine"); \
29-
static_assert((nstages_target) >= 1 && (nstages_target) <= 2, "bad nstages_target"); \
30-
return ((((nthreads) / 32) - 1) << 0) | \
31-
((((occupancy) / 1) - 1) << 4) | \
32-
((((nbatch_fa) / 32) - 1) << 7) | \
33-
((((nbatch_K2) / 8) - 1) << 10) | \
34-
((((nbatch_V2) / 8) - 1) << 17) | \
35-
((((nbatch_combine) / 8) - 1) << 23) | \
36-
((((nstages_target) / 1) - 1) << 28) | \
37-
(((Q_in_reg) ? 1 : 0) << 29); \
38-
} \
39-
40-
static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
10+
struct fattn_mma_config {
11+
int nthreads; // Number of threads per CUDA block.
12+
int occupancy; // Targeted occupancy for the MMA kernel.
13+
int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
14+
int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
15+
int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
16+
int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
17+
int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
18+
bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
19+
20+
constexpr __host__ __device__ fattn_mma_config(
21+
int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
22+
nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
23+
nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
24+
};
25+
26+
#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
27+
if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
28+
static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
29+
static_assert( (occupancy_) <= 8, "bad occupancy"); \
30+
static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
31+
static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
32+
static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
33+
static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
34+
static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
35+
return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
36+
} \
37+
38+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
4139
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
4240
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
4341
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
@@ -73,10 +71,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_amp
7371
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
7472
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
7573

76-
return 0;
74+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
7775
}
7876

79-
static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
77+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
8078
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
8179
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
8280
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
@@ -90,7 +88,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_tur
9088
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
9189
}
9290

93-
static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
91+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
9492
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
9593
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
9694
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
@@ -100,7 +98,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_vol
10098
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
10199
}
102100

103-
static __host__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
101+
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
104102
if (ampere_mma_available(cc)) {
105103
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
106104
}
@@ -111,7 +109,7 @@ static __host__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int
111109
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
112110
}
113111

114-
static constexpr __device__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
112+
static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
115113
#if defined(AMPERE_MMA_AVAILABLE)
116114
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
117115
#elif defined(TURING_MMA_AVAILABLE)
@@ -120,72 +118,72 @@ static constexpr __device__ uint32_t ggml_cuda_fattn_mma_get_config(const int DK
120118
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
121119
#else
122120
GGML_UNUSED_VARS(DKQ, DV, ncols);
123-
return 0;
121+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
124122
#endif // defined(AMPERE_MMA_AVAILABLE)
125123
}
126124

127125
static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
128-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 4) - 1)) + 1) * 32;
126+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
129127
}
130128

131129
static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
132-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 0) & ((1 << 4) - 1)) + 1) * 32;
130+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
133131
}
134132

135133
static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
136-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 4) & ((1 << 3) - 1)) + 1) * 1;
134+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
137135
}
138136

139137
static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
140-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 4) & ((1 << 3) - 1)) + 1) * 1;
138+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
141139
}
142140

143141
static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
144-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 7) & ((1 << 3) - 1)) + 1) * 32;
142+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
145143
}
146144

147145
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
148-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 7) & ((1 << 3) - 1)) + 1) * 32;
146+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
149147
}
150148

151149
static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
152-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 7) - 1)) + 1) * 8;
150+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
153151
}
154152

155153
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
156-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 10) & ((1 << 7) - 1)) + 1) * 8;
154+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
157155
}
158156

159157
static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
160-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 17) & ((1 << 6) - 1)) + 1) * 8;
158+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
161159
}
162160

163161
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
164-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 17) & ((1 << 6) - 1)) + 1) * 8;
162+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
165163
}
166164

167165
static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
168-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 5) - 1)) + 1) * 8;
166+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
169167
}
170168

171169
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
172-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 23) & ((1 << 5) - 1)) + 1) * 8;
170+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
173171
}
174172

175173
static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
176-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 28) & ((1 << 2) - 1)) + 1) * 1;
174+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
177175
}
178176

179177
static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
180-
return (((ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 28) & ((1 << 2) - 1)) + 1) * 1;
178+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
181179
}
182180

183181
static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
184-
return (ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc) >> 29) & 1;
182+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
185183
}
186184

187-
static constexpr __device__ int ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
188-
return (ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols) >> 29) & 1;
185+
static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
186+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
189187
}
190188

191189
// ------------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)