Skip to content

Commit d5c7827

Browse files
committed
image processing resposeDocument formate compatible endpoint - Adithya S k
1 parent 7dbf349 commit d5c7827

File tree

7 files changed

+158
-56
lines changed

7 files changed

+158
-56
lines changed

omniparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
2323
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2424
if load_documents:
2525
print("[LOG] ✅ Loading OCR Model")
26-
shared_state.model_list = load_all_models()
26+
# shared_state.model_list = load_all_models()
2727
print("[LOG] ✅ Loading Vision Model")
2828
# if device == "cuda":
2929
shared_state.vision_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device)

omniparse/image/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from PIL import Image
77
# from omniparse.document.parse import parse_single_image
88
from omniparse.documents.parse import parse_single_pdf
9-
from omniparse.image.process import pre_process_image
9+
from omniparse.image.process import process_image_task
1010
from omniparse.utils import encode_images
11+
from omniparse.models import responseDocument
1112

1213
def parse_image(input_data, model_state) -> dict:
1314
temp_files = []
@@ -44,17 +45,22 @@ def parse_image(input_data, model_state) -> dict:
4445

4546
# Parse the PDF file
4647
full_text, images, out_meta = parse_single_pdf(temp_pdf_path, model_state.model_list)
47-
images = encode_images(images)
48+
49+
parse_image_result = responseDocument(
50+
text=full_text,
51+
metadata=out_meta
52+
)
53+
encode_images(images,parse_image_result)
4854

49-
return {"message": "Document parsed successfully", "markdown": full_text, "metadata": out_meta, "images": images}
55+
return parse_image_result
5056

5157
finally:
5258
# Clean up the temporary files
5359
for file_path in temp_files:
5460
if os.path.exists(file_path):
5561
os.remove(file_path)
5662

57-
def process_image(input_data, task, model_state) -> dict:
63+
def process_image(input_data, task, model_state) -> responseDocument:
5864
try:
5965
temp_files = []
6066

@@ -76,9 +82,9 @@ def process_image(input_data, task, model_state) -> dict:
7682
image_data = Image.open(temp_file_path).convert("RGB")
7783

7884
# Process the image using your function (e.g., process_image)
79-
results = pre_process_image(image_data, task, vision_model = model_state.vision_model, vision_processor = model_state.vision_processor)
85+
image_process_results : responseDocument = process_image_task(image_data, task, model_state)
8086

81-
return {"results": results}
87+
return image_process_results
8288

8389
finally:
8490
# Clean up the temporary files

omniparse/image/parse.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import warnings
2-
warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings
3-
2+
from typing import List, Dict, Tuple, Optional
43
import pypdfium2 as pdfium # Needs to be at the top to avoid warnings
54
from PIL import Image
6-
75
from omniparse.documents.utils import flush_cuda_memory
86
from omniparse.documents.tables.table import format_tables
97
from omniparse.documents.debug.data import dump_bbox_debug_data
@@ -25,14 +23,14 @@
2523
from omniparse.documents.cleaners.text import cleanup_text
2624
from omniparse.documents.images.extract import extract_images
2725
from omniparse.documents.images.save import images_to_dict
28-
29-
from typing import List, Dict, Tuple, Optional
3026
from omniparse.documents.settings import settings
3127

28+
warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings
29+
3230

