@@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
145145
146146 def generate (self , input_ids , streamer = None , interactive = False , ignore_prompt = False , stopping_criteria = None , ** generate_kwargs ):
147147 max_new_tokens = generate_kwargs .get ("max_new_tokens" , - 1 )
148+ self .batch_size = input_ids .shape [0 ]
148149 if self .model is None :
149- self .init_from_bin (self .model_type , self .bin_file , batch_size = input_ids . shape [ 0 ] ,
150+ self .init_from_bin (self .model_type , self .bin_file , batch_size = self . batch_size ,
150151 ** generate_kwargs )
151152 self .generate_round = 0
152153 elif not interactive :
@@ -160,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
160161 beam_search = False
161162 if (generate_kwargs .get ("num_beams" , 1 ) > 1 ) and not generate_kwargs .get ("do_sample" , False ):
162163 beam_search = True
163- if not beam_search :
164- # TODO support multi batch
165- assert input_ids .shape [0 ] == 1 , "Unsupport multi-batch input ids."
166164
167165 if streamer :
168166 assert input_ids .shape [0 ] == 1 , "Streamer only supports batch size 1."
@@ -190,9 +188,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
190188 if stopping_criteria is not None :
191189 if stopping_criteria (torch .tensor (ret ), None ):
192190 break
193- elif ret [0 ][- 1 ] == self .eos_token_id () or \
194- (max_new_tokens != - 1 and out_count >= max_new_tokens ):
191+ elif (max_new_tokens != - 1 and out_count >= max_new_tokens ):
195192 break
193+ else :
194+ all_done = [(r [- 1 ] in [self .eos_token_id (), self .pad_token_id ()]) for r in ret ]
195+ if False not in all_done :
196+ break
196197 if streamer :
197198 streamer .end ()
198199
@@ -206,6 +207,15 @@ def eos_token_id(self):
206207 if self .model_type == 'qwen' :
207208 return self .tokenizer .special_tokens ['<|endoftext|>' ]
208209 return self .tokenizer .eos_token_id
210+
211+ def pad_token_id (self ):
212+ if self .tokenizer .pad_token_id == None :
213+ if self .batch_size == 1 :
214+ return None
215+ else :
216+ raise ValueError ("Please set pad_token_id when doing multi batch inference" \
217+ " with padding!" )
218+ return self .tokenizer .pad_token_id
209219
210220 def __call__ (self , input_ids , reinit = False , ** kwargs ):
211221 if self .model is None :
0 commit comments