@@ -846,7 +846,7 @@ def _bwd_kernel(
846846 )
847847
848848
849- def _flash_attn_forward (q , k , v , mask , bias , causal = False , softmax_scale = None ):
849+ def _flash_attn_forward (q , k , v , mask , bias , softmax_scale = None , is_causal = False ):
850850 # shape constraints
851851 batch , seqlen_q , nheads , d = q .shape
852852 _ , seqlen_k , _ , _ = k .shape
@@ -919,7 +919,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
919919 seqlen_k // 32 , # key for triton cache (limit number of compilations)
920920 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
921921 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
922- causal ,
922+ is_causal ,
923923 BLOCK_HEADDIM ,
924924 BLOCK_M = BLOCK_M ,
925925 BLOCK_N = BLOCK_N ,
@@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
930930
931931
932932def _flash_attn_backward (
933- do , q , k , v , mask , bias , o , lse , dq , dk , dv , dbias , causal = False , softmax_scale = None
933+ do , q , k , v , mask , bias , o , lse , dq , dk , dv , dbias , softmax_scale = None , is_causal = False
934934):
935935 # Make sure that the last dimension is contiguous
936936 if do .stride (- 1 ) != 1 :
@@ -1040,7 +1040,7 @@ def _flash_attn_backward(
10401040 seqlen_k // 32 , # key for triton cache (limit number of compilations)
10411041 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
10421042 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1043- causal ,
1043+ is_causal ,
10441044 BLOCK_HEADDIM ,
10451045 # SEQUENCE_PARALLEL=False,
10461046 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
@@ -1052,63 +1052,64 @@ def _flash_attn_backward(
10521052
10531053class FlashDMAttnFunc (torch .autograd .Function ):
10541054 @staticmethod
1055- def forward (ctx , q , k , v , mask = None , bias = None , causal = False , softmax_scale = None ):
1055+ def forward (ctx , query , key , value , attn_mask = None , attn_bias = None , softmax_scale = None , is_causal = False ):
10561056 """
1057- q: (batch_size, seqlen_q, nheads, headdim)
1058- k: (batch_size, seqlen_k, nheads, headdim)
1059- v: (batch_size, seqlen_k, nheads, headdim)
1060- mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061- bias: optional, (batch, nheads, seqlen_q, seqlen_k)
1062- causal: bool, whether to apply causal masking
1057+ query: (batch_size, seqlen_q, nheads, headdim)
1058+ key: (batch_size, seqlen_k, nheads, headdim)
1059+ value: (batch_size, seqlen_k, nheads, headdim)
1060+ attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061+ attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k)
10631062 softmax_scale: float, scaling factor for attention scores
1063+ is_causal: bool, whether to apply causal masking
10641064 """
1065- batch , seqlen_q , nheads , _ = q .shape
1066- _ , seqlen_k , _ , _ = k .shape
1067- if mask is not None :
1068- if mask .dtype == torch .bool :
1069- mask = torch .where (mask , 1.0 , 0.0 )
1065+ batch , seqlen_q , nheads , _ = query .shape
1066+ _ , seqlen_k , _ , _ = key .shape
1067+ if attn_mask is not None :
1068+ if attn_mask .dtype == torch .bool :
1069+ attn_mask = torch .where (attn_mask , 1.0 , 0.0 )
10701070 else :
1071- mask = torch .ones ((batch , nheads , seqlen_q , seqlen_k ), device = q .device , dtype = q .dtype )
1072- if bias is None :
1073- bias = torch .zeros ((batch , nheads , seqlen_q , seqlen_k ), device = q .device , dtype = q .dtype )
1071+ attn_mask = torch .ones ((batch , nheads , seqlen_q , seqlen_k ), device = query .device , dtype = query .dtype )
1072+ if attn_bias is None :
1073+ attn_bias = torch .zeros ((batch , nheads , seqlen_q , seqlen_k ), device = query .device , dtype = query .dtype )
10741074
10751075 # Make sure that the last dimension is contiguous
1076- q , k , v , mask , bias = [x if x .stride (- 1 ) == 1 else x .contiguous () for x in [q , k , v , mask , bias ]]
1076+ query , key , value , attn_mask , attn_bias = [x if x .stride (- 1 ) == 1 else x .contiguous () for x in [query , key , value , attn_mask , attn_bias ]]
10771077 o , lse , ctx .softmax_scale = _flash_attn_forward (
1078- q , k , v , mask , bias , causal = causal , softmax_scale = softmax_scale
1078+ query , key , value , attn_mask , attn_bias , softmax_scale = softmax_scale , is_causal = is_causal
10791079 )
1080- ctx .save_for_backward (q , k , v , o , lse , mask , bias )
1081- ctx .causal = causal
1080+ ctx .save_for_backward (query , key , value , o , lse , attn_mask , attn_bias )
1081+ ctx .is_causal = is_causal
10821082 return o
10831083
10841084 @staticmethod
10851085 def backward (ctx , do ):
1086- q , k , v , o , lse , mask , bias = ctx .saved_tensors
1086+ query , key , value , o , lse , attn_mask , attn_bias = ctx .saved_tensors
10871087 assert not ctx .needs_input_grad [3 ], "FlashDMAttn does not support mask gradient yet"
10881088 # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
10891089 # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
10901090 with torch .inference_mode ():
1091- dq = torch .empty_like (q )
1092- dk = torch .empty_like (k )
1093- dv = torch .empty_like (v )
1094- dbias = torch .empty_like (bias )
1091+ dq = torch .empty_like (query )
1092+ dk = torch .empty_like (key )
1093+ dv = torch .empty_like (value )
1094+ dbias = torch .empty_like (attn_bias )
10951095 _flash_attn_backward (
10961096 do ,
1097- q ,
1098- k ,
1099- v ,
1100- mask ,
1101- bias ,
1097+ query ,
1098+ key ,
1099+ value ,
1100+ attn_mask ,
1101+ attn_bias ,
11021102 o ,
11031103 lse ,
11041104 dq ,
11051105 dk ,
11061106 dv ,
11071107 dbias ,
1108- causal = ctx .causal ,
11091108 softmax_scale = ctx .softmax_scale ,
1109+ is_causal = ctx .is_causal ,
11101110 )
11111111 return dq , dk , dv , None , dbias , None , None
11121112
11131113
1114- triton_dmattn_func = FlashDMAttnFunc .apply
1114+ def triton_dmattn_func (query , key , value , attn_mask = None , attn_bias = None , scale = None , is_causal = False ):
1115+ return FlashDMAttnFunc .apply (query , key , value , attn_mask , attn_bias , scale , is_causal )
0 commit comments