@@ -490,7 +490,7 @@ def backward(
490490 )
491491
492492 dqkv = dqkv [..., : dout .shape [- 1 ]] # We could have padded the head dimension
493- return dqkv , None , None , None , None , None , None , None , None , None
493+ return dqkv , None , dbias , None , None , None , None , None , None , None
494494
495495
496496class FlashDMAttnVarlenQKVPackedFunc (torch .autograd .Function ):
@@ -604,7 +604,7 @@ def backward(
604604 )
605605
606606 dqkv = dqkv [..., : dout .shape [- 1 ]] # We could have padded the head dimension
607- return dqkv , None , None , None , None , None , None , None , None , None , None , None
607+ return dqkv , None , dbias , None , None , None , None , None , None , None , None , None
608608
609609
610610class FlashDMAttnKVPackedFunc (torch .autograd .Function ):
@@ -712,7 +712,7 @@ def backward(
712712
713713 dq = dq [..., : dout .shape [- 1 ]] # We could have padded the head dimension
714714 dkv = dkv [..., : dout .shape [- 1 ]]
715- return dq , dkv , None , None , None , None , None , None , None , None , None
715+ return dq , dkv , None , dbias , None , None , None , None , None , None , None
716716
717717
718718class FlashDMAttnVarlenKVPackedFunc (torch .autograd .Function ):
@@ -837,7 +837,7 @@ def backward(
837837
838838 dq = dq [..., : dout .shape [- 1 ]] # We could have padded the head dimension
839839 dkv = dkv [..., : dout .shape [- 1 ]]
840- return dq , dkv , None , None , None , None , None , None , None , None , None , None , None , None , None
840+ return dq , dkv , None , dbias , None , None , None , None , None , None , None , None , None , None , None
841841
842842
843843class FlashDMAttnFunc (torch .autograd .Function ):
@@ -941,7 +941,7 @@ def backward(
941941 dq = dq [..., : dout .shape [- 1 ]] # We could have padded the head dimension
942942 dk = dk [..., : dout .shape [- 1 ]]
943943 dv = dv [..., : dout .shape [- 1 ]]
944- return dq , dk , dv , None , None , None , None , None , None , None , None , None
944+ return dq , dk , dv , None , dbias , None , None , None , None , None , None , None
945945
946946
947947class FlashDMAttnVarlenFunc (torch .autograd .Function ):
@@ -1063,7 +1063,7 @@ def backward(
10631063 dq = dq [..., : dout .shape [- 1 ]] # We could have padded the head dimension
10641064 dk = dk [..., : dout .shape [- 1 ]]
10651065 dv = dv [..., : dout .shape [- 1 ]]
1066- return dq , dk , dv , None , None , None , None , None , None , None , None , None , None , None , None , None , None
1066+ return dq , dk , dv , None , dbias , None , None , None , None , None , None , None , None , None , None , None , None
10671067
10681068
10691069def flash_dmattn_qkvpacked_func (
0 commit comments