Skip to content

Commit 1fb5bf6

Browse files
committed
#98 Introduce ExtractResult class and modify strategy interface
This commit introduces a new ExtractResult class which replaces the use of a primitive type (string) for extract return values. This modification preserved vital metadata about the document which was earlier lost when just the resulting string was returned. This change also involved modifying the strategy interface as well as its execution to use the new ExtractResult class rather than primitive string.
1 parent 31b1a07 commit 1fb5bf6

File tree

6 files changed

+107
-10
lines changed

6 files changed

+107
-10
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from typing import Callable, Any
2+
3+
"""
4+
IMPORTANT INFORMATION ABOUT THIS CLASS:
5+
6+
This is not the final version of the object, namespace, or intended use.
7+
8+
For this reason, I am not creating an interface, etc. Add code here as soon as possible
9+
along with further integrations, and once we have gained sufficient experience, we will
10+
undertake a refactor.
11+
12+
Currently, the object's purpose is to replace the use of a primitive type, a string, for
13+
extract returns. The limitation of this approach became evident when returning only the
14+
resulting string caused us to lose valuable metadata about the document. Thanks to this
15+
class, we retain DoclingDocument and foresee that other converters/OCRs may have similar
16+
metadata.
17+
"""
18+
class ExtractResult:
19+
def __init__(
20+
self,
21+
value: Any,
22+
text_gatherer: Callable[[Any], str] = None
23+
):
24+
"""
25+
Initializes a UnifiedText instance.
26+
27+
Args:
28+
value (Any): The object containing or representing the text.
29+
text_gatherer (Callable[[Any], str], optional): A callable that extracts text
30+
from the `data`. Defaults to the `_default_text_gatherer`.
31+
32+
Raises:
33+
ValueError: If `text_gatherer` is not callable or not provided when `value` is not a string.
34+
35+
Examples:
36+
Using the default text gatherer
37+
38+
>>> unified = ExtractResult("Example text")
39+
>>> print(unified.text())
40+
Example text
41+
42+
Using a custom text gatherer
43+
44+
>>> def custom_gatherer(value): return f"Custom: {value}"
45+
>>> unified = ExtractResult(123, custom_gatherer)
46+
>>> print(unified.text())
47+
Custom: 123
48+
"""
49+
50+
if text_gatherer is not None and not callable(text_gatherer):
51+
raise ValueError("The `text_gatherer` provided to UnifiedText must be a callable.")
52+
53+
if not isinstance(value, str) and not callable(text_gatherer):
54+
raise ValueError("If `value` is not a string, `text_gatherer` must be provided.")
55+
56+
self.value = value
57+
self.text_gatherer = text_gatherer or self._default_text_gatherer
58+
59+
@staticmethod
60+
def from_text(value: str) -> 'ExtractResult':
61+
return ExtractResult(value)
62+
63+
@property
64+
def text(self) -> str:
65+
"""
66+
Retrieves text using the text gatherer.
67+
68+
Returns:
69+
str: The extracted text from `value`.
70+
"""
71+
return self.text_gatherer(self.value)
72+
73+
@staticmethod
74+
def _default_text_gatherer(value: Any) -> str:
75+
"""
76+
Default method to extract str from str.
77+
So it just return value, obviously.
78+
79+
Args:
80+
value (Any): The input value.
81+
82+
Returns:
83+
str: The text representation of the input value.
84+
85+
Raises:
86+
TypeError: If the `value` is not a string.
87+
"""
88+
if isinstance(value, str):
89+
return value
90+
raise TypeError("Default text gatherer only supports strings.")

text_extract_api/extract/strategies/easyocr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from PIL import Image
44
import easyocr
55

6+
from extract.extract_result import ExtractResult
67
from text_extract_api.extract.strategies.strategy import Strategy
78
from text_extract_api.files.file_formats.file_format import FileFormat
89
from text_extract_api.files.file_formats.image import ImageFileFormat
@@ -13,7 +14,7 @@ class EasyOCRStrategy(Strategy):
1314
def name(cls) -> str:
1415
return "easyOCR"
1516

