Skip to content

Commit b28acd1

Browse files
committed
Streamlines benchmark suite structure and test scope
Removes obsolete MQAR benchmark configuration and reorganizes benchmark files with cleaner naming convention. Comments out extensive test configurations to focus on essential test cases, reducing test execution time while maintaining core functionality validation. Simplifies the benchmark suite to improve maintainability and development workflow efficiency.
1 parent 3d80c84 commit b28acd1

File tree

4 files changed

+76
-138
lines changed

4 files changed

+76
-138
lines changed

benchmarks/benchmark_mqar.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

benchmarks/benchmark_forward_equivalence.py renamed to benchmarks/forward_equivalence.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -518,27 +518,27 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
518518
# If you encounter NAN issues when running multiple configurations, try running a single configuration
519519
test_configs = [
520520
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
521-
(1, 1, 1, 64, 64, 32, True),
522-
(1, 1, 1, 64, 64, 32, False),
523-
(1, 1, 1, 128, 128, 32, True),
524-
(1, 1, 1, 128, 128, 32, False),
525-
(1, 1, 1, 256, 256, 32, True),
526-
(1, 1, 1, 256, 256, 32, False),
527-
(1, 1, 1, 512, 512, 32, True),
528-
(1, 1, 1, 512, 512, 32, False),
529-
(1, 1, 1, 1024, 1024, 32, True),
530-
(1, 1, 1, 1024, 1024, 32, False),
531-
(1, 1, 1, 2048, 2048, 32, True),
532-
(1, 1, 1, 2048, 2048, 32, False),
521+
# (1, 1, 1, 64, 64, 32, True),
522+
# (1, 1, 1, 64, 64, 32, False),
523+
# (1, 1, 1, 128, 128, 32, True),
524+
# (1, 1, 1, 128, 128, 32, False),
525+
# (1, 1, 1, 256, 256, 32, True),
526+
# (1, 1, 1, 256, 256, 32, False),
527+
# (1, 1, 1, 512, 512, 32, True),
528+
# (1, 1, 1, 512, 512, 32, False),
529+
# (1, 1, 1, 1024, 1024, 32, True),
530+
# (1, 1, 1, 1024, 1024, 32, False),
531+
# (1, 1, 1, 2048, 2048, 32, True),
532+
# (1, 1, 1, 2048, 2048, 32, False),
533533
(1, 1, 1, 4096, 4096, 32, True),
534-
(1, 1, 1, 4096, 4096, 32, False),
535-
(1, 2, 1, 64, 64, 32, True),
536-
(2, 1, 1, 128, 128, 32, True),
537-
(2, 2, 1, 128, 128, 32, True),
538-
(1, 2, 1, 64, 64, 128, True),
539-
(1, 2, 1, 128, 128, 128, True),
540-
(1, 2, 1, 256, 256, 128, True),
541-
(1, 2, 1, 512, 512, 128, True),
534+
# (1, 1, 1, 4096, 4096, 32, False),
535+
# (1, 2, 1, 64, 64, 32, True),
536+
# (2, 1, 1, 128, 128, 32, True),
537+
# (2, 2, 1, 128, 128, 32, True),
538+
# (1, 2, 1, 64, 64, 128, True),
539+
# (1, 2, 1, 128, 128, 128, True),
540+
# (1, 2, 1, 256, 256, 128, True),
541+
# (1, 2, 1, 512, 512, 128, True),
542542
]
543543

544544
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -1050,13 +1050,13 @@ def main():
10501050
print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍")
10511051
test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold)
10521052

1053-
if args.test_type in ['all', 'triton']:
1054-
print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
1055-
test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)
1053+
# if args.test_type in ['all', 'triton']:
1054+
# print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥")
1055+
# test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold)
10561056

1057-
if args.test_type in ['all', 'flex']:
1058-
print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
1059-
test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)
1057+
# if args.test_type in ['all', 'flex']:
1058+
# print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟")
1059+
# test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold)
10601060

10611061

10621062
# Print overall summary

