Skip to content

Commit 39ddb40

Browse files
committed
Standardizes parameter naming across attention functions
Changes `softmax_scale` to `scale` parameter name for consistency across CUDA, Triton, and Flex attention implementations. Updates Flex attention to use keyword arguments and adds tensor transposes to match expected input format. Removes unused return value from Flex attention call to align with other implementations.
1 parent f6e910a commit 39ddb40

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

benchmarks/benchmark_forward_performance.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def dynamic_mask_attention_cuda(
266266
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
267267
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
268268
dropout_p=0.0,
269-
softmax_scale=scaling,
270269
is_causal=is_causal,
270+
scale=scaling,
271271
softcap=0.0,
272272
deterministic=True,
273273
return_attn_probs=return_softmax
@@ -350,10 +350,10 @@ def dynamic_mask_attention_triton(
350350
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
351351
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
352352
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
353-
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
354-
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
355-
is_causal, # causal masking
356-
scaling # scaling factor
353+
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
354+
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
355+
is_causal=is_causal, # causal masking
356+
scale=scaling # scaling factor
357357
)
358358

359359
torch.cuda.synchronize()
@@ -425,14 +425,14 @@ def dynamic_mask_attention_flex(
425425
start_time = time.time()
426426

427427
# Call the Flex Attention implementation
428-
attn_outputs, _ = flex_dmattn_func(
429-
query_states, # q: [batch, num_heads, query_len, head_dim]
430-
key_states, # k: [batch, num_heads, key_len, head_dim]
431-
value_states, # v: [batch, num_heads, key_len, head_dim]
432-
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
433-
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
434-
is_causal=is_causal, # is_causal: Whether to apply causal masking
435-
scaling=scaling # scaling factor
428+
attn_outputs = flex_dmattn_func(
429+
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
430+
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
431+
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
432+
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
433+
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
434+
is_causal=is_causal, # is_causal: whether to apply causal masking
435+
scale=scaling # scaling factor
436436
)
437437

438438
torch.cuda.synchronize()

0 commit comments

Comments
 (0)