Skip to content

Commit 6d60c3a

Browse files
committed
Standardizes parameter names in flash attention API
Renames function parameters to follow more conventional naming patterns: - `causal` becomes `is_causal` for boolean clarity - `q`, `k`, `v` become `query`, `key`, `value` for readability - `mask`, `bias` become `attn_mask`, `attn_bias` for specificity Updates function signatures, internal usage, and wrapper function to maintain consistency throughout the codebase.
1 parent 5c895aa commit 6d60c3a

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ def _bwd_kernel(
846846
)
847847

848848

849-
def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
849+
def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False):
850850
# shape constraints
851851
batch, seqlen_q, nheads, d = q.shape
852852
_, seqlen_k, _, _ = k.shape
@@ -919,7 +919,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
919919
seqlen_k // 32, # key for triton cache (limit number of compilations)
920920
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
921921
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
922-
causal,
922+
is_causal,
923923
BLOCK_HEADDIM,
924924
BLOCK_M=BLOCK_M,
925925
BLOCK_N=BLOCK_N,
@@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
930930

931931

932932
def _flash_attn_backward(
933-
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None
933+
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, softmax_scale=None, is_causal=False
934934
):
935935
# Make sure that the last dimension is contiguous
936936
if do.stride(-1) != 1:
@@ -1040,7 +1040,7 @@ def _flash_attn_backward(
10401040
seqlen_k // 32, # key for triton cache (limit number of compilations)
10411041
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
10421042
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1043-
causal,
1043+
is_causal,
10441044
BLOCK_HEADDIM,
10451045
# SEQUENCE_PARALLEL=False,
10461046
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
@@ -1052,63 +1052,64 @@ def _flash_attn_backward(
10521052

10531053
class FlashDMAttnFunc(torch.autograd.Function):
10541054
@staticmethod
1055-
def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None):
1055+
def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False):
10561056
"""
1057-
q: (batch_size, seqlen_q, nheads, headdim)
1058-
k: (batch_size, seqlen_k, nheads, headdim)
1059-
v: (batch_size, seqlen_k, nheads, headdim)
1060-
mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061-
bias: optional, (batch, nheads, seqlen_q, seqlen_k)
1062-
causal: bool, whether to apply causal masking
1057+
query: (batch_size, seqlen_q, nheads, headdim)
1058+
key: (batch_size, seqlen_k, nheads, headdim)
1059+
value: (batch_size, seqlen_k, nheads, headdim)
1060+
attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k)
1061+
attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k)
10631062
softmax_scale: float, scaling factor for attention scores
1063+
is_causal: bool, whether to apply causal masking
10641064
"""
1065-
batch, seqlen_q, nheads, _ = q.shape
1066-
_, seqlen_k, _, _ = k.shape
1067-
if mask is not None:
1068-
if mask.dtype == torch.bool:
1069-
mask = torch.where(mask, 1.0, 0.0)
1065+
batch, seqlen_q, nheads, _ = query.shape
1066+
_, seqlen_k, _, _ = key.shape
1067+
if attn_mask is not None:
1068+
if attn_mask.dtype == torch.bool:
1069+
attn_mask = torch.where(attn_mask, 1.0, 0.0)
10701070
else:
1071-
mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
1072-
if bias is None:
1073-
bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
1071+
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
1072+
if attn_bias is None:
1073+
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
10741074

10751075
# Make sure that the last dimension is contiguous
1076-
q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]]
1076+
query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]]
10771077
o, lse, ctx.softmax_scale = _flash_attn_forward(
1078-
q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale
1078+
query, key, value, attn_mask, attn_bias, softmax_scale=softmax_scale, is_causal=is_causal
10791079
)
1080-
ctx.save_for_backward(q, k, v, o, lse, mask, bias)
1081-
ctx.causal = causal
1080+
ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias)
1081+
ctx.is_causal = is_causal
10821082
return o
10831083

10841084
@staticmethod
10851085
def backward(ctx, do):
1086-
q, k, v, o, lse, mask, bias = ctx.saved_tensors
1086+
query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors
10871087
assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet"
10881088
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
10891089
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
10901090
with torch.inference_mode():
1091-
dq = torch.empty_like(q)
1092-
dk = torch.empty_like(k)
1093-
dv = torch.empty_like(v)
1094-
dbias = torch.empty_like(bias)
1091+
dq = torch.empty_like(query)
1092+
dk = torch.empty_like(key)
1093+
dv = torch.empty_like(value)
1094+
dbias = torch.empty_like(attn_bias)
10951095
_flash_attn_backward(
10961096
do,
1097-
q,
1098-
k,
1099-
v,
1100-
mask,
1101-
bias,
1097+
query,
1098+
key,
1099+
value,
1100+
attn_mask,
1101+
attn_bias,
11021102
o,
11031103
lse,
11041104
dq,
11051105
dk,
11061106
dv,
11071107
dbias,
1108-
causal=ctx.causal,
11091108
softmax_scale=ctx.softmax_scale,
1109+
is_causal=ctx.is_causal,
11101110
)
11111111
return dq, dk, dv, None, dbias, None, None
11121112

11131113

1114-
triton_dmattn_func = FlashDMAttnFunc.apply
1114+
def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False):
1115+
return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal)

0 commit comments

Comments
 (0)