benchmarks/benchmark_forward_performance.py renamed to benchmarks/forward_performance.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
732732
(1, 2, 1, 4096, 4096, 128, 2048, True),
733733
(1, 2, 1, 8192, 8192, 128, 2048, True),
734734
(1, 2, 1, 16384, 16384, 128, 2048, True),
735-
(1, 2, 1, 32768, 32768, 128, 2048, True),
736-
737-
# Inference
738-
(1, 2, 1, 2, 256, 128, 2048, True),
739-
(1, 2, 1, 2, 512, 128, 2048, True),
740-
(1, 2, 1, 2, 1024, 128, 2048, True),
741-
(1, 2, 1, 2, 2048, 128, 2048, True),
742-
(1, 2, 1, 2, 4096, 128, 2048, True),
743-
(1, 2, 1, 2, 8192, 128, 2048, True),
744-
(1, 2, 1, 2, 16384, 128, 2048, True),
745-
(1, 2, 1, 2, 32768, 128, 2048, True),
735+
# (1, 2, 1, 32768, 32768, 128, 2048, True),
736+
737+
# # Inference
738+
# (1, 2, 1, 2, 256, 128, 2048, True),
739+
# (1, 2, 1, 2, 512, 128, 2048, True),
740+
# (1, 2, 1, 2, 1024, 128, 2048, True),
741+
# (1, 2, 1, 2, 2048, 128, 2048, True),
742+
# (1, 2, 1, 2, 4096, 128, 2048, True),
743+
# (1, 2, 1, 2, 8192, 128, 2048, True),
744+
# (1, 2, 1, 2, 16384, 128, 2048, True),
745+
# (1, 2, 1, 2, 32768, 128, 2048, True),
746746
(1, 2, 1, 2, 65536, 128, 2048, True),
747-
(1, 2, 1, 2, 131072, 128, 2048, True),
748-
(1, 2, 1, 2, 262144, 128, 2048, True),
749-
(1, 2, 1, 2, 524288, 128, 2048, True),
750-
751-
# Vary batch size
752-
(1, 2, 1, 4096, 4096, 32, 2048, True),
753-
(2, 2, 1, 4096, 4096, 32, 2048, True),
754-
(4, 2, 1, 4096, 4096, 32, 2048, True),
755-
(8, 2, 1, 4096, 4096, 32, 2048, True),
756-
757-
# Vary head count
758-
(1, 1, 1, 4096, 4096, 32, 2048, True),
759-
(1, 2, 1, 4096, 4096, 32, 2048, True),
760-
(1, 4, 1, 4096, 4096, 32, 2048, True),
761-
(1, 8, 2, 4096, 4096, 32, 2048, True),
762-
763-
# Vary head dimension
764-
(1, 2, 1, 4096, 4096, 32, 2048, True),
765-
(1, 2, 1, 4096, 4096, 64, 2048, True),
766-
(1, 2, 1, 4096, 4096, 96, 2048, True),
767-
(1, 2, 1, 4096, 4096, 128, 2048, True),
768-
(1, 2, 1, 4096, 4096, 192, 2048, True),
769-
(1, 2, 1, 4096, 4096, 256, 2048, True),
770-
771-
# Vary keep_window_size
772-
(1, 2, 1, 32768, 32768, 128, 32, True),
773-
(1, 2, 1, 32768, 32768, 128, 64, True),
774-
(1, 2, 1, 32768, 32768, 128, 128, True),
775-
(1, 2, 1, 32768, 32768, 128, 256, True),
776-
(1, 2, 1, 32768, 32768, 128, 512, True),
777-
(1, 2, 1, 32768, 32768, 128, 1024, True),
778-
(1, 2, 1, 32768, 32768, 128, 2048, True),
779-
(1, 2, 1, 32768, 32768, 128, 4096, True),
780-
(1, 2, 1, 32768, 32768, 128, 8192, True),
781-
(1, 2, 1, 32768, 32768, 128, 16384, True),
782-
(1, 2, 1, 32768, 32768, 128, 32768, True),
783-
784-
# Test non-causal
785-
(1, 2, 1, 4096, 4096, 128, 2048, False),
747+
# (1, 2, 1, 2, 131072, 128, 2048, True),
748+
# (1, 2, 1, 2, 262144, 128, 2048, True),
749+
# (1, 2, 1, 2, 524288, 128, 2048, True),
750+
751+
# # Vary batch size
752+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
753+
# (2, 2, 1, 4096, 4096, 32, 2048, True),
754+
# (4, 2, 1, 4096, 4096, 32, 2048, True),
755+
# (8, 2, 1, 4096, 4096, 32, 2048, True),
756+
757+
# # Vary head count
758+
# (1, 1, 1, 4096, 4096, 32, 2048, True),
759+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
760+
# (1, 4, 1, 4096, 4096, 32, 2048, True),
761+
# (1, 8, 2, 4096, 4096, 32, 2048, True),
762+
763+
# # Vary head dimension
764+
# (1, 2, 1, 4096, 4096, 32, 2048, True),
765+
# (1, 2, 1, 4096, 4096, 64, 2048, True),
766+
# (1, 2, 1, 4096, 4096, 96, 2048, True),
767+
# (1, 2, 1, 4096, 4096, 128, 2048, True),
768+
# (1, 2, 1, 4096, 4096, 192, 2048, True),
769+
# (1, 2, 1, 4096, 4096, 256, 2048, True),
770+
771+
# # Vary keep_window_size
772+
# (1, 2, 1, 32768, 32768, 128, 32, True),
773+
# (1, 2, 1, 32768, 32768, 128, 64, True),
774+
# (1, 2, 1, 32768, 32768, 128, 128, True),
775+
# (1, 2, 1, 32768, 32768, 128, 256, True),
776+
# (1, 2, 1, 32768, 32768, 128, 512, True),
777+
# (1, 2, 1, 32768, 32768, 128, 1024, True),
778+
# (1, 2, 1, 32768, 32768, 128, 2048, True),
779+
# (1, 2, 1, 32768, 32768, 128, 4096, True),
780+
# (1, 2, 1, 32768, 32768, 128, 8192, True),
781+
# (1, 2, 1, 32768, 32768, 128, 16384, True),
782+
# (1, 2, 1, 32768, 32768, 128, 32768, True),
783+
784+
# # Test non-causal
785+
# (1, 2, 1, 4096, 4096, 128, 2048, False),
786786
]
787787

788788
print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")

0 commit comments

Comments
 (0)