Skip to content

Commit 099f0a1

Browse files
committed
Fix: local_window_size ignored when mask is None in dot_product_attention
Add check for local_window_size in _apply_masks early return condition. Previously, the function would skip masking when no explicit mask was provided, causing local_window_size to be ignored.
1 parent 20092b6 commit 099f0a1

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

jax/_src/nn/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def _get_padding_mask_encoded(T, q_seqlen):
893893

894894
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
895895
local_window_size):
896-
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
896+
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None:
897897
return logits
898898

899899
combined_mask = jnp.ones_like(logits, dtype=bool)

tests/nn_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,27 @@ def testLog1mExpGrad(self):
763763
atol=1e-3,
764764
)
765765

766+
def testDotProductAttention_localWindowSizeWithoutMask(self):
767+
dtype = jnp.float32
768+
B, S, T, N, H = 2, 128, 128, 4, 32
769+
keys = random.split(random.PRNGKey(0), 3)
770+
Q = random.normal(keys[0], (B, T, N, H), dtype)
771+
K = random.normal(keys[1], (B, S, N, H), dtype)
772+
V = random.normal(keys[2], (B, S, N, H), dtype)
773+
774+
output_large_window = nn.dot_product_attention(
775+
Q, K, V, mask=None, local_window_size=(32, 32)
776+
)
777+
778+
output_small_window = nn.dot_product_attention(
779+
Q, K, V, mask=None, local_window_size=(1, 1)
780+
)
781+
782+
self.assertFalse(
783+
jnp.allclose(output_large_window, output_small_window),
784+
"Attention output should differ with different local_window_size, even without a mask.",
785+
)
786+
766787

767788
InitializerRecord = collections.namedtuple(
768789
"InitializerRecord",

0 commit comments

Comments
 (0)