Skip to content

Commit 2a9b919

Browse files
committed
Fixes mask application in backward kernel computation
Adds missing mask and bias parameters to the apply_mask function call to properly handle masking during backward pass computation. Prevents potential infinite values in gradient calculations when elements exceed the actual sequence length.
1 parent d0169a5 commit 2a9b919

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
671671
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
672672
// So we need to mask out the elements beyond actual_seqlen_k.
673673
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal>(
674-
scores, params.scale_softmax,
674+
scores, mask, bias, params.scale_softmax,
675675
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
676676
binfo.actual_seqlen_k,
677677
m_block * kBlockM + get<0>(taccScS_row(0)),

0 commit comments

Comments
 (0)