Skip to content

Commit e741ec8

Browse files
authored
CUDA: Fix FA for Pascal GPU (#1036)
Co-authored-by: firecoperana <firecoperana>
1 parent f4def9b commit e741ec8

File tree

1 file changed

+26
-33
lines changed

1 file changed

+26
-33
lines changed

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#define FATTN_KQ_STRIDE_TILE_F32 32
1313

14-
template<int Dk, int Dv, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
14+
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
1515
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1616
__launch_bounds__(nwarps*WARP_SIZE, 1)
1717
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -52,9 +52,9 @@ static __global__ void flash_attn_tile_ext_f32(
5252
const int ne1,
5353
const int ne2,
5454
const int ne3) {
55-
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
55+
5656
// Skip unused kernel variants for faster compilation:
57-
if (use_softcap && !(Dk == 128 || Dk == 256)) {
57+
if (use_softcap && !(D == 128 || D == 256)) {
5858
NO_DEVICE_CODE;
5959
return;
6060
}
@@ -70,22 +70,15 @@ static __global__ void flash_attn_tile_ext_f32(
7070
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
7171
const half * maskh = (const half *) mask + ne11*ic0;
7272

73-
const int stride_K2 = nb11 / sizeof(half2);
74-
const int stride_V2 = nb12 / sizeof(half2);
73+
const int stride_KV2 = nb11 / sizeof(half2);
7574

7675
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
77-
78-
// TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ?
79-
// let's assume it is is both.
80-
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
81-
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
82-
83-
constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv
76+
static_assert(D % (2 * WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
8477

8578
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
8679

87-
// This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension
88-
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts.
80+
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
81+
8982
float2 * KV_tmp2 = (float2 *) KV_tmp;
9083

9184
float kqmax[ncols/nwarps];
@@ -95,16 +88,16 @@ static __global__ void flash_attn_tile_ext_f32(
9588
}
9689
float kqsum[ncols/nwarps] = {0.0f};
9790

98-
float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
91+
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
9992

10093
// Convert Q to half2 and store in registers:
101-
__shared__ float Q_f[ncols][Dk];
94+
__shared__ float Q_f[ncols][D];
10295
#pragma unroll
10396
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
10497
const int j = j0 + threadIdx.y;
10598

10699
#pragma unroll
107-
for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) {
100+
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
108101
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
109102
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
110103
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
@@ -128,8 +121,8 @@ static __global__ void flash_attn_tile_ext_f32(
128121
const int i_KQ = i_KQ_0 + threadIdx.y;
129122

130123
#pragma unroll
131-
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) {
132-
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x];
124+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
125+
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
133126
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
134127
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
135128
}
@@ -140,7 +133,7 @@ static __global__ void flash_attn_tile_ext_f32(
140133
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
141134

142135
#pragma unroll
143-
for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) {
136+
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
144137
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
145138
float Q_k[ncols/nwarps];
146139

@@ -209,7 +202,7 @@ static __global__ void flash_attn_tile_ext_f32(
209202
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
210203

211204
#pragma unroll
212-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
205+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
213206
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
214207
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
215208
}
@@ -222,26 +215,26 @@ static __global__ void flash_attn_tile_ext_f32(
222215
const int k = k0 + threadIdx.y;
223216

224217
#pragma unroll
225-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
218+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
226219
const int i = i0 + threadIdx.x;
227220

228-
KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
229-
KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
221+
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
222+
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
230223
}
231224
}
232225

233226
__syncthreads();
234227

235228
#pragma unroll
236229
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
237-
float2 V_k[(Dv/2)/WARP_SIZE];
230+
float2 V_k[(D/2)/WARP_SIZE];
238231
float KQ_k[ncols/nwarps];
239232

240233
#pragma unroll
241-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
234+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
242235
const int i = i0 + threadIdx.x;
243236

244-
V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i];
237+
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
245238
}
246239
#pragma unroll
247240
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
@@ -251,7 +244,7 @@ static __global__ void flash_attn_tile_ext_f32(
251244
}
252245

253246
#pragma unroll
254-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
247+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
255248
#pragma unroll
256249
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
257250
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
@@ -275,7 +268,7 @@ static __global__ void flash_attn_tile_ext_f32(
275268
kqsum_j = warp_reduce_sum(kqsum_j);
276269

277270
#pragma unroll
278-
for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) {
271+
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
279272
const int i0 = i00 + 2*threadIdx.x;
280273

281274
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
@@ -284,8 +277,8 @@ static __global__ void flash_attn_tile_ext_f32(
284277
dst_val.y /= kqsum_j;
285278
}
286279
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
287-
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x;
288-
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y;
280+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
281+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
289282
}
290283

291284
if (parallel_blocks != 1 && threadIdx.x == 0) {
@@ -301,13 +294,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
301294
case 64: {
302295
constexpr int D = 64;
303296
constexpr int nwarps = 8;
304-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
297+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
305298
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
306299
} break;
307300
case 128: {
308301
constexpr int D = 128;
309302
constexpr int nwarps = 8;
310-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
303+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
311304
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
312305
} break;
313306
default: {

0 commit comments

Comments
 (0)