@@ -168,10 +168,11 @@ class AttentionBackendName(str, Enum):
168168 FLASH = "flash"
169169 FLASH_HUB = "flash_hub"
170170 FLASH_VARLEN = "flash_varlen"
171+ FLASH_VARLEN_HUB = "flash_varlen_hub"
171172 _FLASH_3 = "_flash_3"
172173 _FLASH_VARLEN_3 = "_flash_varlen_3"
173174 _FLASH_3_HUB = "_flash_3_hub"
174- # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
175+ _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
175176
176177 # `aiter`
177178 AITER = "aiter"
@@ -263,9 +264,17 @@ class _HubKernelConfig:
263264 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
264265 repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
265266 ),
267+ AttentionBackendName ._FLASH_3_VARLEN_HUB : _HubKernelConfig (
268+ repo_id = "kernels-community/flash-attn3" ,
269+ function_attr = "flash_attn_varlen_func" ,
270+ # revision="fake-ops-return-probs",
271+ ),
266272 AttentionBackendName .FLASH_HUB : _HubKernelConfig (
267273 repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_func" , revision = None
268274 ),
275+ AttentionBackendName .FLASH_VARLEN_HUB : _HubKernelConfig (
276+ repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_varlen_func" , revision = None
277+ ),
269278 AttentionBackendName .SAGE_HUB : _HubKernelConfig (
270279 repo_id = "kernels-community/sage_attention" , function_attr = "sageattn" , revision = None
271280 ),
@@ -425,8 +434,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
425434 f"Flash Attention 3 backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
426435 )
427436
428- # TODO: add support Hub variant of varlen later
429- elif backend in [AttentionBackendName ._FLASH_3_HUB , AttentionBackendName .FLASH_HUB , AttentionBackendName .SAGE_HUB ]:
437+ elif backend in [
438+ AttentionBackendName .FLASH_HUB ,
439+ AttentionBackendName .FLASH_VARLEN_HUB ,
440+ AttentionBackendName ._FLASH_3_HUB ,
441+ AttentionBackendName ._FLASH_3_VARLEN_HUB ,
442+ AttentionBackendName .SAGE_HUB ,
443+ ]:
430444 if not is_kernels_available ():
431445 raise RuntimeError (
432446 f"Backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -1387,6 +1401,63 @@ def _flash_attention_hub(
13871401 return (out , lse ) if return_lse else out
13881402
13891403
1404+ @_AttentionBackendRegistry .register (
1405+ AttentionBackendName .FLASH_VARLEN_HUB ,
1406+ constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1407+ supports_context_parallel = False ,
1408+ )
1409+ def _flash_varlen_attention_hub (
1410+ query : torch .Tensor ,
1411+ key : torch .Tensor ,
1412+ value : torch .Tensor ,
1413+ attn_mask : Optional [torch .Tensor ] = None ,
1414+ dropout_p : float = 0.0 ,
1415+ scale : Optional [float ] = None ,
1416+ is_causal : bool = False ,
1417+ return_lse : bool = False ,
1418+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1419+ ) -> torch .Tensor :
1420+ batch_size , seq_len_q , _ , _ = query .shape
1421+ _ , seq_len_kv , _ , _ = key .shape
1422+
1423+ if attn_mask is not None :
1424+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
1425+
1426+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1427+ _prepare_for_flash_attn_or_sage_varlen (
1428+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1429+ )
1430+ )
1431+
1432+ key_valid , value_valid = [], []
1433+ for b in range (batch_size ):
1434+ valid_len = seqlens_k [b ]
1435+ key_valid .append (key [b , :valid_len ])
1436+ value_valid .append (value [b , :valid_len ])
1437+
1438+ query_packed = query .flatten (0 , 1 )
1439+ key_packed = torch .cat (key_valid , dim = 0 )
1440+ value_packed = torch .cat (value_valid , dim = 0 )
1441+
1442+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_VARLEN_HUB ].kernel_fn
1443+ out = func (
1444+ q = query_packed ,
1445+ k = key_packed ,
1446+ v = value_packed ,
1447+ cu_seqlens_q = cu_seqlens_q ,
1448+ cu_seqlens_k = cu_seqlens_k ,
1449+ max_seqlen_q = max_seqlen_q ,
1450+ max_seqlen_k = max_seqlen_k ,
1451+ dropout_p = dropout_p ,
1452+ softmax_scale = scale ,
1453+ causal = is_causal ,
1454+ return_attn_probs = return_lse ,
1455+ )
1456+ out = out .unflatten (0 , (batch_size , - 1 ))
1457+
1458+ return out
1459+
1460+
13901461@_AttentionBackendRegistry .register (
13911462 AttentionBackendName .FLASH_VARLEN ,
13921463 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
@@ -1509,6 +1580,60 @@ def _flash_attention_3_hub(
15091580 return (out [0 ], out [1 ]) if return_attn_probs else out
15101581
15111582
1583+ @_AttentionBackendRegistry .register (
1584+ AttentionBackendName ._FLASH_3_VARLEN_HUB ,
1585+ constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1586+ supports_context_parallel = False ,
1587+ )
1588+ def _flash_attention_3_varlen_hub (
1589+ query : torch .Tensor ,
1590+ key : torch .Tensor ,
1591+ value : torch .Tensor ,
1592+ attn_mask : Optional [torch .Tensor ] = None ,
1593+ scale : Optional [float ] = None ,
1594+ is_causal : bool = False ,
1595+ return_lse : bool = False ,
1596+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1597+ ) -> torch .Tensor :
1598+ batch_size , seq_len_q , _ , _ = query .shape
1599+ _ , seq_len_kv , _ , _ = key .shape
1600+
1601+ if attn_mask is not None :
1602+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
1603+
1604+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1605+ _prepare_for_flash_attn_or_sage_varlen (
1606+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1607+ )
1608+ )
1609+
1610+ key_valid , value_valid = [], []
1611+ for b in range (batch_size ):
1612+ valid_len = seqlens_k [b ]
1613+ key_valid .append (key [b , :valid_len ])
1614+ value_valid .append (value [b , :valid_len ])
1615+
1616+ query_packed = query .flatten (0 , 1 )
1617+ key_packed = torch .cat (key_valid , dim = 0 )
1618+ value_packed = torch .cat (value_valid , dim = 0 )
1619+
1620+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_VARLEN_HUB ].kernel_fn
1621+ out , lse , * _ = func (
1622+ q = query_packed ,
1623+ k = key_packed ,
1624+ v = value_packed ,
1625+ cu_seqlens_q = cu_seqlens_q ,
1626+ cu_seqlens_k = cu_seqlens_k ,
1627+ max_seqlen_q = max_seqlen_q ,
1628+ max_seqlen_k = max_seqlen_k ,
1629+ softmax_scale = scale ,
1630+ causal = is_causal ,
1631+ )
1632+ out = out .unflatten (0 , (batch_size , - 1 ))
1633+
1634+ return (out , lse ) if return_lse else out
1635+
1636+
15121637@_AttentionBackendRegistry .register (
15131638 AttentionBackendName ._FLASH_VARLEN_3 ,
15141639 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
0 commit comments