Skip to content

Commit 5d3bd41

Browse files
committed
Returns bias gradients in backward pass
Updates all FlashDMAttn autograd function classes to properly return the computed bias gradients (dbias) in their backward methods instead of returning None. This ensures gradient computation flows correctly through the bias parameter during backpropagation.
1 parent a56ed00 commit 5d3bd41

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def backward(
490490
)
491491

492492
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
493-
return dqkv, None, None, None, None, None, None, None, None, None
493+
return dqkv, None, dbias, None, None, None, None, None, None, None
494494

495495

496496
class FlashDMAttnVarlenQKVPackedFunc(torch.autograd.Function):
@@ -604,7 +604,7 @@ def backward(
604604
)
605605

606606
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
607-
return dqkv, None, None, None, None, None, None, None, None, None, None, None
607+
return dqkv, None, dbias, None, None, None, None, None, None, None, None, None
608608

609609

610610
class FlashDMAttnKVPackedFunc(torch.autograd.Function):
@@ -712,7 +712,7 @@ def backward(
712712

713713
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
714714
dkv = dkv[..., : dout.shape[-1]]
715-
return dq, dkv, None, None, None, None, None, None, None, None, None
715+
return dq, dkv, None, dbias, None, None, None, None, None, None, None
716716

717717

718718
class FlashDMAttnVarlenKVPackedFunc(torch.autograd.Function):
@@ -837,7 +837,7 @@ def backward(
837837

838838
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
839839
dkv = dkv[..., : dout.shape[-1]]
840-
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None
840+
return dq, dkv, None, dbias, None, None, None, None, None, None, None, None, None, None, None
841841

842842

843843
class FlashDMAttnFunc(torch.autograd.Function):
@@ -941,7 +941,7 @@ def backward(
941941
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
942942
dk = dk[..., : dout.shape[-1]]
943943
dv = dv[..., : dout.shape[-1]]
944-
return dq, dk, dv, None, None, None, None, None, None, None, None, None
944+
return dq, dk, dv, None, dbias, None, None, None, None, None, None, None
945945

946946

947947
class FlashDMAttnVarlenFunc(torch.autograd.Function):
@@ -1063,7 +1063,7 @@ def backward(
10631063
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
10641064
dk = dk[..., : dout.shape[-1]]
10651065
dv = dv[..., : dout.shape[-1]]
1066-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
1066+
return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None, None
10671067

10681068

10691069
def flash_dmattn_qkvpacked_func(

0 commit comments

Comments
 (0)