Skip to content

Commit 4edb7a8

Browse files
committed
Improves API parameter naming consistency
Renames q/k/v parameters to query/key/value in flash attention functions for better readability and standardization. Updates parameter documentation to reflect the new naming convention and fixes GQA condition description to use <= instead of <. Removes outdated footer reference to integration docs.
1 parent 4e29d06 commit 4edb7a8

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

docs/api_reference.md

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,23 @@ Main attention function. Supports multi-head and grouped-query attention (when t
6868

6969
```python
7070
def flash_dmattn_func(
71-
q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
72-
k: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
73-
v: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
74-
attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
75-
attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
76-
scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim)
77-
is_causal: Optional[bool] = None, # causal mask
78-
softcap: Optional[float] = None, # CUDA-only
79-
deterministic: Optional[bool] = None, # CUDA-only
71+
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
72+
key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
73+
value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
74+
attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
75+
attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k)
76+
scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim)
77+
is_causal: Optional[bool] = None, # causal mask
78+
softcap: Optional[float] = None, # CUDA-only
79+
deterministic: Optional[bool] = None, # CUDA-only
8080
) -> torch.Tensor
8181
```
8282

8383
#### Parameters
8484

85-
- q: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous
86-
- k, v: (B, K, H_kv, D). Same dtype/device as q; GQA when H_kv < H
85+
- query: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous
86+
- key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
87+
- value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H
8788
- attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable
8889
- attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable
8990
- scale: score scaling; default 1/sqrt(D)
@@ -137,20 +138,20 @@ Variable length attention for batches with mixed sequence lengths.
137138

138139
```python
139140
def flash_dmattn_varlen_func(
140-
q: torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
141-
k: torch.Tensor, # same layout as q
142-
v: torch.Tensor, # same layout as q
143-
attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K)
144-
attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K)
145-
cu_seqlens_q: torch.Tensor = None, # (B+1,)
146-
cu_seqlens_k: torch.Tensor = None, # (B+1,)
141+
query: torch.Tensor, # (total_q, H, D) or (B, Q, H, D)
142+
key: torch.Tensor, # same layout as query
143+
value: torch.Tensor, # same layout as query
144+
attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K)
145+
attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K)
146+
cu_seqlens_q: torch.Tensor = None, # (B+1,)
147+
cu_seqlens_k: torch.Tensor = None, # (B+1,)
147148
max_seqlen_q: int = None,
148149
max_seqlen_k: int = None,
149150
scale: Optional[float] = None,
150151
is_causal: Optional[bool] = None,
151-
softcap: Optional[float] = None, # CUDA-only
152-
deterministic: Optional[bool] = None, # CUDA-only
153-
block_table: Optional[torch.Tensor] = None, # experimental: paged attention
152+
softcap: Optional[float] = None, # CUDA-only
153+
deterministic: Optional[bool] = None, # CUDA-only
154+
block_table: Optional[torch.Tensor] = None, # experimental: paged attention
154155
) -> torch.Tensor
155156
```
156157

@@ -386,6 +387,3 @@ print_memory_stats()
386387
torch.cuda.empty_cache()
387388
```
388389

389-
---
390-
391-
See also: `docs/integration.md` and `benchmarks/`.

0 commit comments

Comments
 (0)