@@ -133,6 +133,15 @@ def get_hf_token():
133133 return None
134134
135135
136+ class _StopOnToken (StoppingCriteria ):
137+ """Stop generation when any ID in `stop_ids` is produced."""
138+ def __init__ (self , stop_ids ):
139+ self .stop_ids = set (stop_ids )
140+
141+ def __call__ (self , input_ids , scores , ** kwargs ):
142+ return input_ids [0 , - 1 ].item () in self .stop_ids
143+
144+
136145class StopAfterThink (StoppingCriteria ):
137146 def __init__ (self , tokenizer ):
138147 self .tokenizer = tokenizer
@@ -370,18 +379,6 @@ def __init__(self, generation_settings: dict, model_name: str):
370379
371380 self .generation_settings ["pad_token_id" ] = self .tokenizer .eos_token_id
372381
373- @torch .inference_mode ()
374- def generate_response (self , inputs , remove_token_type_ids = False ):
375- if remove_token_type_ids :
376- inputs .pop ("token_type_ids" , None )
377-
378- settings = {** inputs , ** self .generation_settings ,
379- "pad_token_id" : self .tokenizer .eos_token_id }
380- generated = self .model .generate (** settings )
381- txt = self .tokenizer .decode (generated [0 ], skip_special_tokens = True )
382- txt = txt [txt .rfind ("</think>" ) + len ("</think>" ):].lstrip ()
383- yield txt
384-
385382 def create_prompt (self , augmented_query : str ) -> str :
386383 return f"""[gMASK]<sop><|system|>
387384{ system_message } <|user|>
@@ -423,6 +420,80 @@ def generate_response(self, inputs):
423420 yield from super ().generate_response (inputs )
424421
425422
423+ class Phi4 (BaseModel ):
424+ def __init__ (self , generation_settings : dict , model_name : str ):
425+ model_info = CHAT_MODELS [model_name ]
426+
427+ settings = copy .deepcopy (bnb_bfloat16_settings )
428+ settings ["model_settings" ]["attn_implementation" ] = "sdpa"
429+ settings ["model_settings" ]["device_map" ] = "auto"
430+
431+ # Pure-CPU fallback: no quant-weights on GPU, force everything to CPU
432+ if not torch .cuda .is_available ():
433+ settings = {"tokenizer_settings" : {}, "model_settings" : {"device_map" : "cpu" }}
434+
435+ super ().__init__ (model_info , settings , generation_settings )
436+
437+ self .generation_settings ["pad_token_id" ] = self .tokenizer .eos_token_id
438+
439+ def create_prompt (self , augmented_query : str ) -> str :
440+ return (
441+ f"<|system|>{ system_message } <|end|>"
442+ f"<|user|>{ augmented_query } <|end|><|assistant|>"
443+ )
444+
445+ @torch .inference_mode ()
446+ def generate_response (self , inputs , remove_token_type_ids : bool = False ):
447+ if remove_token_type_ids :
448+ inputs .pop ("token_type_ids" , None )
449+
450+ eos_id = self .tokenizer .eos_token_id
451+ user_id = self .tokenizer .convert_tokens_to_ids ("<|user|>" )
452+ assist_id = self .tokenizer .convert_tokens_to_ids ("<|assistant|>" )
453+
454+ stop_criteria = StoppingCriteriaList ([_StopOnToken ({user_id , eos_id })])
455+
456+ streamer = TextIteratorStreamer (
457+ self .tokenizer ,
458+ skip_prompt = True ,
459+ skip_special_tokens = False
460+ )
461+
462+ gen_thread = threading .Thread (
463+ target = self .model .generate ,
464+ kwargs = {** inputs ,
465+ ** self .generation_settings ,
466+ "streamer" : streamer ,
467+ "eos_token_id" : eos_id ,
468+ "pad_token_id" : eos_id ,
469+ "stopping_criteria" : stop_criteria },
470+ daemon = True
471+ )
472+ gen_thread .start ()
473+
474+ buffer , sent = "" , 0
475+ ASSIST , USER , END = "<|assistant|>" , "<|user|>" , "<|end|>"
476+
477+ for chunk in streamer :
478+ buffer += chunk
479+
480+ if ASSIST in buffer :
481+ buffer = buffer .split (ASSIST )[- 1 ]
482+
483+ for tag in (USER , END ):
484+ cut = buffer .find (tag )
485+ if cut != - 1 :
486+ buffer = buffer [:cut ]
487+ streamer .break_on_eos = True
488+
489+ clean = buffer .replace (ASSIST , "" ).replace (USER , "" ).replace (END , "" )
490+
491+ if len (clean ) > sent :
492+ yield clean [sent :]
493+ sent = len (clean )
494+
495+ gen_thread .join ()
496+
426497def generate_response (model_instance , augmented_query ):
427498 prompt = model_instance .create_prompt (augmented_query )
428499 inputs = model_instance .create_inputs (prompt )
0 commit comments