@@ -78,7 +78,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
7878 float max_bias = 0 .0f ;
7979 memcpy (&max_bias, (const float *) KQV->op_params + 1 , sizeof (float ));
8080
81- const float use_gqa_opt = mask && max_bias == 0 .0f ;
81+ const bool use_gqa_opt = mask && max_bias == 0 .0f ;
8282
8383 GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
8484 const int gqa_ratio = Q->ne [2 ] / K->ne [2 ];
@@ -88,12 +88,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
8888 return ;
8989 }
9090
91- if (use_gqa_opt && gqa_ratio == 4 ) {
91+ if (use_gqa_opt && gqa_ratio % 4 == 0 ) {
9292 ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4 >(ctx, dst);
9393 return ;
9494 }
9595
96- if (use_gqa_opt && gqa_ratio == 2 ) {
96+ if (use_gqa_opt && gqa_ratio % 2 == 0 ) {
9797 ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2 >(ctx, dst);
9898 return ;
9999 }
@@ -508,7 +508,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
508508 // const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
509509 // const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
510510 const bool mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies;
511- const bool can_use_vector_kernel = Q->ne [0 ] % (2 *WARP_SIZE) == 0 ;
511+ const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q-> ne [ 0 ] % (2 *WARP_SIZE) == 0 ;
512512 if (Q->ne [1 ] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
513513 if (precision == GGML_PREC_DEFAULT) {
514514 ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
0 commit comments