Skip to content

Commit fd6f71a

Browse files
authored
add phi 4 mini - 4b
1 parent 063d2d4 commit fd6f71a

File tree

2 files changed

+94
-12
lines changed

2 files changed

+94
-12
lines changed

src/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,17 @@
354354
'gated': False,
355355
'max_tokens': 4096,
356356
},
357+
'Phi 4 Mini - 4b': {
358+
'model': 'Phi 4 Mini - 4b',
359+
'repo_id': 'microsoft/Phi-4-mini-instruct',
360+
'cache_dir': 'microsoft--Phi-4-mini-instruct',
361+
'cps': 222.77,
362+
'vram': 4761.80,
363+
'function': 'Phi4',
364+
'precision': 'bfloat16',
365+
'gated': False,
366+
'max_new_tokens': 4096,
367+
},
357368
'Qwen 3 - 4b': {
358369
'model': 'Qwen 3 - 4b',
359370
'repo_id': 'Qwen/Qwen3-4B',

src/module_chat.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def get_hf_token():
133133
return None
134134

135135

136+
class _StopOnToken(StoppingCriteria):
137+
"""Stop generation when any ID in `stop_ids` is produced."""
138+
def __init__(self, stop_ids):
139+
self.stop_ids = set(stop_ids)
140+
141+
def __call__(self, input_ids, scores, **kwargs):
142+
return input_ids[0, -1].item() in self.stop_ids
143+
144+
136145
class StopAfterThink(StoppingCriteria):
137146
def __init__(self, tokenizer):
138147
self.tokenizer = tokenizer
@@ -370,18 +379,6 @@ def __init__(self, generation_settings: dict, model_name: str):
370379

371380
self.generation_settings["pad_token_id"] = self.tokenizer.eos_token_id
372381

373-
@torch.inference_mode()
374-
def generate_response(self, inputs, remove_token_type_ids=False):
375-
if remove_token_type_ids:
376-
inputs.pop("token_type_ids", None)
377-
378-
settings = {**inputs, **self.generation_settings,
379-
"pad_token_id": self.tokenizer.eos_token_id}
380-
generated = self.model.generate(**settings)
381-
txt = self.tokenizer.decode(generated[0], skip_special_tokens=True)
382-
txt = txt[txt.rfind("</think>") + len("</think>"):].lstrip()
383-
yield txt
384-
385382
def create_prompt(self, augmented_query: str) -> str:
386383
return f"""[gMASK]<sop><|system|>
387384
{system_message}<|user|>
@@ -423,6 +420,80 @@ def generate_response(self, inputs):
423420
yield from super().generate_response(inputs)
424421

425422

423+
class Phi4(BaseModel):
424+
def __init__(self, generation_settings: dict, model_name: str):
425+
model_info = CHAT_MODELS[model_name]
426+
427+
settings = copy.deepcopy(bnb_bfloat16_settings)
428+
settings["model_settings"]["attn_implementation"] = "sdpa"
429+
settings["model_settings"]["device_map"] = "auto"
430+
431+
# Pure-CPU fallback: no quant-weights on GPU, force everything to CPU
432+
if not torch.cuda.is_available():
433+
settings = {"tokenizer_settings": {}, "model_settings": {"device_map": "cpu"}}
434+
435+
super().__init__(model_info, settings, generation_settings)
436+
437+
self.generation_settings["pad_token_id"] = self.tokenizer.eos_token_id
438+
439+
def create_prompt(self, augmented_query: str) -> str:
440+
return (
441+
f"<|system|>{system_message}<|end|>"
442+
f"<|user|>{augmented_query}<|end|><|assistant|>"
443+
)
444+
445+
@torch.inference_mode()
446+
def generate_response(self, inputs, remove_token_type_ids: bool = False):
447+
if remove_token_type_ids:
448+
inputs.pop("token_type_ids", None)
449+
450+
eos_id = self.tokenizer.eos_token_id
451+
user_id = self.tokenizer.convert_tokens_to_ids("<|user|>")
452+
assist_id = self.tokenizer.convert_tokens_to_ids("<|assistant|>")
453+
454+
stop_criteria = StoppingCriteriaList([_StopOnToken({user_id, eos_id})])
455+
456+
streamer = TextIteratorStreamer(
457+
self.tokenizer,
458+
skip_prompt=True,
459+
skip_special_tokens=False
460+
)
461+
462+
gen_thread = threading.Thread(
463+
target=self.model.generate,
464+
kwargs={**inputs,
465+
**self.generation_settings,
466+
"streamer": streamer,
467+
"eos_token_id": eos_id,
468+
"pad_token_id": eos_id,
469+
"stopping_criteria": stop_criteria},
470+
daemon=True
471+
)
472+
gen_thread.start()
473+
474+
buffer, sent = "", 0
475+
ASSIST, USER, END = "<|assistant|>", "<|user|>", "<|end|>"
476+
477+
for chunk in streamer:
478+
buffer += chunk
479+
480+
if ASSIST in buffer:
481+
buffer = buffer.split(ASSIST)[-1]
482+
483+
for tag in (USER, END):
484+
cut = buffer.find(tag)
485+
if cut != -1:
486+
buffer = buffer[:cut]
487+
streamer.break_on_eos = True
488+
489+
clean = buffer.replace(ASSIST, "").replace(USER, "").replace(END, "")
490+
491+
if len(clean) > sent:
492+
yield clean[sent:]
493+
sent = len(clean)
494+
495+
gen_thread.join()
496+
426497
def generate_response(model_instance, augmented_query):
427498
prompt = model_instance.create_prompt(augmented_query)
428499
inputs = model_instance.create_inputs(prompt)

0 commit comments

Comments
 (0)