Skip to content

Commit 7d14f8e

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix GLM-4.5 attention (#700)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 93a4f60 commit 7d14f8e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ static __global__ void flash_attn_ext_f16(
9696
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
9797
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
9898
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
99-
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
99+
[[maybe_unused]] const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
100100

101101
const int stride_Q = nb01 / sizeof(float);
102102
const int stride_K = nb11 / sizeof(half);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
7878
float max_bias = 0.0f;
7979
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
8080

81-
const float use_gqa_opt = mask && max_bias == 0.0f;
81+
const bool use_gqa_opt = mask && max_bias == 0.0f;
8282

8383
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
8484
const int gqa_ratio = Q->ne[2] / K->ne[2];
@@ -88,12 +88,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
8888
return;
8989
}
9090

91-
if (use_gqa_opt && gqa_ratio == 4) {
91+
if (use_gqa_opt && gqa_ratio % 4 == 0) {
9292
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
9393
return;
9494
}
9595

96-
if (use_gqa_opt && gqa_ratio == 2) {
96+
if (use_gqa_opt && gqa_ratio % 2 == 0) {
9797
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
9898
return;
9999
}
@@ -508,7 +508,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
508508
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
509509
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
510510
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
511-
const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0;
511+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
512512
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
513513
if (precision == GGML_PREC_DEFAULT) {
514514
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);

0 commit comments

Comments
 (0)