1111
1212#define FATTN_KQ_STRIDE_TILE_F32 32
1313
14- template <int Dk, int Dv , int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
14+ template <int D , int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
1515#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1616__launch_bounds__ (nwarps*WARP_SIZE, 1 )
1717#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -52,9 +52,9 @@ static __global__ void flash_attn_tile_ext_f32(
5252 const int ne1,
5353 const int ne2,
5454 const int ne3) {
55- static_assert (Dk == Dv || (Dk == 192 && Dv == 128 ) || (Dk == 576 && Dv == 512 ));
55+
5656 // Skip unused kernel variants for faster compilation:
57- if (use_softcap && !(Dk == 128 || Dk == 256 )) {
57+ if (use_softcap && !(D == 128 || D == 256 )) {
5858 NO_DEVICE_CODE;
5959 return ;
6060 }
@@ -70,22 +70,15 @@ static __global__ void flash_attn_tile_ext_f32(
7070 const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx .y / gqa_ratio)); // K and V have same shape
7171 const half * maskh = (const half *) mask + ne11*ic0;
7272
73- const int stride_K2 = nb11 / sizeof (half2);
74- const int stride_V2 = nb12 / sizeof (half2);
73+ const int stride_KV2 = nb11 / sizeof (half2);
7574
7675 const float slope = get_alibi_slope (max_bias, blockIdx .y , n_head_log2, m0, m1);
77-
78- // TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ?
79- // let's assume it is is both.
80- static_assert (Dk % (2 *WARP_SIZE) == 0 , " Dk not divisible by 2*WARP_SIZE == 64." );
81- static_assert (Dv % (2 *WARP_SIZE) == 0 , " Dv not divisible by 2*WARP_SIZE == 64." );
82-
83- constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv
76+ static_assert (D % (2 * WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
8477
8578 __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
8679
87- // This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension
88- __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1 ]; // Pad D to avoid memory bank conflicts.
80+ __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1 ]; // Pad D to avoid memory bank conflicts.
81+
8982 float2 * KV_tmp2 = (float2 *) KV_tmp;
9083
9184 float kqmax[ncols/nwarps];
@@ -95,16 +88,16 @@ static __global__ void flash_attn_tile_ext_f32(
9588 }
9689 float kqsum[ncols/nwarps] = {0 .0f };
9790
98- float2 VKQ[ncols/nwarps][(Dv /2 )/WARP_SIZE] = {{{0 .0f , 0 .0f }}};
91+ float2 VKQ[ncols/nwarps][(D /2 )/WARP_SIZE] = {{{0 .0f , 0 .0f }}};
9992
10093 // Convert Q to half2 and store in registers:
101- __shared__ float Q_f[ncols][Dk ];
94+ __shared__ float Q_f[ncols][D ];
10295#pragma unroll
10396 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
10497 const int j = j0 + threadIdx .y ;
10598
10699#pragma unroll
107- for (int i0 = 0 ; i0 < Dk ; i0 += 2 *WARP_SIZE) {
100+ for (int i0 = 0 ; i0 < D ; i0 += 2 *WARP_SIZE) {
108101 float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof (float2 )) + i0/2 + threadIdx .x ] : make_float2 (0 .0f , 0 .0f );
109102 Q_f[j][i0 + 0 *WARP_SIZE + threadIdx .x ] = tmp.x * scale;
110103 Q_f[j][i0 + 1 *WARP_SIZE + threadIdx .x ] = tmp.y * scale;
@@ -128,8 +121,8 @@ static __global__ void flash_attn_tile_ext_f32(
128121 const int i_KQ = i_KQ_0 + threadIdx .y ;
129122
130123#pragma unroll
131- for (int k_KQ_0 = 0 ; k_KQ_0 < Dk ; k_KQ_0 += 2 *WARP_SIZE) {
132- const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx .x ];
124+ for (int k_KQ_0 = 0 ; k_KQ_0 < D ; k_KQ_0 += 2 *WARP_SIZE) {
125+ const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx .x ];
133126 KV_tmp[i_KQ][k_KQ_0 + 0 *WARP_SIZE + threadIdx .x ] = __low2float (tmp);
134127 KV_tmp[i_KQ][k_KQ_0 + 1 *WARP_SIZE + threadIdx .x ] = __high2float (tmp);
135128 }
@@ -140,7 +133,7 @@ static __global__ void flash_attn_tile_ext_f32(
140133 float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0 .0f }};
141134
142135#pragma unroll
143- for (int k_KQ = 0 ; k_KQ < Dk ; ++k_KQ) {
136+ for (int k_KQ = 0 ; k_KQ < D ; ++k_KQ) {
144137 float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
145138 float Q_k[ncols/nwarps];
146139
@@ -209,7 +202,7 @@ static __global__ void flash_attn_tile_ext_f32(
209202 kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
210203
211204#pragma unroll
212- for (int i0 = 0 ; i0 < Dv /2 ; i0 += WARP_SIZE) {
205+ for (int i0 = 0 ; i0 < D /2 ; i0 += WARP_SIZE) {
213206 VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
214207 VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
215208 }
@@ -222,26 +215,26 @@ static __global__ void flash_attn_tile_ext_f32(
222215 const int k = k0 + threadIdx .y ;
223216
224217#pragma unroll
225- for (int i0 = 0 ; i0 < Dv /2 ; i0 += WARP_SIZE) {
218+ for (int i0 = 0 ; i0 < D /2 ; i0 += WARP_SIZE) {
226219 const int i = i0 + threadIdx .x ;
227220
228- KV_tmp2[k*(Dv /2 ) + i].x = __low2float (V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
229- KV_tmp2[k*(Dv /2 ) + i].y = __high2float (V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
221+ KV_tmp2[k*(D /2 ) + i].x = __low2float (V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
222+ KV_tmp2[k*(D /2 ) + i].y = __high2float (V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
230223 }
231224 }
232225
233226 __syncthreads ();
234227
235228#pragma unroll
236229 for (int k = 0 ; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
237- float2 V_k[(Dv /2 )/WARP_SIZE];
230+ float2 V_k[(D /2 )/WARP_SIZE];
238231 float KQ_k[ncols/nwarps];
239232
240233#pragma unroll
241- for (int i0 = 0 ; i0 < Dv /2 ; i0 += WARP_SIZE) {
234+ for (int i0 = 0 ; i0 < D /2 ; i0 += WARP_SIZE) {
242235 const int i = i0 + threadIdx .x ;
243236
244- V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv /2 ) + i];
237+ V_k[i0/WARP_SIZE] = KV_tmp2[k*(D /2 ) + i];
245238 }
246239#pragma unroll
247240 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
@@ -251,7 +244,7 @@ static __global__ void flash_attn_tile_ext_f32(
251244 }
252245
253246#pragma unroll
254- for (int i0 = 0 ; i0 < Dv /2 ; i0 += WARP_SIZE) {
247+ for (int i0 = 0 ; i0 < D /2 ; i0 += WARP_SIZE) {
255248#pragma unroll
256249 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
257250 VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x *KQ_k[j0/nwarps];
@@ -275,7 +268,7 @@ static __global__ void flash_attn_tile_ext_f32(
275268 kqsum_j = warp_reduce_sum (kqsum_j);
276269
277270#pragma unroll
278- for (int i00 = 0 ; i00 < Dv ; i00 += 2 *WARP_SIZE) {
271+ for (int i00 = 0 ; i00 < D ; i00 += 2 *WARP_SIZE) {
279272 const int i0 = i00 + 2 *threadIdx .x ;
280273
281274 float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2 *WARP_SIZE)];
@@ -284,8 +277,8 @@ static __global__ void flash_attn_tile_ext_f32(
284277 dst_val.y /= kqsum_j;
285278 }
286279 const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
287- dst[j_dst*Dv *gridDim .y + Dv *blockIdx .y + i0 + 0 ] = dst_val.x ;
288- dst[j_dst*Dv *gridDim .y + Dv *blockIdx .y + i0 + 1 ] = dst_val.y ;
280+ dst[j_dst*D *gridDim .y + D *blockIdx .y + i0 + 0 ] = dst_val.x ;
281+ dst[j_dst*D *gridDim .y + D *blockIdx .y + i0 + 1 ] = dst_val.y ;
289282 }
290283
291284 if (parallel_blocks != 1 && threadIdx .x == 0 ) {
@@ -301,13 +294,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
301294 case 64 : {
302295 constexpr int D = 64 ;
303296 constexpr int nwarps = 8 ;
304- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
297+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
305298 launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
306299 } break ;
307300 case 128 : {
308301 constexpr int D = 128 ;
309302 constexpr int nwarps = 8 ;
310- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
303+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
311304 launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true , true );
312305 } break ;
313306 default : {
0 commit comments