16-
def extract_text(self, file_format: FileFormat, language: str = 'en') -> str:
17+
def extract_text(self, file_format: FileFormat, language: str = 'en') -> ExtractResult:
1718
"""
1819
Extract text using EasyOCR after converting the input file to images
1920
(if not already an ImageFileFormat).
@@ -53,4 +54,6 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> str:
5354

5455
# Join text from all images/pages
5556
full_text = "\n\n".join(all_extracted_text)
56-
return full_text
57+
58+
59+
return ExtractResult.from_text(full_text)

text_extract_api/extract/strategies/llama_vision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ollama
66

7+
from extract.extract_result import ExtractResult
78
from text_extract_api.extract.strategies.strategy import Strategy
89
from text_extract_api.files.file_formats.file_format import FileFormat
910
from text_extract_api.files.file_formats.image import ImageFileFormat
@@ -16,7 +17,7 @@ class LlamaVisionStrategy(Strategy):
1617
def name(cls) -> str:
1718
return "llama_vision"
1819

19-
def extract_text(self, file_format: FileFormat, language: str = 'en') -> str:
20+
def extract_text(self, file_format: FileFormat, language: str = 'en') -> ExtractResult:
2021

2122
if (
2223
not isinstance(file_format, ImageFileFormat)
@@ -66,4 +67,4 @@ def extract_text(self, file_format: FileFormat, language: str = 'en') -> str:
6667

6768
print(response)
6869

69-
return extracted_text
70+
return ExtractResult.from_text(extracted_text)

text_extract_api/extract/strategies/strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from pydantic.v1.typing import get_class
99

10+
from extract.extract_result import ExtractResult
1011
from text_extract_api.files.file_formats.file_format import FileFormat
1112

1213
class Strategy:
@@ -27,7 +28,7 @@ def name(cls) -> str:
2728
raise NotImplementedError("Strategy subclasses must implement name")
2829

2930
@classmethod
30-
def extract_text(cls, file_format: Type["FileFormat"], language: str = 'en') -> str:
31+
def extract_text(cls, file_format: Type["FileFormat"], language: str = 'en') -> ExtractResult:
3132
raise NotImplementedError("Strategy subclasses must implement extract_text method")
3233

3334
@classmethod

text_extract_api/extract/tasks.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ def ocr_task(
4848
extracted_text = cached_result.decode('utf-8')
4949

5050
if extracted_text is None:
51-
print("Extracting text from PDF...")
51+
print(f"Extracting text from file using strategy: {strategy.name()}")
5252
self.update_state(state='PROGRESS',
53-
meta={'progress': 30, 'status': 'Extracting text from PDF', 'start_time': start_time,
53+
meta={'progress': 30, 'status': 'Extracting text from file', 'start_time': start_time,
5454
'elapsed_time': time.time() - start_time}) # Example progress update
55-
extracted_text = strategy.extract_text(FileFormat.from_binary(binary_content), language)
55+
extract_result = strategy.extract_text(FileFormat.from_binary(binary_content), language)
56+
extracted_text = extract_result.text
57+
5658
else:
5759
print("Using cached result...")
5860

@@ -62,11 +64,12 @@ def ocr_task(
6264
'start_time': start_time,
6365
'elapsed_time': time.time() - start_time}) # Example progress update
6466

67+
# @todo Universal Text Object - is cache available
6568
if ocr_cache:
6669
redis_client.set(file_hash, extracted_text)
6770

6871
if prompt:
69-
print("Transforming text using LLM (prompt={prompt}, model={model}) ...")
72+
print(f"Transforming text using LLM (prompt={prompt}, model={model}) ...")
7073
self.update_state(state='PROGRESS', meta={'progress': 75, 'status': 'Processing LLM', 'start_time': start_time,
7174
'elapsed_time': time.time() - start_time}) # Example progress update
7275
llm_resp = ollama.generate(model, prompt + extracted_text, stream=True)

text_extract_api/files/file_formats/file_format.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def from_binary(
6868
mime_type = mime_type or FileFormat._guess_mime_type(binary_data=binary, filename=filename)
6969
from text_extract_api.files.file_formats.pdf import PdfFileFormat # type: ignore
7070
file_format_class = cls._get_file_format_class(mime_type)
71-
print(file_format_class)
7271
return file_format_class(binary_file_content=binary, filename=filename, mime_type=mime_type)
7372

7473
def __repr__(self) -> str:

0 commit comments

Comments
 (0)