3331
def parse_single_image(
3432
image: Image.Image,
35-
model_lst: List,
33+
model_list: List,
3634
metadata: Optional[Dict] = None,
3735
langs: Optional[List[str]] = None,
3836
batch_multiplier: int = 1
@@ -54,10 +52,10 @@ def parse_single_image(
5452
"languages": langs,
5553
}
5654

57-
texify_model, layout_model, order_model, edit_model, detection_model, ocr_model = model_lst
55+
texify_model, layout_model, order_model, edit_model, detection_model, ocr_model = model_list
5856

5957
# Identify text lines on pages
60-
text_line_prediction = surya_detection(image, detection_model, batch_multiplier=batch_multiplier)
58+
surya_detection(image, detection_model, batch_multiplier=batch_multiplier)
6159
flush_cuda_memory()
6260

6361
# OCR pages as needed

omniparse/image/process.py

Lines changed: 124 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,144 @@
1+
from typing import Dict, Any, Union
2+
from PIL import Image as PILImage
3+
import base64
4+
from io import BytesIO
5+
import copy
6+
from omniparse.image.utils import plot_bbox, fig_to_pil,draw_polygons,draw_ocr_bboxes
7+
from omniparse.models import responseDocument
18

2-
def pre_process_image(image, task_prompt, vision_model, vision_processor):
3-
# :Convert binary image data to PIL Image
4-
# image = Image.fromarray(image)
9+
def process_image_task(image_data: Union[str, bytes, PILImage.Image], task_prompt: str, model_state) -> Dict[str, Any]:
10+
# Convert image_data if it's in bytes
11+
if isinstance(image_data, bytes):
12+
pil_image = PILImage.open(BytesIO(image_data))
13+
elif isinstance(image_data, str):
14+
try:
15+
image_bytes = base64.b64decode(image_data)
16+
pil_image = PILImage.open(BytesIO(image_bytes))
17+
except Exception as e:
18+
raise ValueError(f"Failed to decode base64 image: {str(e)}")
19+
elif isinstance(image_data, PILImage.Image):
20+
pil_image = image_data
21+
else:
22+
raise ValueError("Unsupported image_data type. Should be either string (file path), bytes (binary image data), or PIL.Image instance.")
23+
24+
# Process based on task_prompt
525
if task_prompt == 'Caption':
6-
task_prompt = '<CAPTION>'
7-
results = run_example(task_prompt, image, vision_model, vision_processor)
26+
task_prompt_model = '<CAPTION>'
827
elif task_prompt == 'Detailed Caption':
9-
task_prompt = '<DETAILED_CAPTION>'
10-
results = run_example(task_prompt, image, vision_model, vision_processor)
28+
task_prompt_model = '<DETAILED_CAPTION>'
1129
elif task_prompt == 'More Detailed Caption':
12-
task_prompt = '<MORE_DETAILED_CAPTION>'
13-
results = run_example(task_prompt, image, vision_model, vision_processor)
30+
task_prompt_model = '<MORE_DETAILED_CAPTION>'
31+
elif task_prompt == 'Caption + Grounding':
32+
task_prompt_model = '<CAPTION>'
33+
elif task_prompt == 'Detailed Caption + Grounding':
34+
task_prompt_model = '<DETAILED_CAPTION>'
35+
elif task_prompt == 'More Detailed Caption + Grounding':
36+
task_prompt_model = '<MORE_DETAILED_CAPTION>'
1437
elif task_prompt == 'Object Detection':
15-
task_prompt = '<OD>'
16-
results = run_example(task_prompt, image, vision_model, vision_processor)
38+
task_prompt_model = '<OD>'
1739
elif task_prompt == 'Dense Region Caption':
18-
task_prompt = '<DENSE_REGION_CAPTION>'
19-
results = run_example(task_prompt, image, vision_model, vision_processor)
40+
task_prompt_model = '<DENSE_REGION_CAPTION>'
2041
elif task_prompt == 'Region Proposal':
21-
task_prompt = '<REGION_PROPOSAL>'
22-
results = run_example(task_prompt, image, vision_model, vision_processor)
42+
task_prompt_model = '<REGION_PROPOSAL>'
2343
elif task_prompt == 'Caption to Phrase Grounding':
24-
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
25-
results = run_example(task_prompt, image, vision_model, vision_processor)
44+
task_prompt_model = '<CAPTION_TO_PHRASE_GROUNDING>'
2645
elif task_prompt == 'Referring Expression Segmentation':
27-
task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
28-
results = run_example(task_prompt, image, vision_model, vision_processor)
46+
task_prompt_model = '<REFERRING_EXPRESSION_SEGMENTATION>'
2947
elif task_prompt == 'Region to Segmentation':
30-
task_prompt = '<REGION_TO_SEGMENTATION>'
31-
results = run_example(task_prompt, image,vision_model, vision_processor)
48+
task_prompt_model = '<REGION_TO_SEGMENTATION>'
3249
elif task_prompt == 'Open Vocabulary Detection':
33-
task_prompt = '<OPEN_VOCABULARY_DETECTION>'
34-
results = run_example(task_prompt, image, vision_model, vision_processor)
50+
task_prompt_model = '<OPEN_VOCABULARY_DETECTION>'
3551
elif task_prompt == 'Region to Category':
36-
task_prompt = '<REGION_TO_CATEGORY>'
37-
results = run_example(task_prompt, image, vision_model, vision_processor)
52+
task_prompt_model = '<REGION_TO_CATEGORY>'
3853
elif task_prompt == 'Region to Description':
39-
task_prompt = '<REGION_TO_DESCRIPTION>'
40-
results = run_example(task_prompt, image, vision_model, vision_processor)
54+
task_prompt_model = '<REGION_TO_DESCRIPTION>'
4155
elif task_prompt == 'OCR':
42-
task_prompt = '<OCR>'
43-
results = run_example(task_prompt, image, vision_model, vision_processor)
56+
task_prompt_model = '<OCR>'
4457
elif task_prompt == 'OCR with Region':
45-
task_prompt = '<OCR_WITH_REGION>'
46-
results = run_example(task_prompt, image, vision_model, vision_processor)
58+
task_prompt_model = '<OCR_WITH_REGION>'
4759
else:
48-
return {"error": "Invalid task prompt"}
60+
raise ValueError("Invalid task prompt")
61+
62+
results, processed_image = pre_process_image(pil_image, task_prompt_model, model_state.vision_model, model_state.vision_processor)
63+
# Update responseDocument fields based on the results
64+
process_image_result = responseDocument(
65+
text = str(results)
66+
)
67+
68+
if processed_image is not None:
69+
process_image_result.add_image(f"{task_prompt}", processed_image)
70+
71+
return process_image_result
4972

50-
return results
73+
# Your pre_process_image function with some adjustments
74+
def pre_process_image(image, task_prompt, vision_model, vision_processor):
75+
if task_prompt == '<CAPTION>':
76+
results = run_example(task_prompt, image, vision_model, vision_processor)
77+
return results, None
78+
elif task_prompt == '<DETAILED_CAPTION>':
79+
results = run_example(task_prompt, image, vision_model, vision_processor)
80+
return results, None
81+
elif task_prompt == '<MORE_DETAILED_CAPTION>':
82+
results = run_example(task_prompt, image, vision_model, vision_processor)
83+
return results, None
84+
elif task_prompt == '<CAPTION_TO_PHRASE_GROUNDING>':
85+
results = run_example(task_prompt, image, vision_model, vision_processor)
86+
fig = plot_bbox(image, results[task_prompt])
87+
return results, fig_to_pil(fig)
88+
elif task_prompt == '<DETAILED_CAPTION + GROUNDING>':
89+
results = run_example(task_prompt, image, vision_model, vision_processor)
90+
fig = plot_bbox(image, results[task_prompt])
91+
return results, fig_to_pil(fig)
92+
elif task_prompt == '<MORE_DETAILED_CAPTION + GROUNDING>':
93+
results = run_example(task_prompt, image, vision_model, vision_processor)
94+
fig = plot_bbox(image, results[task_prompt])
95+
return results, fig_to_pil(fig)
96+
elif task_prompt == '<OD>':
97+
results = run_example(task_prompt, image, vision_model, vision_processor)
98+
fig = plot_bbox(image, results[task_prompt])
99+
return results, fig_to_pil(fig)
100+
elif task_prompt == '<DENSE_REGION_CAPTION>':
101+
results = run_example(task_prompt, image, vision_model, vision_processor)
102+
fig = plot_bbox(image, results[task_prompt])
103+
return results, fig_to_pil(fig)
104+
elif task_prompt == '<REGION_PROPOSAL>':
105+
results = run_example(task_prompt, image, vision_model, vision_processor)
106+
fig = plot_bbox(image, results[task_prompt])
107+
return results, fig_to_pil(fig)
108+
elif task_prompt == '<CAPTION_TO_PHRASE_GROUNDING>':
109+
results = run_example(task_prompt, image, vision_model, vision_processor)
110+
fig = plot_bbox(image, results[task_prompt])
111+
return results, fig_to_pil(fig)
112+
elif task_prompt == '<REFERRING_EXPRESSION_SEGMENTATION>':
113+
results = run_example(task_prompt, image, vision_model, vision_processor)
114+
output_image = copy.deepcopy(image)
115+
output_image = draw_polygons(output_image, results[task_prompt], fill_mask=True)
116+
return results, output_image
117+
elif task_prompt == '<REGION_TO_SEGMENTATION>':
118+
results = run_example(task_prompt, image, vision_model, vision_processor)
119+
output_image = copy.deepcopy(image)
120+
output_image = draw_polygons(output_image, results[task_prompt], fill_mask=True)
121+
return results, output_image
122+
elif task_prompt == '<OPEN_VOCABULARY_DETECTION>':
123+
results = run_example(task_prompt, image, vision_model, vision_processor)
124+
fig = plot_bbox(image, results[task_prompt])
125+
return results, fig_to_pil(fig)
126+
elif task_prompt == '<REGION_TO_CATEGORY>':
127+
results = run_example(task_prompt, image, vision_model, vision_processor)
128+
return results, None
129+
elif task_prompt == '<REGION_TO_DESCRIPTION>':
130+
results = run_example(task_prompt, image, vision_model, vision_processor)
131+
return results, None
132+
elif task_prompt == '<OCR>':
133+
results = run_example(task_prompt, image, vision_model, vision_processor)
134+
return results, None
135+
elif task_prompt == '<OCR_WITH_REGION>':
136+
results = run_example(task_prompt, image, vision_model, vision_processor)
137+
output_image = copy.deepcopy(image)
138+
output_image = draw_ocr_bboxes(output_image, results[task_prompt])
139+
return results, output_image
140+
else:
141+
raise ValueError("Invalid task prompt")
51142

52143
def run_example(task_prompt, image, vision_model, vision_processor):
53144
# if text_input is None:

omniparse/image/router.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.responses import JSONResponse
33
from omniparse import get_shared_state
44
from omniparse.image import parse_image, process_image
5+
from omniparse.models import responseDocument
56

67
image_router = APIRouter()
78
model_state = get_shared_state()
@@ -20,8 +21,8 @@ async def parse_image_endpoint(file: UploadFile = File(...)):
2021
async def process_image_route(image: UploadFile = File(...), task: str = Form(...)):
2122
try:
2223
file_bytes = await image.read()
23-
result = process_image(file_bytes, task, model_state)
24-
return JSONResponse(content=result)
24+
result : responseDocument = process_image(file_bytes, task, model_state)
25+
return JSONResponse(content=result.model_dump())
2526

2627
except Exception as e:
2728
raise HTTPException(status_code=500, detail=str(e))

omniparse/image/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import random
23
import numpy as np
34
from PIL import Image, ImageDraw, ImageFont
@@ -66,3 +67,8 @@ def draw_ocr_bboxes(image, prediction):
6667
return image
6768

6869

70+
def fig_to_pil(fig):
71+
buf = io.BytesIO()
72+
fig.savefig(buf, format='png')
73+
buf.seek(0)
74+
return Image.open(buf)

omniparse/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
import base64
22
import os
33
from art import text2art
4+
from omniparse.models import responseDocument
45

5-
def encode_images(images):
6-
image_data = {}
6+
def encode_images(images, inputDocument:responseDocument):
7+
image_data = []
78
for i, (filename, image) in enumerate(images.items()):
89
# print(f"Processing image {filename}")
9-
1010
# Save image as PNG
1111
image.save(filename, "PNG")
12-
1312
# Read the saved image file as bytes
1413
with open(filename, "rb") as f:
1514
image_bytes = f.read()
16-
1715
# Convert image to base64
1816
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
1917
image_data[f'{filename}'] = image_base64
20-
18+
19+
inputDocument.add_image(image_name=filename,image_data=image_base64)
20+
2121
# Remove the temporary image file
2222
os.remove(filename)
2323
return image_data

0 commit comments

Comments
 (0)