@@ -68,22 +68,23 @@ Main attention function. Supports multi-head and grouped-query attention (when t
6868
6969``` python
7070def flash_dmattn_func (
71- q : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
72- k : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
73- v : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
74- attn_mask : Optional[torch.Tensor] = None , # (batch, num_heads, seqlen_q, seqlen_k)
75- attn_bias : Optional[torch.Tensor] = None , # (batch, num_heads, seqlen_q, seqlen_k)
76- scale : Optional[float ] = None , # score scaling, defaults to 1/sqrt(head_dim)
77- is_causal : Optional[bool ] = None , # causal mask
78- softcap : Optional[float ] = None , # CUDA-only
79- deterministic : Optional[bool ] = None , # CUDA-only
71+ query : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
72+ key : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
73+ value : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
74+ attn_mask : Optional[torch.Tensor] = None , # (batch, num_heads, seqlen_q, seqlen_k)
75+ attn_bias : Optional[torch.Tensor] = None , # (batch, num_heads, seqlen_q, seqlen_k)
76+ scale : Optional[float ] = None , # score scaling, defaults to 1/sqrt(head_dim)
77+ is_causal : Optional[bool ] = None , # causal mask
78+ softcap : Optional[float ] = None , # CUDA-only
79+ deterministic : Optional[bool ] = None , # CUDA-only
8080) -> torch.Tensor
8181```
8282
8383# ### Parameters
8484
85- - q: (B, Q, H, D). CUDA tensor, fp16/ bf16, last dim contiguous
86- - k, v: (B, K, H_kv, D). Same dtype/ device as q; GQA when H_kv < H
85+ - query: (B, Q, H, D). CUDA tensor, fp16/ bf16, last dim contiguous
86+ - key: (B, K, H_kv, D). Same dtype/ device as query; GQA when H_kv <= H
87+ - value: (B, K, H_kv, D). Same dtype/ device as query; GQA when H_kv <= H
8788- attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable
8889- attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable
8990- scale: score scaling; default 1 / sqrt(D)
@@ -137,20 +138,20 @@ Variable length attention for batches with mixed sequence lengths.
137138
138139```python
139140def flash_dmattn_varlen_func(
140- q : torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
141- k : torch.Tensor, # same layout as q
142- v : torch.Tensor, # same layout as q
143- attn_mask: Optional[torch.Tensor] = None , # (B, H, Q, K)
144- attn_bias: Optional[torch.Tensor] = None , # (B, H, Q, K)
145- cu_seqlens_q: torch.Tensor = None , # (B+1,)
146- cu_seqlens_k: torch.Tensor = None , # (B+1,)
141+ query : torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
142+ key : torch.Tensor, # same layout as query
143+ value : torch.Tensor, # same layout as query
144+ attn_mask: Optional[torch.Tensor] = None , # (B, H, Q, K)
145+ attn_bias: Optional[torch.Tensor] = None , # (B, H, Q, K)
146+ cu_seqlens_q: torch.Tensor = None , # (B+1,)
147+ cu_seqlens_k: torch.Tensor = None , # (B+1,)
147148 max_seqlen_q: int = None ,
148149 max_seqlen_k: int = None ,
149150 scale: Optional[float ] = None ,
150151 is_causal: Optional[bool ] = None ,
151- softcap: Optional[float ] = None , # CUDA-only
152- deterministic: Optional[bool ] = None , # CUDA-only
153- block_table: Optional[torch.Tensor] = None , # experimental: paged attention
152+ softcap: Optional[float ] = None , # CUDA-only
153+ deterministic: Optional[bool ] = None , # CUDA-only
154+ block_table: Optional[torch.Tensor] = None , # experimental: paged attention
154155) -> torch.Tensor
155156```
156157
@@ -386,6 +387,3 @@ print_memory_stats()
386387torch.cuda.empty_cache()
387388```
388389
389- -- -
390-
391- See also: `docs/ integration.md` and `benchmarks/ ` .
0 commit comments