We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 1a78757 + 099f0a1 commit f837ebcCopy full SHA for f837ebc
jax/_src/nn/functions.py
@@ -880,7 +880,7 @@ def _get_padding_mask_encoded(T, q_seqlen):
880
881
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
882
local_window_size):
883
- if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
+ if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None:
884
return logits
885
886
combined_mask = jnp.ones_like(logits, dtype=bool)
tests/nn_test.py
@@ -791,6 +791,27 @@ def testLog1mExpGrad(self):
791
atol=1e-3,
792
)
793
794
+ def testDotProductAttention_localWindowSizeWithoutMask(self):
795
+ dtype = jnp.float32
796
+ B, S, T, N, H = 2, 128, 128, 4, 32
797
+ keys = random.split(random.PRNGKey(0), 3)
798
+ Q = random.normal(keys[0], (B, T, N, H), dtype)
799
+ K = random.normal(keys[1], (B, S, N, H), dtype)
800
+ V = random.normal(keys[2], (B, S, N, H), dtype)
801
+
802
+ output_large_window = nn.dot_product_attention(
803
+ Q, K, V, mask=None, local_window_size=(32, 32)
804
+ )
805
806
+ output_small_window = nn.dot_product_attention(
807
+ Q, K, V, mask=None, local_window_size=(1, 1)
808
809
810
+ self.assertFalse(
811
+ jnp.allclose(output_large_window, output_small_window),
812
+ "Attention output should differ with different local_window_size, even without a mask.",
813
814
815
816
InitializerRecord = collections.namedtuple(
817
"InitializerRecord",
0 commit comments