File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -5101,7 +5101,7 @@ void kernel_flash_attn_ext_impl(
51015101
51025102 device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
51035103
5104- const float scale = 1 .0f /S[jj];
5104+ const float scale = S[jj] == 0.0 ? 0 . 0f : 1 .0f /S[jj];
51055105
51065106 if (DV4 % NW == 0 ) {
51075107 FOR_UNROLL (short ii = 0 ; ii < DV4/NW; ++ii) {
@@ -5721,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl(
57215721 device float4 * dst4 = (device float4 *) dst;
57225722 device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
57235723
5724- const float S = NWG == 1 ? 1 .0f /ss[0 ] : 1 .0f ;
5724+ const float S = NWG == 1 ? (ss[ 0 ] == 0 . 0f ? 0 . 0f : 1 .0f /ss[0 ]) : 1 .0f ;
57255725
57265726 // interleave the workgroup data
57275727 for (short i = tiisg; i < DV4; i += NW) {
@@ -5899,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
58995899 const float m = simd_max (M);
59005900 const float ms = exp (M - m);
59015901
5902- S = 1 .0f /simd_sum (S*ms);
5902+ S = simd_sum (S*ms);
5903+ S = S == 0 .0f ? 0 .0f : 1 .0f /S;
59035904
59045905 const short DV4 = DV/4 ;
59055906
You can’t perform that action at this time.
0 commit comments