@@ -665,12 +665,12 @@ def init_func(nargs):
665665 num_stages = 1 ,
666666 pre_hook = init_to_zero (["DQ" , "DBias" ]),
667667 ),
668- triton .Config (
669- {"BLOCK_M" : 64 , "BLOCK_N" : 128 , "SEQUENCE_PARALLEL" : True },
670- num_warps = 8 ,
671- num_stages = 1 ,
672- pre_hook = init_to_zero (["DQ" , "DBias" ]),
673- ),
668+ # triton.Config(
669+ # {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
670+ # num_warps=8,
671+ # num_stages=1,
672+ # pre_hook=init_to_zero(["DQ", "DBias"]),
673+ # ),
674674 # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
675675 # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
676676 # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
@@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
930930
931931
932932def _flash_attn_backward (
933- do , q , k , v , mask , bias , o , lse , dq , dk , dv , dbias , softmax_scale = None , is_causal = False
933+ do , q , k , v , mask , bias , o , lse , softmax_scale = None , is_causal = False
934934):
935935 # Make sure that the last dimension is contiguous
936936 if do .stride (- 1 ) != 1 :
@@ -941,8 +941,6 @@ def _flash_attn_backward(
941941 assert d <= 128
942942 seqlen_q_rounded = math .ceil (seqlen_q / 128 ) * 128
943943 assert lse .shape == (batch , nheads , seqlen_q_rounded )
944- assert q .stride (- 1 ) == k .stride (- 1 ) == v .stride (- 1 ) == o .stride (- 1 ) == 1
945- assert dq .stride (- 1 ) == dk .stride (- 1 ) == dv .stride (- 1 ) == 1
946944
947945 assert mask .dtype in [q .dtype , torch .float ]
948946 assert mask .is_cuda
@@ -959,6 +957,9 @@ def _flash_attn_backward(
959957 dq_accum = torch .empty_like (q , dtype = torch .float32 )
960958 delta = torch .empty_like (lse )
961959 # delta = torch.zeros_like(lse)
960+ dk = torch .empty_like (k )
961+ dv = torch .empty_like (v )
962+ dbias = torch .empty_like (bias )
962963
963964 BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
964965 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK_M" ]), batch * nheads )
@@ -1047,7 +1048,8 @@ def _flash_attn_backward(
10471048 # num_warps=num_warps,
10481049 # num_stages=1,
10491050 )
1050- dq .copy_ (dq_accum )
1051+ dq = dq_accum .to (q .dtype )
1052+ return dq , dk , dv , dbias
10511053
10521054
10531055class FlashDMAttnFunc (torch .autograd .Function ):
@@ -1075,7 +1077,13 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
10751077 # Make sure that the last dimension is contiguous
10761078 query , key , value , attn_mask , attn_bias = [x if x .stride (- 1 ) == 1 else x .contiguous () for x in [query , key , value , attn_mask , attn_bias ]]
10771079 o , lse , ctx .softmax_scale = _flash_attn_forward (
1078- query , key , value , attn_mask , attn_bias , softmax_scale = softmax_scale , is_causal = is_causal
1080+ query ,
1081+ key ,
1082+ value ,
1083+ attn_mask ,
1084+ attn_bias ,
1085+ softmax_scale = softmax_scale ,
1086+ is_causal = is_causal
10791087 )
10801088 ctx .save_for_backward (query , key , value , o , lse , attn_mask , attn_bias )
10811089 ctx .is_causal = is_causal
@@ -1085,29 +1093,18 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
10851093 def backward (ctx , do ):
10861094 query , key , value , o , lse , attn_mask , attn_bias = ctx .saved_tensors
10871095 assert not ctx .needs_input_grad [3 ], "FlashDMAttn does not support mask gradient yet"
1088- # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1089- # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1090- with torch .inference_mode ():
1091- dq = torch .empty_like (query )
1092- dk = torch .empty_like (key )
1093- dv = torch .empty_like (value )
1094- dbias = torch .empty_like (attn_bias )
1095- _flash_attn_backward (
1096- do ,
1097- query ,
1098- key ,
1099- value ,
1100- attn_mask ,
1101- attn_bias ,
1102- o ,
1103- lse ,
1104- dq ,
1105- dk ,
1106- dv ,
1107- dbias ,
1108- softmax_scale = ctx .softmax_scale ,
1109- is_causal = ctx .is_causal ,
1110- )
1096+ dq , dk , dv , dbias = _flash_attn_backward (
1097+ do ,
1098+ query ,
1099+ key ,
1100+ value ,
1101+ attn_mask ,
1102+ attn_bias ,
1103+ o ,
1104+ lse ,
1105+ softmax_scale = ctx .softmax_scale ,
1106+ is_causal = ctx .is_causal ,
1107+ )
11111108 return dq , dk , dv , None , dbias , None , None
11121109
11131110
0 commit comments