Skip to content

Commit 262d216

Browse files
authored
fix granite vision
1 parent 34b329f commit 262d216

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

src/module_process_images.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def initialize_model_and_tokenizer(self):
184184

185185
@torch.inference_mode()
186186
def process_single_image(self, raw_image):
187-
# NEW – drop alpha if present
188187
if raw_image.mode != "RGB":
189188
raw_image = raw_image.convert("RGB")
190189

@@ -565,16 +564,20 @@ def initialize_model_and_tokenizer(self):
565564
save_dir = VISION_MODELS[chosen_model]["cache_dir"]
566565
cache_dir = CACHE_DIR / save_dir
567566
cache_dir.mkdir(parents=True, exist_ok=True)
567+
568568
config = BitsAndBytesConfig(
569569
load_in_4bit=True,
570570
bnb_4bit_compute_dtype=torch.bfloat16,
571571
bnb_4bit_quant_type="nf4",
572572
llm_int8_skip_modules=[
573573
"vision_tower",
574-
"multi_modal_projector",
575-
"language_model.lm_head"
574+
"multi_modal_projector",
575+
"language_model.embed_tokens",
576+
"language_model.norm",
577+
"lm_head"
576578
]
577579
)
580+
578581
processor = AutoProcessor.from_pretrained(
579582
model_id,
580583
use_fast=True,
@@ -588,30 +591,25 @@ def initialize_model_and_tokenizer(self):
588591
low_cpu_mem_usage=True,
589592
cache_dir=cache_dir,
590593
token=False
591-
).eval()
594+
)
595+
model.to(self.device)
596+
model.eval()
592597
my_cprint("Granite Vision model loaded into memory", "green")
593598
return model, None, processor
594599

595600
@torch.inference_mode()
596601
def process_single_image(self, raw_image):
597-
msg = "Describe in detail what this image depicts but limit your response to one paragraph with no line breaks in it."
598-
prompt = f"<|user|>\n<image>\n{msg}\n<|assistant|>\n"
599-
inputs = self.processor(
600-
images=raw_image,
601-
text=prompt,
602-
return_tensors="pt"
603-
).to(self.device)
604-
output = self.model.generate(
605-
**inputs,
606-
max_new_tokens=1024,
607-
do_sample=False,
608-
num_beams=1
602+
if raw_image.mode != "RGB":
603+
raw_image = raw_image.convert("RGB")
604+
msg = (
605+
"Describe in detail what this image depicts but limit your response "
606+
"to one paragraph with no line breaks in it."
609607
)
610-
resp = self.processor.decode(
611-
output[0],
612-
skip_special_tokens=True
613-
).split('<|assistant|>')[-1].strip()
614-
return ' '.join(line.strip() for line in resp.split('\n') if line.strip())
608+
prompt = f"<|user|>\n<image>\n{msg}\n<|assistant|>\n"
609+
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(self.device)
610+
output = self.model.generate(**inputs, max_new_tokens=1024, do_sample=False, num_beams=1)
611+
resp = self.processor.decode(output[0], skip_special_tokens=True).split('<|assistant|>')[-1].strip()
612+
return " ".join(line.strip() for line in resp.split("\n") if line.strip())
615613

616614

617615
class loader_qwenvl(BaseLoader):

0 commit comments

Comments
 (0)