Skip to content

Commit 34b329f

Browse files
authored
Add files via upload
1 parent 32f679f commit 34b329f

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/module_process_images.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def choose_image_loader():
6262
loader_func = loader_molmo(config).process_images
6363
elif chosen_model in ['Ovis2 - 1b', 'Ovis2 - 2b']:
6464
loader_func = loader_ovis(config).process_images
65-
elif chosen_model in ['InternVL2.5 - 1b', 'InternVL2.5 - 4b']:
66-
loader_func = loader_internvl2_5(config).process_images
65+
elif chosen_model in ['InternVL3 - 1b', 'InternVL3 - 2b', 'InternVL2.5 - 4b', 'InternVL3 - 8b', 'InternVL3 - 14b']:
66+
loader_func = loader_internvl(config).process_images
6767
elif chosen_model in ['Qwen VL - 3b', 'Qwen VL - 7b']:
6868
loader_func = loader_qwenvl(config).process_images
6969
else:
@@ -184,13 +184,24 @@ def initialize_model_and_tokenizer(self):
184184

185185
@torch.inference_mode()
186186
def process_single_image(self, raw_image):
187+
# NEW – drop alpha if present
188+
if raw_image.mode != "RGB":
189+
raw_image = raw_image.convert("RGB")
190+
187191
query = "Describe this image in as much detail as possible but do not repeat yourself."
188-
inputs = self.tokenizer.apply_chat_template([{"role":"user","image":raw_image,"content":query}], add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True).to(self.device)
192+
inputs = self.tokenizer.apply_chat_template(
193+
[{"role": "user", "image": raw_image, "content": query}],
194+
add_generation_prompt=True,
195+
tokenize=True,
196+
return_tensors="pt",
197+
return_dict=True
198+
).to(self.device)
199+
189200
with torch.no_grad():
190201
outputs = self.model.generate(**inputs, max_length=1024, do_sample=False)
191-
outputs = outputs[:, inputs['input_ids'].shape[1]:]
192-
desc = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
193-
return ' '.join(line.strip() for line in desc.split('\n') if line.strip())
202+
203+
outputs = outputs[:, inputs["input_ids"].shape[1]:]
204+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
194205

195206

196207
class loader_molmo(BaseLoader):
@@ -444,7 +455,7 @@ def process_single_image(self, raw_image):
444455
return " ".join(line.strip() for line in description.split("\n") if line.strip())
445456

446457

447-
class loader_internvl2_5(BaseLoader):
458+
class loader_internvl(BaseLoader):
448459
def initialize_model_and_tokenizer(self):
449460
chosen_model = self.config['vision']['chosen_model']
450461
info = VISION_MODELS[chosen_model]

0 commit comments

Comments
 (0)