@@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
732732 (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , True ),
733733 (1 , 2 , 1 , 8192 , 8192 , 128 , 2048 , True ),
734734 (1 , 2 , 1 , 16384 , 16384 , 128 , 2048 , True ),
735- # (1, 2, 1, 32768, 32768, 128, 2048, True),
735+ (1 , 2 , 1 , 32768 , 32768 , 128 , 2048 , True ),
736736
737737 # # Inference
738- # (1, 2, 1, 2 , 256, 128, 2048, True),
739- # (1, 2, 1, 2 , 512, 128, 2048, True),
740- # (1, 2, 1, 2 , 1024, 128, 2048, True),
741- # (1, 2, 1, 2 , 2048, 128, 2048, True),
742- # (1, 2, 1, 2 , 4096, 128, 2048, True),
743- # (1, 2, 1, 2 , 8192, 128, 2048, True),
744- # (1, 2, 1, 2 , 16384, 128, 2048, True),
745- # (1, 2, 1, 2 , 32768, 128, 2048, True),
746- (1 , 2 , 1 , 2 , 65536 , 128 , 2048 , True ),
747- # (1, 2, 1, 2 , 131072, 128, 2048, True),
748- # (1, 2, 1, 2 , 262144, 128, 2048, True),
749- # (1, 2, 1, 2 , 524288, 128, 2048, True),
750-
751- # # Vary batch size
752- # (1, 2, 1, 4096, 4096, 32, 2048, True),
753- # (2, 2, 1, 4096, 4096, 32, 2048, True),
754- # (4, 2, 1, 4096, 4096, 32, 2048, True),
755- # (8, 2, 1, 4096, 4096, 32, 2048, True),
756-
757- # # Vary head count
758- # (1, 1, 1, 4096, 4096, 32, 2048, True),
759- # (1, 2, 1, 4096, 4096, 32, 2048, True),
760- # (1, 4, 1, 4096, 4096, 32, 2048, True),
761- # (1, 8, 2, 4096, 4096, 32, 2048, True),
762-
763- # # Vary head dimension
764- # (1, 2, 1, 4096, 4096, 32, 2048, True),
765- # (1, 2, 1, 4096, 4096, 64, 2048, True),
766- # (1, 2, 1, 4096, 4096, 96, 2048, True),
767- # (1, 2, 1, 4096, 4096, 128, 2048, True),
768- # (1, 2, 1, 4096, 4096, 192, 2048, True),
769- # (1, 2, 1, 4096, 4096, 256, 2048, True),
770-
771- # # Vary keep_window_size
772- # (1, 2, 1, 32768, 32768, 128, 32, True),
773- # (1, 2, 1, 32768, 32768, 128, 64, True),
774- # (1, 2, 1, 32768, 32768, 128, 128, True),
775- # (1, 2, 1, 32768, 32768, 128, 256, True),
776- # (1, 2, 1, 32768, 32768, 128, 512, True),
777- # (1, 2, 1, 32768, 32768, 128, 1024, True),
778- # (1, 2, 1, 32768, 32768, 128, 2048, True),
779- # (1, 2, 1, 32768, 32768, 128, 4096, True),
780- # (1, 2, 1, 32768, 32768, 128, 8192, True),
781- # (1, 2, 1, 32768, 32768, 128, 16384, True),
782- # (1, 2, 1, 32768, 32768, 128, 32768, True),
783-
784- # # Test non-causal
785- # (1, 2, 1, 4096, 4096, 128, 2048, False),
738+ (1 , 2 , 1 , 1 , 256 , 128 , 2048 , True ),
739+ (1 , 2 , 1 , 1 , 512 , 128 , 2048 , True ),
740+ (1 , 2 , 1 , 1 , 1024 , 128 , 2048 , True ),
741+ (1 , 2 , 1 , 1 , 2048 , 128 , 2048 , True ),
742+ (1 , 2 , 1 , 1 , 4096 , 128 , 2048 , True ),
743+ (1 , 2 , 1 , 1 , 8192 , 128 , 2048 , True ),
744+ (1 , 2 , 1 , 1 , 16384 , 128 , 2048 , True ),
745+ (1 , 2 , 1 , 1 , 32768 , 128 , 2048 , True ),
746+ (1 , 2 , 1 , 1 , 65536 , 128 , 2048 , True ),
747+ (1 , 2 , 1 , 1 , 131072 , 128 , 2048 , True ),
748+ (1 , 2 , 1 , 1 , 262144 , 128 , 2048 , True ),
749+ (1 , 2 , 1 , 1 , 524288 , 128 , 2048 , True ),
750+
751+ # Vary batch size
752+ (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
753+ (2 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
754+ (4 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
755+ (8 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
756+
757+ # Vary head count
758+ (1 , 1 , 1 , 4096 , 4096 , 32 , 2048 , True ),
759+ (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
760+ (1 , 4 , 1 , 4096 , 4096 , 32 , 2048 , True ),
761+ (1 , 8 , 2 , 4096 , 4096 , 32 , 2048 , True ),
762+
763+ # Vary head dimension
764+ (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
765+ (1 , 2 , 1 , 4096 , 4096 , 64 , 2048 , True ),
766+ (1 , 2 , 1 , 4096 , 4096 , 96 , 2048 , True ),
767+ (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , True ),
768+ (1 , 2 , 1 , 4096 , 4096 , 192 , 2048 , True ),
769+ (1 , 2 , 1 , 4096 , 4096 , 256 , 2048 , True ),
770+
771+ # Vary keep_window_size
772+ (1 , 2 , 1 , 32768 , 32768 , 128 , 32 , True ),
773+ (1 , 2 , 1 , 32768 , 32768 , 128 , 64 , True ),
774+ (1 , 2 , 1 , 32768 , 32768 , 128 , 128 , True ),
775+ (1 , 2 , 1 , 32768 , 32768 , 128 , 256 , True ),
776+ (1 , 2 , 1 , 32768 , 32768 , 128 , 512 , True ),
777+ (1 , 2 , 1 , 32768 , 32768 , 128 , 1024 , True ),
778+ (1 , 2 , 1 , 32768 , 32768 , 128 , 2048 , True ),
779+ (1 , 2 , 1 , 32768 , 32768 , 128 , 4096 , True ),
780+ (1 , 2 , 1 , 32768 , 32768 , 128 , 8192 , True ),
781+ (1 , 2 , 1 , 32768 , 32768 , 128 , 16384 , True ),
782+ (1 , 2 , 1 , 32768 , 32768 , 128 , 32768 , True ),
783+
784+ # Test non-causal
785+ (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , False ),
786786 ]
787787
788788 print (f"\n 📊 Benchmark Results (averaged over { num_runs } runs):" )
0 commit comments