Skip to content

Commit a477e9f

Browse files
authored
Streamlines benchmark suite structure and test scope
2 parents 69de751 + 7099f37 commit a477e9f

File tree

4 files changed

+76
-138
lines changed

4 files changed

+76
-138
lines changed

benchmarks/benchmark_mqar.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

benchmarks/benchmark_forward_equivalence.py renamed to benchmarks/forward_equivalence.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -518,27 +518,27 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
518518
# If you encounter NAN issues when running multiple configurations, try running a single configuration
519519
test_configs = [
520520
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
521-
(1, 1, 1, 64, 64, 32, True),
522-
(1, 1, 1, 64, 64, 32, False),
523-
(1, 1, 1, 128, 128, 32, True),
524-
(1, 1, 1, 128, 128, 32, False),
525-
(1, 1, 1, 256, 256, 32, True),
526-
(1, 1, 1, 256, 256, 32, False),
527-
(1, 1, 1, 512, 512, 32, True),
528-
(1, 1, 1, 512, 512, 32, False),
529-
(1, 1, 1, 1024, 1024, 32, True),
530-
(1, 1, 1, 1024, 1024, 32, False),
531-
(1, 1, 1, 2048, 2048, 32, True),
532-
(1, 1, 1, 2048, 2048, 32, False),
521+
# (1, 1, 1, 64, 64, 32, True),
522+
# (1, 1, 1, 64, 64, 32, False),
523+
# (1, 1, 1, 128, 128, 32, True),
524+
# (1, 1, 1, 128, 128, 32, False),
525+
# (1, 1, 1, 256, 256, 32, True),
526+
# (1, 1, 1, 256, 256, 32, False),
527+
# (1, 1, 1, 512, 512, 32, True),
528+
# (1, 1, 1, 512, 512, 32, False),
529+
# (1, 1, 1, 1024, 1024, 32, True),
530+
# (1, 1, 1, 1024, 1024, 32, False),
531+
# (1, 1, 1, 2048, 2048, 32, True),
532+
# (1, 1, 1, 2048, 2048, 32, False),
533533
(1, 1, 1, 4096, 4096, 32, True),
534-
(1, 1, 1, 4096, 4096, 32, False),
535-
(1, 2, 1, 64, 64, 32, True),
536-
(2, 1, 1, 128, 128, 32, True),
537-
(2, 2, 1, 128, 128, 32, True),
538-
(1, 2, 1, 64, 64, 128, True),
539-
(1, 2, 1, 128, 128, 128, True),
540-
(1, 2, 1, 256, 256, 128, True),
541-
(1, 2, 1, 512, 512, 128, True),
534+
# (1, 1, 1, 4096, 4096, 32, False),
535+
# (1, 2, 1, 64, 64, 32, True),
536+
# (2, 1, 1, 128, 128, 32, True),
537+
# (2, 2, 1, 128, 128, 32, True),
538+
# (1, 2, 1, 64, 64, 128, True),
539+
# (1, 2, 1, 128, 128, 128, True),
540+
# (1, 2, 1, 256, 256, 128, True),
541+
# (1, 2, 1, 512, 512, 128, True),
542542
]
543543

544544
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -1050,13 +1050,13 @@ def main():
10501050
print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍")
10511051
test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold)
10521052

1053-
if args.test_type in ['all', 'triton']:
1054-
print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
1055-
test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)
1053+
# if args.test_type in ['all', 'triton']:
1054+
# print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
1055+
# test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)
10561056

1057-
if args.test_type in ['all', 'flex']:
1058-
print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
1059-
test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)
1057+
# if args.test_type in ['all', 'flex']:
1058+
# print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
1059+
# test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)
10601060

10611061

10621062
# Print overall summary

