Skip to content

Commit 6403ad1

Browse files
authored
Simplify and standardize flex attention interface
2 parents eade331 + 43adc40 commit 6403ad1

File tree

2 files changed

+48
-46
lines changed

2 files changed

+48
-46
lines changed

flash_dmattn/flash_dmattn_flex.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ def flex_attention_forward(
88
query: torch.Tensor,
99
key: torch.Tensor,
1010
value: torch.Tensor,
11-
attention_mask: torch.Tensor,
12-
attention_bias: torch.Tensor,
11+
attn_mask: torch.Tensor,
12+
attn_bias: torch.Tensor,
13+
scale: Optional[float] = None,
1314
is_causal: bool = True,
14-
scaling: Optional[float] = None,
1515
**kwargs,
1616
) -> Tuple[torch.Tensor, torch.Tensor]:
17-
attn_mask = attention_mask[:, :, :, : key.shape[-2]]
18-
attn_bias = attention_bias[:, :, :, : key.shape[-2]]
17+
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]
18+
key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
19+
value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
20+
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
21+
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
1922

2023
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
2124
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]
@@ -44,23 +47,21 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
4447
"num_stages": 1,
4548
"num_warps": 8,
4649
}
47-
attn_output, attention_weights = compile_friendly_flex_attention(
50+
attn_output = compile_friendly_flex_attention(
4851
query,
4952
key,
5053
value,
5154
score_mod=score_mod,
5255
block_mask=block_mask if is_causal else None,
53-
scale=scaling,
56+
scale=scale,
5457
kernel_options=kernel_options,
5558
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
5659
# For simplification, we thus always return it as no additional computations are introduced.
57-
return_lse=True,
60+
return_lse=False,
5861
training=False,
5962
)
60-
# lse is returned in float32
61-
attention_weights = attention_weights.to(value.dtype)
6263
attn_output = attn_output.transpose(1, 2).contiguous()
6364

64-
return attn_output, attention_weights
65+
return attn_output
6566

6667
flex_dmattn_func = flex_attention_forward

flash_dmattn/flash_dmattn_triton.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ def _bwd_kernel(
846846
)
847847

848848

849-
def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
849+
def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False):
850850
# shape constraints
851851
batch, seqlen_q, nheads, d = q.shape
852852
_, seqlen_k, _, _ = k.shape
@@ -919,7 +919,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
919919
seqlen_k // 32, # key for triton cache (limit number of compilations)
920920
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
921921
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
922-
causal,
922+
is_causal,
923923
BLOCK_HEADDIM,
924924
BLOCK_M=BLOCK_M,
925925
BLOCK_N=BLOCK_N,
@@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
930930

931931

932932
def _flash_attn_backward(
933-
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None
933+
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, softmax_scale=None, is_causal=False
934934
):
935935
# Make sure that the last dimension is contiguous
936936
if do.stride(-1) != 1:
@@ -1040,7 +1040,7 @@ def _flash_attn_backward(
10401040
seqlen_k // 32, # key for triton cache (limit number of compilations)
10411041
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
10421042
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1043-
causal,
1043+
is_causal,
10441044
BLOCK_HEADDIM,
10451045
# SEQUENCE_PARALLEL=False,
10461046
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
@@ -1052,63 +1052,64 @@ def _flash_attn_backward(
10521052

10531053
class FlashDMAttnFunc(torch.autograd.Function):
10541054
@staticmethod
1055-
def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None):
1055+
def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False):
10561056
"""
1057-
q: (batch_size, seqlen_q, nheads, headdim)
1058-
k: (batch_size, seqlen_k, nheads, headdim)
1059-
v: (batch_size, seqlen_k, nheads, headdim)
1060-
mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061-
bias: optional, (batch, nheads, seqlen_q, seqlen_k)
1062-
causal: bool, whether to apply causal masking
1057+
query: (batch_size, seqlen_q, nheads, headdim)
1058+
key: (batch_size, seqlen_k, nheads, headdim)
1059+
value: (batch_size, seqlen_k, nheads, headdim)
1060+
attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061+
attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k)
10631062
softmax_scale: float, scaling factor for attention scores
1063+
is_causal: bool, whether to apply causal masking
10641064
"""
1065-
batch, seqlen_q, nheads, _ = q.shape
1066-
_, seqlen_k, _, _ = k.shape
1067-
if mask is not None:
1068-
if mask.dtype == torch.bool:
1069-
mask = torch.where(mask, 1.0, 0.0)
1065+
batch, seqlen_q, nheads, _ = query.shape
1066+
_, seqlen_k, _, _ = key.shape
1067+
if attn_mask is not None:
1068+
if attn_mask.dtype == torch.bool:
1069+
attn_mask = torch.where(attn_mask, 1.0, 0.0)
10701070
else:
1071-
mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
1072-
if bias is None:
1073-
bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
1071+
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
1072+
if attn_bias is None:
1073+
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
10741074

10751075
# Make sure that the last dimension is contiguous
1076-
q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]]
1076+
query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]]
10771077
o, lse, ctx.softmax_scale = _flash_attn_forward(
1078-
q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale
1078+
query, key, value, attn_mask, attn_bias, softmax_scale=softmax_scale, is_causal=is_causal
10791079
)
1080-
ctx.save_for_backward(q, k, v, o, lse, mask, bias)
1081-
ctx.causal = causal
1080+
ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias)
1081+
ctx.is_causal = is_causal
10821082
return o
10831083

10841084
@staticmethod
10851085
def backward(ctx, do):
1086-
q, k, v, o, lse, mask, bias = ctx.saved_tensors
1086+
query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors
10871087
assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet"
10881088
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
10891089
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
10901090
with torch.inference_mode():
1091-
dq = torch.empty_like(q)
1092-
dk = torch.empty_like(k)
1093-
dv = torch.empty_like(v)
1094-
dbias = torch.empty_like(bias)
1091+
dq = torch.empty_like(query)
1092+
dk = torch.empty_like(key)
1093+
dv = torch.empty_like(value)
1094+
dbias = torch.empty_like(attn_bias)
10951095
_flash_attn_backward(
10961096
do,
1097-
q,
1098-
k,
1099-
v,
1100-
mask,
1101-
bias,
1097+
query,
1098+
key,
1099+
value,
1100+
attn_mask,
1101+
attn_bias,
11021102
o,
11031103
lse,
11041104
dq,
11051105
dk,
11061106
dv,
11071107
dbias,
1108-
causal=ctx.causal,
11091108
softmax_scale=ctx.softmax_scale,
1109+
is_causal=ctx.is_causal,
11101110
)
11111111
return dq, dk, dv, None, dbias, None, None
11121112

11131113

1114-
triton_dmattn_func = FlashDMAttnFunc.apply
1114+
def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False):
1115+
return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal)

0 commit comments

Comments
 (0)