5757PAD_SEQUENCE_TO_MULTIPLE_OF = int (os .environ .get ("PAD_SEQUENCE_TO_MULTIPLE_OF" , 256 ))
5858CHUNK_SIZES = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 ]
5959LAZY_MODE = int (os .environ .get ("PT_HPU_LAZY_MODE" , 1 ))
60- BATCH_BUCKET_SIZE = int (os .environ .get ("BATCH_BUCKET_SIZE" , 8 ))
61- PREFILL_BATCH_BUCKET_SIZE = int (os .environ .get ("PREFILL_BATCH_BUCKET_SIZE" , 2 ))
60+ BATCH_SIZE_EXPONENT_BASE = int (os .environ .get ("BATCH_SIZE_EXPONENT_BASE" , 2 ))
6261MAX_BATCH_SIZE = (
6362 int (os .environ .get ("MAX_BATCH_SIZE" ))
6463 if os .environ .get ("MAX_BATCH_SIZE" ) is not None
@@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
7473 )
7574
7675
77- def round_up (number , k ):
76+ def round_up_seq (number , k ):
7877 return (number + k - 1 ) // k * k
7978
8079
80+ def round_up_batch (number ):
81+ return BATCH_SIZE_EXPONENT_BASE ** (
82+ math .ceil (math .log (number , BATCH_SIZE_EXPONENT_BASE ))
83+ )
84+
85+
8186def to_tensor_indices (indices , device ):
8287 return torch .tensor (indices , dtype = torch .long , device = device )
8388
@@ -399,7 +404,7 @@ def recombine(
399404
400405 total_requests = sum (len (b ) for b in batches )
401406 new_bs = total_requests
402- new_bs = round_up (total_requests , BATCH_BUCKET_SIZE )
407+ new_bs = round_up_batch (total_requests )
403408
404409 batch_id = batches [0 ].batch_id
405410 device = batches [0 ].input_ids .device
@@ -540,7 +545,7 @@ def from_pb(
540545 # TODO: by tokenizing all inputs at once we loose information on actual input lengths
541546 # this means that we cannot shift inputs to the left after a long input sequence
542547 # was filtered out
543- new_bs = round_up (len (requests ), PREFILL_BATCH_BUCKET_SIZE )
548+ new_bs = round_up_batch (len (requests ))
544549 missing_inputs = new_bs - len (inputs )
545550 dummy_inputs = ["?" ] * missing_inputs
546551 parameters = [r .parameters for r in pb .requests ]
@@ -572,7 +577,7 @@ def from_pb(
572577 assert (
573578 PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
574579 ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
575- rounded_seq_len = round_up (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
580+ rounded_seq_len = round_up_seq (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
576581 if rounded_seq_len <= max_input_length :
577582 bucket_size = rounded_seq_len - 1
578583 else :
@@ -1068,10 +1073,10 @@ def generate_token(
10681073 if (
10691074 self .enable_hpu_graph
10701075 and self .limit_hpu_graph
1071- and round_up (batch .batch_size , BATCH_BUCKET_SIZE ) != self .prev_bs
1076+ and round_up_batch (batch .batch_size ) != self .prev_bs
10721077 ):
10731078 self .model .clear_cache ()
1074- self .prev_bs = round_up (batch .batch_size , BATCH_BUCKET_SIZE )
1079+ self .prev_bs = round_up_batch (batch .batch_size )
10751080 dbg_trace (
10761081 scenario ,
10771082 f"bs:{ batch .batch_size } num_reqs:{ len (batch .requests )} seq_len:{ batch .seq_length } padding:{ batch .right_padding } " ,
@@ -1325,15 +1330,14 @@ def warmup(
13251330
13261331 # Warmup prefill batch_size
13271332 max_input_tokens = request .max_input_tokens
1333+ max_exp = math .ceil (math .log (max_prefill_batch_size , BATCH_SIZE_EXPONENT_BASE ))
13281334 prefill_batch_size_list = [
1329- batch
1330- for batch in range (
1331- PREFILL_BATCH_BUCKET_SIZE ,
1332- max_prefill_batch_size ,
1333- PREFILL_BATCH_BUCKET_SIZE ,
1335+ BATCH_SIZE_EXPONENT_BASE ** exp
1336+ for exp in range (
1337+ 0 ,
1338+ max_exp + 1 ,
13341339 )
13351340 ]
1336- prefill_batch_size_list .append (max_prefill_batch_size )
13371341 prefill_seqlen_list = [
13381342 seq
13391343 for seq in range (
@@ -1370,12 +1374,10 @@ def warmup(
13701374 )
13711375
13721376 max_decode_batch_size = math .floor (MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS )
1373- max_decode_batch_size = round_up ( max_decode_batch_size , BATCH_BUCKET_SIZE )
1377+ max_exp = math . ceil ( math . log ( max_decode_batch_size , BATCH_SIZE_EXPONENT_BASE ) )
13741378 decode_batch_size_list = [
1375- i
1376- for i in range (BATCH_BUCKET_SIZE , max_decode_batch_size , BATCH_BUCKET_SIZE )
1379+ BATCH_SIZE_EXPONENT_BASE ** exp for exp in range (0 , max_exp + 1 )
13771380 ]
1378- decode_batch_size_list .append (max_decode_batch_size )
13791381 decode_batch_size_list .sort (reverse = True )
13801382
13811383 try :
0 commit comments