diff --git a/mindee/input/inference_parameters.py b/mindee/input/inference_parameters.py index 0df0495a..9cde3baa 100644 --- a/mindee/input/inference_parameters.py +++ b/mindee/input/inference_parameters.py @@ -1,9 +1,41 @@ +import json from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Union from mindee.input.polling_options import PollingOptions +class DataSchema: + """Modify the Data Schema.""" + + _override: Optional[list] = None + + def __init__(self, override: Optional[list] = None): + self._override = override + + @property + def override(self): + """Override the data schema.""" + return self._override + + @override.setter + def override(self, value: Optional[Union[str, list]]) -> None: + if value is None: + _override = None + elif isinstance(value, str): + _override = json.loads(value) + elif isinstance(value, list): + _override = value + else: + raise TypeError("Invalid type for data schema override") + if _override is not None and _override == {}: + raise ValueError("Empty override provided") + self._override = _override + + def __str__(self) -> str: + return json.dumps({"override": self.override}) + + @dataclass class InferenceParameters: """Inference parameters to set when sending a file.""" @@ -31,3 +63,4 @@ class InferenceParameters: """Whether to close the file after parsing.""" text_context: Optional[str] = None """Additional text context used by the model during inference. Not recommended, for specific use only.""" + data_schema: Optional[DataSchema] = None diff --git a/mindee/mindee_http/mindee_api_v2.py b/mindee/mindee_http/mindee_api_v2.py index 8a7260fd..673a94c9 100644 --- a/mindee/mindee_http/mindee_api_v2.py +++ b/mindee/mindee_http/mindee_api_v2.py @@ -94,8 +94,10 @@ def req_post_inference_enqueue( data["webhook_ids"] = params.webhook_ids if params.alias and len(params.alias): data["alias"] = params.alias - if params.text_context and (params.text_context): + if params.text_context and len(params.text_context): data["text_context"] = params.text_context + if params.data_schema is not None: + data["data_schema"] = str(params.data_schema) if isinstance(input_source, LocalInputSource): files = {"file": input_source.read_contents(params.close_file)} diff --git a/tests/v2/test_client.py b/tests/v2/test_client.py index 7ae18db4..e5d15282 100644 --- a/tests/v2/test_client.py +++ b/tests/v2/test_client.py @@ -6,6 +6,7 @@ from mindee.error.mindee_error import MindeeApiV2Error, MindeeError from mindee.error.mindee_http_error_v2 import MindeeHTTPErrorV2 from mindee.input import LocalInputSource, PathInput +from mindee.input.inference_parameters import DataSchema from mindee.mindee_http.base_settings import USER_AGENT from mindee.parsing.v2.inference import Inference from mindee.parsing.v2.job import Job @@ -130,7 +131,11 @@ def test_enqueue_and_parse_path_with_env_token(custom_base_url_client): with pytest.raises(MindeeHTTPErrorV2): custom_base_url_client.enqueue_and_get_inference( input_doc, - InferenceParameters("dummy-model", text_context="ignore this message"), + InferenceParameters( + "dummy-model", + text_context="ignore this message", + data_schema=DataSchema(override={"test_field": {}}), + ), ) diff --git a/tests/v2/test_client_integration.py b/tests/v2/test_client_integration.py index 3330f7fb..4d26625c 100644 --- a/tests/v2/test_client_integration.py +++ b/tests/v2/test_client_integration.py @@ -5,6 +5,7 @@ from mindee import ClientV2, InferenceParameters, PathInput, UrlInputSource from mindee.error.mindee_http_error_v2 import MindeeHTTPErrorV2 +from mindee.input.inference_parameters import DataSchema from mindee.parsing.v2.inference_response import InferenceResponse from tests.utils import FILE_TYPES_DIR, V2_PRODUCT_DATA_DIR @@ -25,6 +26,22 @@ def v2_client() -> ClientV2: return ClientV2(api_key) +def _basic_assert_success( + response: InferenceResponse, page_count: int, model_id: str +) -> None: + assert response is not None + assert response.inference is not None + + assert response.inference.file is not None + assert response.inference.file.page_count == page_count + + assert response.inference.model is not None + assert response.inference.model.id == model_id + + assert response.inference.result is not None + assert response.inference.active_options is not None + + @pytest.mark.integration @pytest.mark.v2 def test_parse_file_empty_multiple_pages_must_succeed( @@ -49,24 +66,15 @@ def test_parse_file_empty_multiple_pages_must_succeed( response: InferenceResponse = v2_client.enqueue_and_get_inference( input_source, params ) - assert response is not None - assert response.inference is not None + _basic_assert_success(response=response, page_count=2, model_id=findoc_model_id) - assert response.inference.file is not None assert response.inference.file.name == "multipage_cut-2.pdf" - assert response.inference.file.page_count == 2 - - assert response.inference.model is not None - assert response.inference.model.id == findoc_model_id - assert response.inference.active_options is not None assert response.inference.active_options.rag is False assert response.inference.active_options.raw_text is True assert response.inference.active_options.polygon is False assert response.inference.active_options.confidence is False - assert response.inference.result is not None - assert response.inference.result.raw_text is not None assert len(response.inference.result.raw_text.pages) == 2 @@ -93,24 +101,15 @@ def test_parse_file_empty_single_page_options_must_succeed( response: InferenceResponse = v2_client.enqueue_and_get_inference( input_source, params ) - assert response is not None - assert response.inference is not None - - assert response.inference.model is not None - assert response.inference.model.id == findoc_model_id + _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) - assert response.inference.file is not None assert response.inference.file.name == "blank_1.pdf" - assert response.inference.file.page_count == 1 - assert response.inference.active_options is not None assert response.inference.active_options.rag is True assert response.inference.active_options.raw_text is True assert response.inference.active_options.polygon is True assert response.inference.active_options.confidence is True - assert response.inference.result is not None - @pytest.mark.integration @pytest.mark.v2 @@ -137,18 +136,10 @@ def test_parse_file_filled_single_page_must_succeed( response: InferenceResponse = v2_client.enqueue_and_get_inference( input_source, params ) + _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) - assert response is not None - assert response.inference is not None - - assert response.inference.file is not None assert response.inference.file.name == "default_sample.jpg" - assert response.inference.file.page_count == 1 - assert response.inference.model is not None - assert response.inference.model.id == findoc_model_id - - assert response.inference.active_options is not None assert response.inference.active_options.rag is False assert response.inference.active_options.raw_text is False assert response.inference.active_options.polygon is False @@ -156,7 +147,6 @@ def test_parse_file_filled_single_page_must_succeed( assert response.inference.result.raw_text is None - assert response.inference.result is not None supplier_name = response.inference.result.fields["supplier_name"] assert supplier_name is not None assert supplier_name.value == "John Smith" @@ -266,15 +256,43 @@ def test_blank_url_input_source_must_succeed( response: InferenceResponse = v2_client.enqueue_and_get_inference( input_source, params ) - assert response is not None - assert response.inference is not None - - assert response.inference.file is not None - assert response.inference.file.page_count == 1 + _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) - assert response.inference.model is not None - assert response.inference.model.id == findoc_model_id - assert response.inference.result is not None +@pytest.mark.integration +@pytest.mark.v2 +def test_data_schema_must_succeed( + v2_client: ClientV2, + findoc_model_id: str, +) -> None: + """ + Load a blank PDF from an HTTPS URL and make sure the inference call completes without raising any errors. + """ + input_path: Path = FILE_TYPES_DIR / "pdf" / "blank_1.pdf" - assert response.inference.active_options is not None + input_source = PathInput(input_path) + params = InferenceParameters( + model_id=findoc_model_id, + rag=False, + raw_text=False, + polygon=False, + confidence=False, + webhook_ids=[], + data_schema=DataSchema( + override=[ + { + "name": "test", + "title": "Test", + "is_array": False, + "type": "string", + "description": "A test field", + } + ] + ), + alias="py_integration_data_schema_override", + ) + response: InferenceResponse = v2_client.enqueue_and_get_inference( + input_source, params + ) + _basic_assert_success(response=response, page_count=1, model_id=findoc_model_id) + assert response.inference.result.fields["test"] is not None