benchmarks/benchmark_forward_performance.py renamed to benchmarks/forward_performance.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
732732
(1, 2, 1, 4096, 4096, 128, 2048, True),
733733
(1, 2, 1, 8192, 8192, 128, 2048, True),
734734
(1, 2, 1, 16384, 16384, 128, 2048, True),
735-
(1, 2, 1, 32768, 32768, 128, 2048, True),
736-
737-
# Inference
738-
(1, 2, 1, 2, 256, 128, 2048, True),
739-
(1, 2, 1, 2, 512, 128, 2048, True),
740-
(1, 2, 1, 2, 1024, 128, 2048, True),
741-
(1, 2, 1, 2, 2048, 128, 2048, True),
742-
(1, 2, 1, 2, 4096, 128, 2048, True),
743-
(1, 2, 1, 2, 8192, 128, 2048, True),
744-
(1, 2, 1, 2, 16384, 128, 2048, True),
745-
(1, 2, 1, 2, 32768, 128, 2048, True),
735+
# (1, 2, 1, 32768, 32768, 128, 2048, True),
736+
737+
# # Inference
738+
# (1, 2, 1, 2, 256, 128, 2048, True),
739+
# (1, 2, 1, 2, 512, 128, 2048, True),
740+
# (1, 2, 1, 2, 1024, 128, 2048, True),
741+
# (1, 2, 1, 2, 2048, 128, 2048, True),
742+
# (1, 2, 1, 2, 4096, 128, 2048, True),
743+
# (1, 2, 1, 2, 8192, 128, 2048, True),
744+
# (1, 2, 1, 2, 16384, 128, 2048, True),
745+
# (1, 2, 1, 2, 32768, 128, 2048, True),
746746
(1, 2, 1, 2, 65536, 128, 2048, True),
747-
(1, 2, 1, 2, 131072, 128, 2048, True),
748-
(1, 2, 1, 2, 262144, 128, 2048, True),
749-
(1, 2, 1, 2, 524288, 128, 2048, True),
750-
751-
# Vary batch size
752-
(1, 2, 1, 4096, 4096, 32, 2048, True),
753-
(2, 2, 1, 4096, 4096, 32, 2048, True),
754-
(4, 2, 1, 4096, 4096, 32, 2048, True),
755-
(8, 2, 1, 4096, 4096, 32, 2048, True),
756-
757-
# Vary head count
758-
(1, 1, 1, 4096, 4096, 32, 2048, True),
759-
(1, 2, 1, 4096, 4096, 32, 2048, True),
760-
(1, 4, 1, 4096, 4096, 32, 2048, True),
761-
(1, 8, 2, 4096, 4096, 32, 2048, True),
762-
763-
# Vary head dimension
764-
(1, 2, 1, 4096, 4096, 32, 2048, True),
765-
(1, 2, 1, 4096, 4096, 64, 2048, True),
766-
(1, 2, 1, 4096, 4096, 96, 2048, True),
767-
(1, 2, 1, 4096, 4096, 128, 2048, True),
768-
(1, 2, 1, 4096, 4096, 192, 2048, True),
769-
(1, 2, 1, 4096, 4096, 256, 2048, True),
770-
771-
# Vary keep_window_size
772-
(1, 2, 1, 32768, 32768, 128, 32, True),
773-
(1, 2, 1, 32768, 32768, 128, 64, True),
774-
(1, 2, 1, 32768, 32768, 128, 128, True),
775-
(1, 2, 1, 32768, 32768, 128, 256, True),
776-
(1, 2, 1, 32768, 32768, 128, 512, True),
777-
(1, 2, 1, 32768, 32768, 128, 1024, True),
778-
(1, 2, 1, 32768, 32768, 128, 2048, True),
779-
(1, 2, 1, 32768, 32768, 128, 4096, True),
780-
(1, 2, 1, 32768, 32768, 128, 8192, True),
781-
(1, 2, 1, 32768, 32768, 128, 16384, True),
782-
(1, 2, 1, 32768, 32768, 128, 32768, True),
783-
784-
# Test non-causal
785-
(1, 2, 1, 4096, 4096, 128, 2048, False),
747+
# (1, 2, 1, 2, 131072, 128, 2048, True),
748+
# (1, 2, 1, 2, 262144, 128, 2048, True),
749+
# (1, 2, 1, 2, 524288, 128, 2048, True),
750+
751+
# # Vary batch size
752+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
753+
# (2, 2, 1, 4096, 4096, 32, 2048, True),
754+
# (4, 2, 1, 4096, 4096, 32, 2048, True),
755+
# (8, 2, 1, 4096, 4096, 32, 2048, True),
756+
757+
# # Vary head count
758+
# (1, 1, 1, 4096, 4096, 32, 2048, True),
759+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
760+
# (1, 4, 1, 4096, 4096, 32, 2048, True),
761+
# (1, 8, 2, 4096, 4096, 32, 2048, True),
762+
763+
# # Vary head dimension
764+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
765+
# (1, 2, 1, 4096, 4096, 64, 2048, True),
766+
# (1, 2, 1, 4096, 4096, 96, 2048, True),
767+
# (1, 2, 1, 4096, 4096, 128, 2048, True),
768+
# (1, 2, 1, 4096, 4096, 192, 2048, True),
769+
# (1, 2, 1, 4096, 4096, 256, 2048, True),
770+
771+
# # Vary keep_window_size
772+
# (1, 2, 1, 32768, 32768, 128, 32, True),
773+
# (1, 2, 1, 32768, 32768, 128, 64, True),
774+
# (1, 2, 1, 32768, 32768, 128, 128, True),
775+
# (1, 2, 1, 32768, 32768, 128, 256, True),
776+
# (1, 2, 1, 32768, 32768, 128, 512, True),
777+
# (1, 2, 1, 32768, 32768, 128, 1024, True),
778+
# (1, 2, 1, 32768, 32768, 128, 2048, True),
779+
# (1, 2, 1, 32768, 32768, 128, 4096, True),
780+
# (1, 2, 1, 32768, 32768, 128, 8192, True),
781+
# (1, 2, 1, 32768, 32768, 128, 16384, True),
782+
# (1, 2, 1, 32768, 32768, 128, 32768, True),
783+
784+
# # Test non-causal
785+
# (1, 2, 1, 4096, 4096, 128, 2048, False),
786786
]
787787

788788
print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")

0 commit comments

Comments
 (0)