@@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
732732 (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , True ),
733733 (1 , 2 , 1 , 8192 , 8192 , 128 , 2048 , True ),
734734 (1 , 2 , 1 , 16384 , 16384 , 128 , 2048 , True ),
735- (1 , 2 , 1 , 32768 , 32768 , 128 , 2048 , True ),
736-
737- # Inference
738- (1 , 2 , 1 , 2 , 256 , 128 , 2048 , True ),
739- (1 , 2 , 1 , 2 , 512 , 128 , 2048 , True ),
740- (1 , 2 , 1 , 2 , 1024 , 128 , 2048 , True ),
741- (1 , 2 , 1 , 2 , 2048 , 128 , 2048 , True ),
742- (1 , 2 , 1 , 2 , 4096 , 128 , 2048 , True ),
743- (1 , 2 , 1 , 2 , 8192 , 128 , 2048 , True ),
744- (1 , 2 , 1 , 2 , 16384 , 128 , 2048 , True ),
745- (1 , 2 , 1 , 2 , 32768 , 128 , 2048 , True ),
735+ # (1, 2, 1, 32768, 32768, 128, 2048, True),
736+
737+ # # Inference
738+ # (1, 2, 1, 2, 256, 128, 2048, True),
739+ # (1, 2, 1, 2, 512, 128, 2048, True),
740+ # (1, 2, 1, 2, 1024, 128, 2048, True),
741+ # (1, 2, 1, 2, 2048, 128, 2048, True),
742+ # (1, 2, 1, 2, 4096, 128, 2048, True),
743+ # (1, 2, 1, 2, 8192, 128, 2048, True),
744+ # (1, 2, 1, 2, 16384, 128, 2048, True),
745+ # (1, 2, 1, 2, 32768, 128, 2048, True),
746746 (1 , 2 , 1 , 2 , 65536 , 128 , 2048 , True ),
747- (1 , 2 , 1 , 2 , 131072 , 128 , 2048 , True ),
748- (1 , 2 , 1 , 2 , 262144 , 128 , 2048 , True ),
749- (1 , 2 , 1 , 2 , 524288 , 128 , 2048 , True ),
750-
751- # Vary batch size
752- (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
753- (2 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
754- (4 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
755- (8 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
756-
757- # Vary head count
758- (1 , 1 , 1 , 4096 , 4096 , 32 , 2048 , True ),
759- (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
760- (1 , 4 , 1 , 4096 , 4096 , 32 , 2048 , True ),
761- (1 , 8 , 2 , 4096 , 4096 , 32 , 2048 , True ),
762-
763- # Vary head dimension
764- (1 , 2 , 1 , 4096 , 4096 , 32 , 2048 , True ),
765- (1 , 2 , 1 , 4096 , 4096 , 64 , 2048 , True ),
766- (1 , 2 , 1 , 4096 , 4096 , 96 , 2048 , True ),
767- (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , True ),
768- (1 , 2 , 1 , 4096 , 4096 , 192 , 2048 , True ),
769- (1 , 2 , 1 , 4096 , 4096 , 256 , 2048 , True ),
770-
771- # Vary keep_window_size
772- (1 , 2 , 1 , 32768 , 32768 , 128 , 32 , True ),
773- (1 , 2 , 1 , 32768 , 32768 , 128 , 64 , True ),
774- (1 , 2 , 1 , 32768 , 32768 , 128 , 128 , True ),
775- (1 , 2 , 1 , 32768 , 32768 , 128 , 256 , True ),
776- (1 , 2 , 1 , 32768 , 32768 , 128 , 512 , True ),
777- (1 , 2 , 1 , 32768 , 32768 , 128 , 1024 , True ),
778- (1 , 2 , 1 , 32768 , 32768 , 128 , 2048 , True ),
779- (1 , 2 , 1 , 32768 , 32768 , 128 , 4096 , True ),
780- (1 , 2 , 1 , 32768 , 32768 , 128 , 8192 , True ),
781- (1 , 2 , 1 , 32768 , 32768 , 128 , 16384 , True ),
782- (1 , 2 , 1 , 32768 , 32768 , 128 , 32768 , True ),
783-
784- # Test non-causal
785- (1 , 2 , 1 , 4096 , 4096 , 128 , 2048 , False ),
747+ # (1, 2, 1, 2, 131072, 128, 2048, True),
748+ # (1, 2, 1, 2, 262144, 128, 2048, True),
749+ # (1, 2, 1, 2, 524288, 128, 2048, True),
750+
751+ # # Vary batch size
752+ # (1, 2, 1, 4096, 4096, 32, 2048, True),
753+ # (2, 2, 1, 4096, 4096, 32, 2048, True),
754+ # (4, 2, 1, 4096, 4096, 32, 2048, True),
755+ # (8, 2, 1, 4096, 4096, 32, 2048, True),
756+
757+ # # Vary head count
758+ # (1, 1, 1, 4096, 4096, 32, 2048, True),
759+ # (1, 2, 1, 4096, 4096, 32, 2048, True),
760+ # (1, 4, 1, 4096, 4096, 32, 2048, True),
761+ # (1, 8, 2, 4096, 4096, 32, 2048, True),
762+
763+ # # Vary head dimension
764+ # (1, 2, 1, 4096, 4096, 32, 2048, True),
765+ # (1, 2, 1, 4096, 4096, 64, 2048, True),
766+ # (1, 2, 1, 4096, 4096, 96, 2048, True),
767+ # (1, 2, 1, 4096, 4096, 128, 2048, True),
768+ # (1, 2, 1, 4096, 4096, 192, 2048, True),
769+ # (1, 2, 1, 4096, 4096, 256, 2048, True),
770+
771+ # # Vary keep_window_size
772+ # (1, 2, 1, 32768, 32768, 128, 32, True),
773+ # (1, 2, 1, 32768, 32768, 128, 64, True),
774+ # (1, 2, 1, 32768, 32768, 128, 128, True),
775+ # (1, 2, 1, 32768, 32768, 128, 256, True),
776+ # (1, 2, 1, 32768, 32768, 128, 512, True),
777+ # (1, 2, 1, 32768, 32768, 128, 1024, True),
778+ # (1, 2, 1, 32768, 32768, 128, 2048, True),
779+ # (1, 2, 1, 32768, 32768, 128, 4096, True),
780+ # (1, 2, 1, 32768, 32768, 128, 8192, True),
781+ # (1, 2, 1, 32768, 32768, 128, 16384, True),
782+ # (1, 2, 1, 32768, 32768, 128, 32768, True),
783+
784+ # # Test non-causal
785+ # (1, 2, 1, 4096, 4096, 128, 2048, False),
786786 ]
787787
788788 print (f"\n 📊 Benchmark Results (averaged over { num_runs } runs):" )
0 commit comments