Skip to content

Commit 467743f

Browse files
committed
metal : prevent division by zero in FA kernels
1 parent 62f9209 commit 467743f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5101,7 +5101,7 @@ void kernel_flash_attn_ext_impl(
51015101

51025102
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
51035103

5104-
const float scale = 1.0f/S[jj];
5104+
const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
51055105

51065106
if (DV4 % NW == 0) {
51075107
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
@@ -5721,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl(
57215721
device float4 * dst4 = (device float4 *) dst;
57225722
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
57235723

5724-
const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
5724+
const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
57255725

57265726
// interleave the workgroup data
57275727
for (short i = tiisg; i < DV4; i += NW) {
@@ -5899,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
58995899
const float m = simd_max(M);
59005900
const float ms = exp(M - m);
59015901

5902-
S = 1.0f/simd_sum(S*ms);
5902+
S = simd_sum(S*ms);
5903+
S = S == 0.0f ? 0.0f : 1.0f/S;
59035904

59045905
const short DV4 = DV/4;
59055906

0 commit comments

Comments
 (0)