Skip to content

Commit f6e910a

Browse files
committed
Standardizes parameter naming across attention functions
Updates parameter names to use consistent naming conventions across CUDA, Triton, and Flex attention implementations. Changes 'softmax_scale' to 'scale' and converts positional arguments to keyword arguments for better API consistency and clarity. Fixes tensor dimension ordering in Flex attention by adding transpose operations to match expected input format.
1 parent e5eb029 commit f6e910a

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

benchmarks/benchmark_forward_equivalence.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ def dynamic_mask_attention_cuda(
257257
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
258258
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
259259
dropout_p=0.0,
260-
softmax_scale=scaling,
261260
is_causal=is_causal,
261+
scale=scaling,
262262
softcap=0.0,
263263
deterministic=True,
264264
return_attn_probs=return_softmax
@@ -331,10 +331,10 @@ def dynamic_mask_attention_triton(
331331
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
332332
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
333333
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
334-
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
335-
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
336-
is_causal, # causal masking
337-
scaling # scaling factor
334+
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
335+
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
336+
is_causal=is_causal, # causal masking
337+
scale=scaling # scaling factor
338338
)
339339

340340
return attn_outputs # [batch, query_len, num_heads, head_dim]
@@ -396,14 +396,14 @@ def dynamic_mask_attention_flex(
396396
# But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format
397397

398398
# Call the Flex Attention implementation
399-
attn_outputs, _ = flex_dmattn_func(
400-
query_states, # q: [batch, num_heads, query_len, head_dim]
401-
key_states, # k: [batch, num_heads, key_len, head_dim]
402-
value_states, # v: [batch, num_heads, key_len, head_dim]
403-
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
404-
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
405-
is_causal=is_causal, # is_causal: whether to apply causal masking
406-
scaling=scaling # scaling factor
399+
attn_outputs = flex_dmattn_func(
400+
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
401+
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
402+
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
403+
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
404+
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
405+
is_causal=is_causal, # is_causal: whether to apply causal masking
406+
scale=scaling # scaling factor
407407
)
408408

409409
return attn_outputs # [batch, query_len, num_heads, head_dim]

0 commit comments

Comments
 (0)