55
66using namespace ggml_cuda_mma ;
77
8- // Config options for specific head sizes .
8+ // Config options for the MMA kernel .
99// Should not affect results, only speed/register pressure/shared memory use.
10- //
11- // nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
12- // nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
13- // Q_in_reg: whether the Q values should be kept permanently in registers.
14- // nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
15- // nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
16- // nbatch_V2: number of V half2 values in direction of DV to load in parallel.
17- // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
18-
19- // The ROCm compiler cannot handle templating in __launch_bounds__.
20- // As a workaround, define a macro to package the kernel parameters as uint32_t:
21- #define GGML_CUDA_FATTN_MMA_CONFIG_CASE (DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K2, nbatch_V2, nbatch_combine, nstages_target, Q_in_reg ) \
22- if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
23- static_assert ((nthreads) % 32 == 0 && (nthreads) <= 512 , " bad nthreads" ); \
24- static_assert ( (occupancy) <= 8 , " bad occupancy" ); \
25- static_assert ((nbatch_fa) % 32 == 0 && (nbatch_fa) <= 256 , " bad nbatch_fa" ); \
26- static_assert ((nbatch_K2) % 4 == 0 && (nbatch_K2) <= 512 , " bad nbatch_K2" ); \
27- static_assert ((nbatch_V2) % 4 == 0 && (nbatch_V2) <= 256 , " bad nbatch_V2" ); \
28- static_assert ((nbatch_combine) % 4 == 0 && (nbatch_combine) <= 128 , " bad nbatch_combine" ); \
29- static_assert ((nstages_target) >= 1 && (nstages_target) <= 2 , " bad nstages_target" ); \
30- return ((((nthreads) / 32 ) - 1 ) << 0 ) | \
31- ((((occupancy) / 1 ) - 1 ) << 4 ) | \
32- ((((nbatch_fa) / 32 ) - 1 ) << 7 ) | \
33- ((((nbatch_K2) / 8 ) - 1 ) << 10 ) | \
34- ((((nbatch_V2) / 8 ) - 1 ) << 17 ) | \
35- ((((nbatch_combine) / 8 ) - 1 ) << 23 ) | \
36- ((((nstages_target) / 1 ) - 1 ) << 28 ) | \
37- (((Q_in_reg) ? 1 : 0 ) << 29 ); \
38- } \
39-
40- static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_ampere (const int DKQ, const int DV, const int ncols) {
10+ struct fattn_mma_config {
11+ int nthreads; // Number of threads per CUDA block.
12+ int occupancy; // Targeted occupancy for the MMA kernel.
13+ int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
14+ int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
15+ int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
16+ int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
17+ int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
18+ bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
19+
20+ constexpr __host__ __device__ fattn_mma_config (
21+ int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
22+ nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
23+ nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
24+ };
25+
26+ #define GGML_CUDA_FATTN_MMA_CONFIG_CASE (DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_ ) \
27+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
28+ static_assert ((nthreads_) % 32 == 0 && (nthreads_) <= 512 , " bad nthreads" ); \
29+ static_assert ( (occupancy_) <= 8 , " bad occupancy" ); \
30+ static_assert ((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256 , " bad nbatch_fa" ); \
31+ static_assert ((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512 , " bad nbatch_K2" ); \
32+ static_assert ((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256 , " bad nbatch_V2" ); \
33+ static_assert ((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128 , " bad nbatch_combine" ); \
34+ static_assert ((nstages_target_) >= 1 && (nstages_target_) <= 2 , " bad nstages_target" ); \
35+ return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
36+ } \
37+
38+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere (const int DKQ, const int DV, const int ncols) {
4139 GGML_CUDA_FATTN_MMA_CONFIG_CASE ( 64 , 64 , 8 , 128 , 2 , 128 , 32 , 32 , 32 , 2 , true );
4240 GGML_CUDA_FATTN_MMA_CONFIG_CASE ( 64 , 64 , 16 , 128 , 2 , 64 , 32 , 32 , 32 , 2 , true );
4341 GGML_CUDA_FATTN_MMA_CONFIG_CASE ( 64 , 64 , 32 , 128 , 2 , 64 , 32 , 32 , 32 , 2 , true );
@@ -73,10 +71,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_amp
7371 GGML_CUDA_FATTN_MMA_CONFIG_CASE (576 , 512 , 32 , 128 , 2 , 32 , 160 , 128 , 128 , 1 , false );
7472 GGML_CUDA_FATTN_MMA_CONFIG_CASE (576 , 512 , 64 , 256 , 1 , 32 , 160 , 128 , 128 , 1 , false );
7573
76- return 0 ;
74+ return fattn_mma_config ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , false ) ;
7775}
7876
79- static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_turing (const int DKQ, const int DV, const int ncols) {
77+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing (const int DKQ, const int DV, const int ncols) {
8078 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 8 , 128 , 2 , 64 , 128 , 128 , 128 , 2 , true );
8179 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 16 , 128 , 2 , 64 , 128 , 128 , 128 , 2 , true );
8280 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 32 , 128 , 2 , 64 , 128 , 128 , 64 , 2 , true );
@@ -90,7 +88,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_tur
9088 return ggml_cuda_fattn_mma_get_config_ampere (DKQ, DV, ncols);
9189}
9290
93- static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_volta (const int DKQ, const int DV, const int ncols) {
91+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta (const int DKQ, const int DV, const int ncols) {
9492 GGML_CUDA_FATTN_MMA_CONFIG_CASE (576 , 512 , 8 , 64 , 4 , 32 , 288 , 256 , 64 , 1 , false );
9593 GGML_CUDA_FATTN_MMA_CONFIG_CASE (576 , 512 , 16 , 64 , 4 , 32 , 288 , 256 , 64 , 1 , false );
9694 GGML_CUDA_FATTN_MMA_CONFIG_CASE (576 , 512 , 32 , 128 , 2 , 32 , 160 , 128 , 64 , 1 , false );
@@ -100,7 +98,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_mma_get_config_vol
10098 return ggml_cuda_fattn_mma_get_config_ampere (DKQ, DV, ncols);
10199}
102100
103- static __host__ uint32_t ggml_cuda_fattn_mma_get_config (const int DKQ, const int DV, const int ncols, const int cc) {
101+ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config (const int DKQ, const int DV, const int ncols, const int cc) {
104102 if (ampere_mma_available (cc)) {
105103 return ggml_cuda_fattn_mma_get_config_ampere (DKQ, DV, ncols);
106104 }
@@ -111,7 +109,7 @@ static __host__ uint32_t ggml_cuda_fattn_mma_get_config(const int DKQ, const int
111109 return ggml_cuda_fattn_mma_get_config_volta (DKQ, DV, ncols);
112110}
113111
114- static constexpr __device__ uint32_t ggml_cuda_fattn_mma_get_config (const int DKQ, const int DV, const int ncols) {
112+ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config (const int DKQ, const int DV, const int ncols) {
115113#if defined(AMPERE_MMA_AVAILABLE)
116114 return ggml_cuda_fattn_mma_get_config_ampere (DKQ, DV, ncols);
117115#elif defined(TURING_MMA_AVAILABLE)
@@ -120,72 +118,72 @@ static constexpr __device__ uint32_t ggml_cuda_fattn_mma_get_config(const int DK
120118 return ggml_cuda_fattn_mma_get_config_volta (DKQ, DV, ncols);
121119#else
122120 GGML_UNUSED_VARS (DKQ, DV, ncols);
123- return 0 ;
121+ return fattn_mma_config ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , false ) ;
124122#endif // defined(AMPERE_MMA_AVAILABLE)
125123}
126124
127125static __host__ int ggml_cuda_fattn_mma_get_nthreads (const int DKQ, const int DV, const int ncols, const int cc) {
128- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 0 ) & (( 1 << 4 ) - 1 )) + 1 ) * 32 ;
126+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nthreads ;
129127}
130128
131129static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads (const int DKQ, const int DV, const int ncols) {
132- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 0 ) & (( 1 << 4 ) - 1 )) + 1 ) * 32 ;
130+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nthreads ;
133131}
134132
135133static __host__ int ggml_cuda_fattn_mma_get_occupancy (const int DKQ, const int DV, const int ncols, const int cc) {
136- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 4 ) & (( 1 << 3 ) - 1 )) + 1 ) * 1 ;
134+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). occupancy ;
137135}
138136
139137static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy (const int DKQ, const int DV, const int ncols) {
140- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 4 ) & (( 1 << 3 ) - 1 )) + 1 ) * 1 ;
138+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). occupancy ;
141139}
142140
143141static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa (const int DKQ, const int DV, const int ncols, const int cc) {
144- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 7 ) & (( 1 << 3 ) - 1 )) + 1 ) * 32 ;
142+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nbatch_fa ;
145143}
146144
147145static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa (const int DKQ, const int DV, const int ncols) {
148- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 7 ) & (( 1 << 3 ) - 1 )) + 1 ) * 32 ;
146+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nbatch_fa ;
149147}
150148
151149static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2 (const int DKQ, const int DV, const int ncols, const int cc) {
152- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 10 ) & (( 1 << 7 ) - 1 )) + 1 ) * 8 ;
150+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nbatch_K2 ;
153151}
154152
155153static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2 (const int DKQ, const int DV, const int ncols) {
156- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 10 ) & (( 1 << 7 ) - 1 )) + 1 ) * 8 ;
154+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nbatch_K2 ;
157155}
158156
159157static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2 (const int DKQ, const int DV, const int ncols, const int cc) {
160- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 17 ) & (( 1 << 6 ) - 1 )) + 1 ) * 8 ;
158+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nbatch_V2 ;
161159}
162160
163161static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2 (const int DKQ, const int DV, const int ncols) {
164- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 17 ) & (( 1 << 6 ) - 1 )) + 1 ) * 8 ;
162+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nbatch_V2 ;
165163}
166164
167165static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine (const int DKQ, const int DV, const int ncols, const int cc) {
168- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 23 ) & (( 1 << 5 ) - 1 )) + 1 ) * 8 ;
166+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nbatch_combine ;
169167}
170168
171169static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine (const int DKQ, const int DV, const int ncols) {
172- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 23 ) & (( 1 << 5 ) - 1 )) + 1 ) * 8 ;
170+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nbatch_combine ;
173171}
174172
175173static __host__ int ggml_cuda_fattn_mma_get_nstages_target (const int DKQ, const int DV, const int ncols, const int cc) {
176- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 28 ) & (( 1 << 2 ) - 1 )) + 1 ) * 1 ;
174+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). nstages_target ;
177175}
178176
179177static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target (const int DKQ, const int DV, const int ncols) {
180- return ((( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 28 ) & (( 1 << 2 ) - 1 )) + 1 ) * 1 ;
178+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). nstages_target ;
181179}
182180
183181static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg (const int DKQ, const int DV, const int ncols, const int cc) {
184- return ( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc) >> 29 ) & 1 ;
182+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols, cc). Q_in_reg ;
185183}
186184
187- static constexpr __device__ int ggml_cuda_fattn_mma_get_Q_in_reg (const int DKQ, const int DV, const int ncols) {
188- return ( ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols) >> 29 ) & 1 ;
185+ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg (const int DKQ, const int DV, const int ncols) {
186+ return ggml_cuda_fattn_mma_get_config (DKQ, DV, ncols). Q_in_reg ;
189187}
190188
191189// ------------------------------------------------------------------------------------------------------------------
0 commit comments