Skip to content

Commit f837ebc

Browse files
Merge pull request #32357 from akshay-babbar:fix-local-window-size-masking
PiperOrigin-RevId: 842753992
2 parents 1a78757 + 099f0a1 commit f837ebc

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

jax/_src/nn/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def _get_padding_mask_encoded(T, q_seqlen):
880880

881881
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
882882
local_window_size):
883-
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
883+
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None:
884884
return logits
885885

886886
combined_mask = jnp.ones_like(logits, dtype=bool)

tests/nn_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,27 @@ def testLog1mExpGrad(self):
791791
atol=1e-3,
792792
)
793793

794+
def testDotProductAttention_localWindowSizeWithoutMask(self):
795+
dtype = jnp.float32
796+
B, S, T, N, H = 2, 128, 128, 4, 32
797+
keys = random.split(random.PRNGKey(0), 3)
798+
Q = random.normal(keys[0], (B, T, N, H), dtype)
799+
K = random.normal(keys[1], (B, S, N, H), dtype)
800+
V = random.normal(keys[2], (B, S, N, H), dtype)
801+
802+
output_large_window = nn.dot_product_attention(
803+
Q, K, V, mask=None, local_window_size=(32, 32)
804+
)
805+
806+
output_small_window = nn.dot_product_attention(
807+
Q, K, V, mask=None, local_window_size=(1, 1)
808+
)
809+
810+
self.assertFalse(
811+
jnp.allclose(output_large_window, output_small_window),
812+
"Attention output should differ with different local_window_size, even without a mask.",
813+
)
814+
794815

795816
InitializerRecord = collections.namedtuple(
796817
"InitializerRecord",

0 commit comments

Comments
 (0)