Skip to content

Commit 43d73e8

Browse files
committed
Refactors backward pass to return gradients directly
Simplifies the backward function interface by removing output tensor parameters and returning gradients directly instead of modifying pre-allocated tensors in-place. Removes inference mode wrapper and associated tensor pre-allocation from autograd function, streamlining the gradient computation flow. Comments out problematic triton configuration that may cause issues with certain sequence lengths.
1 parent 974451e commit 43d73e8

File tree

1 file changed

+31
-34
lines changed

1 file changed

+31
-34
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,12 @@ def init_func(nargs):
665665
num_stages=1,
666666
pre_hook=init_to_zero(["DQ", "DBias"]),
667667
),
668-
triton.Config(
669-
{"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
670-
num_warps=8,
671-
num_stages=1,
672-
pre_hook=init_to_zero(["DQ", "DBias"]),
673-
),
668+
# triton.Config(
669+
# {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
670+
# num_warps=8,
671+
# num_stages=1,
672+
# pre_hook=init_to_zero(["DQ", "DBias"]),
673+
# ),
674674
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
675675
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
676676
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
@@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
930930

931931

932932
def _flash_attn_backward(
933-
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, softmax_scale=None, is_causal=False
933+
do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False
934934
):
935935
# Make sure that the last dimension is contiguous
936936
if do.stride(-1) != 1:
@@ -941,8 +941,6 @@ def _flash_attn_backward(
941941
assert d <= 128
942942
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
943943
assert lse.shape == (batch, nheads, seqlen_q_rounded)
944-
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
945-
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
946944

947945
assert mask.dtype in [q.dtype, torch.float]
948946
assert mask.is_cuda
@@ -959,6 +957,9 @@ def _flash_attn_backward(
959957
dq_accum = torch.empty_like(q, dtype=torch.float32)
960958
delta = torch.empty_like(lse)
961959
# delta = torch.zeros_like(lse)
960+
dk = torch.empty_like(k)
961+
dv = torch.empty_like(v)
962+
dbias = torch.empty_like(bias)
962963

963964
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
964965
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
@@ -1047,7 +1048,8 @@ def _flash_attn_backward(
10471048
# num_warps=num_warps,
10481049
# num_stages=1,
10491050
)
1050-
dq.copy_(dq_accum)
1051+
dq = dq_accum.to(q.dtype)
1052+
return dq, dk, dv, dbias
10511053

10521054

10531055
class FlashDMAttnFunc(torch.autograd.Function):
@@ -1075,7 +1077,13 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
10751077
# Make sure that the last dimension is contiguous
10761078
query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]]
10771079
o, lse, ctx.softmax_scale = _flash_attn_forward(
1078-
query, key, value, attn_mask, attn_bias, softmax_scale=softmax_scale, is_causal=is_causal
1080+
query,
1081+
key,
1082+
value,
1083+
attn_mask,
1084+
attn_bias,
1085+
softmax_scale=softmax_scale,
1086+
is_causal=is_causal
10791087
)
10801088
ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias)
10811089
ctx.is_causal = is_causal
@@ -1085,29 +1093,18 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
10851093
def backward(ctx, do):
10861094
query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors
10871095
assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet"
1088-
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1089-
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1090-
with torch.inference_mode():
1091-
dq = torch.empty_like(query)
1092-
dk = torch.empty_like(key)
1093-
dv = torch.empty_like(value)
1094-
dbias = torch.empty_like(attn_bias)
1095-
_flash_attn_backward(
1096-
do,
1097-
query,
1098-
key,
1099-
value,
1100-
attn_mask,
1101-
attn_bias,
1102-
o,
1103-
lse,
1104-
dq,
1105-
dk,
1106-
dv,
1107-
dbias,
1108-
softmax_scale=ctx.softmax_scale,
1109-
is_causal=ctx.is_causal,
1110-
)
1096+
dq, dk, dv, dbias = _flash_attn_backward(
1097+
do,
1098+
query,
1099+
key,
1100+
value,
1101+
attn_mask,
1102+
attn_bias,
1103+
o,
1104+
lse,
1105+
softmax_scale=ctx.softmax_scale,
1106+
is_causal=ctx.is_causal,
1107+
)
11111108
return dq, dk, dv, None, dbias, None, None
11121109

11131110

0 commit comments

Comments
 (0)