Skip to content

Commit 1f517c3

Browse files
committed
Enables comprehensive test configurations
Uncomments previously disabled test cases to run full suite of forward equivalence tests across various batch sizes, head configurations, sequence lengths, and causal/non-causal modes. Adds two new edge case configurations with very short sequence lengths to improve test coverage.
1 parent 2959585 commit 1f517c3

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

benchmarks/forward_equivalence.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -518,27 +518,28 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
518518
# If you encounter NAN issues when running multiple configurations, try running a single configuration
519519
test_configs = [
520520
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
521-
# (1, 1, 1, 64, 64, 32, True),
522-
# (1, 1, 1, 64, 64, 32, False),
523-
# (1, 1, 1, 128, 128, 32, True),
524-
# (1, 1, 1, 128, 128, 32, False),
525-
# (1, 1, 1, 256, 256, 32, True),
526-
# (1, 1, 1, 256, 256, 32, False),
527-
# (1, 1, 1, 512, 512, 32, True),
528-
# (1, 1, 1, 512, 512, 32, False),
529-
# (1, 1, 1, 1024, 1024, 32, True),
530-
# (1, 1, 1, 1024, 1024, 32, False),
531-
# (1, 1, 1, 2048, 2048, 32, True),
532-
# (1, 1, 1, 2048, 2048, 32, False),
521+
(1, 1, 1, 64, 64, 32, True),
522+
(1, 1, 1, 64, 64, 32, False),
523+
(1, 1, 1, 128, 128, 32, True),
524+
(1, 1, 1, 128, 128, 32, False),
525+
(1, 1, 1, 256, 256, 32, True),
526+
(1, 1, 1, 256, 256, 32, False),
527+
(1, 1, 1, 512, 512, 32, True),
528+
(1, 1, 1, 512, 512, 32, False),
529+
(1, 1, 1, 1024, 1024, 32, True),
530+
(1, 1, 1, 1024, 1024, 32, False),
531+
(1, 1, 1, 2048, 2048, 32, True),
532+
(1, 1, 1, 2048, 2048, 32, False),
533533
(1, 1, 1, 4096, 4096, 32, True),
534-
# (1, 1, 1, 4096, 4096, 32, False),
535-
# (1, 2, 1, 64, 64, 32, True),
536-
# (2, 1, 1, 128, 128, 32, True),
537-
# (2, 2, 1, 128, 128, 32, True),
538-
# (1, 2, 1, 64, 64, 128, True),
539-
# (1, 2, 1, 128, 128, 128, True),
540-
# (1, 2, 1, 256, 256, 128, True),
541-
# (1, 2, 1, 512, 512, 128, True),
534+
(1, 1, 1, 4096, 4096, 32, False),
535+
(1, 2, 1, 64, 64, 32, True),
536+
(2, 1, 1, 128, 128, 32, True),
537+
(2, 2, 1, 128, 128, 32, True),
538+
(1, 2, 1, 64, 64, 128, True),
539+
(1, 2, 1, 128, 128, 128, True),
540+
(1, 2, 1, 256, 256, 128, True),
541+
(1, 2, 1, 3, 512, 128, True),
542+
(1, 2, 1, 1, 512, 128, True),
542543
]
543544

544545
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)