Skip to content

Commit f48f9c2

Browse files
authored
[core] start varlen variants for attn backend kernels. (#12765)
* start varlen variants for attn backend kernels. * maybe unflatten heads. * updates * remove unused function. * doc * up
1 parent 3c05b9f commit f48f9c2

File tree

3 files changed

+132
-5
lines changed

3 files changed

+132
-5
lines changed

docs/source/en/api/pipelines/hunyuan_video15.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ export_to_video(video, "output.mp4", fps=15)
5656

5757
- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
5858

59-
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
60-
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen`
59+
- **H100/H800:** `_flash_3_hub` or `_flash_3_varlen_hub`
60+
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen_hub`
6161
- **Other GPUs:** `sage_hub`
6262

6363
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.

docs/source/en/optimization/attention_backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,12 @@ Refer to the table below for a complete list of available attention backends and
141141
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
142142
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
143143
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
144+
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
144145
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
145146
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
146147
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
147148
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
149+
| `_flash_3_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 from kernels |
148150
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
149151
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
150152
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |

src/diffusers/models/attention_dispatch.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,11 @@ class AttentionBackendName(str, Enum):
168168
FLASH = "flash"
169169
FLASH_HUB = "flash_hub"
170170
FLASH_VARLEN = "flash_varlen"
171+
FLASH_VARLEN_HUB = "flash_varlen_hub"
171172
_FLASH_3 = "_flash_3"
172173
_FLASH_VARLEN_3 = "_flash_varlen_3"
173174
_FLASH_3_HUB = "_flash_3_hub"
174-
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
175+
_FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
175176

176177
# `aiter`
177178
AITER = "aiter"
@@ -263,9 +264,17 @@ class _HubKernelConfig:
263264
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
264265
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
265266
),
267+
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
268+
repo_id="kernels-community/flash-attn3",
269+
function_attr="flash_attn_varlen_func",
270+
# revision="fake-ops-return-probs",
271+
),
266272
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
267273
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
268274
),
275+
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
276+
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
277+
),
269278
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
270279
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
271280
),
@@ -425,8 +434,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
425434
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
426435
)
427436

428-
# TODO: add support Hub variant of varlen later
429-
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
437+
elif backend in [
438+
AttentionBackendName.FLASH_HUB,
439+
AttentionBackendName.FLASH_VARLEN_HUB,
440+
AttentionBackendName._FLASH_3_HUB,
441+
AttentionBackendName._FLASH_3_VARLEN_HUB,
442+
AttentionBackendName.SAGE_HUB,
443+
]:
430444
if not is_kernels_available():
431445
raise RuntimeError(
432446
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -1387,6 +1401,63 @@ def _flash_attention_hub(
13871401
return (out, lse) if return_lse else out
13881402

13891403

1404+
@_AttentionBackendRegistry.register(
1405+
AttentionBackendName.FLASH_VARLEN_HUB,
1406+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1407+
supports_context_parallel=False,
1408+
)
1409+
def _flash_varlen_attention_hub(
1410+
query: torch.Tensor,
1411+
key: torch.Tensor,
1412+
value: torch.Tensor,
1413+
attn_mask: Optional[torch.Tensor] = None,
1414+
dropout_p: float = 0.0,
1415+
scale: Optional[float] = None,
1416+
is_causal: bool = False,
1417+
return_lse: bool = False,
1418+
_parallel_config: Optional["ParallelConfig"] = None,
1419+
) -> torch.Tensor:
1420+
batch_size, seq_len_q, _, _ = query.shape
1421+
_, seq_len_kv, _, _ = key.shape
1422+
1423+
if attn_mask is not None:
1424+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1425+
1426+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1427+
_prepare_for_flash_attn_or_sage_varlen(
1428+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1429+
)
1430+
)
1431+
1432+
key_valid, value_valid = [], []
1433+
for b in range(batch_size):
1434+
valid_len = seqlens_k[b]
1435+
key_valid.append(key[b, :valid_len])
1436+
value_valid.append(value[b, :valid_len])
1437+
1438+
query_packed = query.flatten(0, 1)
1439+
key_packed = torch.cat(key_valid, dim=0)
1440+
value_packed = torch.cat(value_valid, dim=0)
1441+
1442+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
1443+
out = func(
1444+
q=query_packed,
1445+
k=key_packed,
1446+
v=value_packed,
1447+
cu_seqlens_q=cu_seqlens_q,
1448+
cu_seqlens_k=cu_seqlens_k,
1449+
max_seqlen_q=max_seqlen_q,
1450+
max_seqlen_k=max_seqlen_k,
1451+
dropout_p=dropout_p,
1452+
softmax_scale=scale,
1453+
causal=is_causal,
1454+
return_attn_probs=return_lse,
1455+
)
1456+
out = out.unflatten(0, (batch_size, -1))
1457+
1458+
return out
1459+
1460+
13901461
@_AttentionBackendRegistry.register(
13911462
AttentionBackendName.FLASH_VARLEN,
13921463
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -1509,6 +1580,60 @@ def _flash_attention_3_hub(
15091580
return (out[0], out[1]) if return_attn_probs else out
15101581

15111582

1583+
@_AttentionBackendRegistry.register(
1584+
AttentionBackendName._FLASH_3_VARLEN_HUB,
1585+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1586+
supports_context_parallel=False,
1587+
)
1588+
def _flash_attention_3_varlen_hub(
1589+
query: torch.Tensor,
1590+
key: torch.Tensor,
1591+
value: torch.Tensor,
1592+
attn_mask: Optional[torch.Tensor] = None,
1593+
scale: Optional[float] = None,
1594+
is_causal: bool = False,
1595+
return_lse: bool = False,
1596+
_parallel_config: Optional["ParallelConfig"] = None,
1597+
) -> torch.Tensor:
1598+
batch_size, seq_len_q, _, _ = query.shape
1599+
_, seq_len_kv, _, _ = key.shape
1600+
1601+
if attn_mask is not None:
1602+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1603+
1604+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1605+
_prepare_for_flash_attn_or_sage_varlen(
1606+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1607+
)
1608+
)
1609+
1610+
key_valid, value_valid = [], []
1611+
for b in range(batch_size):
1612+
valid_len = seqlens_k[b]
1613+
key_valid.append(key[b, :valid_len])
1614+
value_valid.append(value[b, :valid_len])
1615+
1616+
query_packed = query.flatten(0, 1)
1617+
key_packed = torch.cat(key_valid, dim=0)
1618+
value_packed = torch.cat(value_valid, dim=0)
1619+
1620+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
1621+
out, lse, *_ = func(
1622+
q=query_packed,
1623+
k=key_packed,
1624+
v=value_packed,
1625+
cu_seqlens_q=cu_seqlens_q,
1626+
cu_seqlens_k=cu_seqlens_k,
1627+
max_seqlen_q=max_seqlen_q,
1628+
max_seqlen_k=max_seqlen_k,
1629+
softmax_scale=scale,
1630+
causal=is_causal,
1631+
)
1632+
out = out.unflatten(0, (batch_size, -1))
1633+
1634+
return (out, lse) if return_lse else out
1635+
1636+
15121637
@_AttentionBackendRegistry.register(
15131638
AttentionBackendName._FLASH_VARLEN_3,
15141639
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

0 commit comments

Comments
 (0)