From 61c7d1b1bb7680a179e6ae2cb9f9b98379b8dd10 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:05:12 -0800 Subject: [PATCH 1/2] Bug fixes for HF models --- .../src/sagemaker/serve/builder/schema_builder.py | 5 +++++ .../src/sagemaker/serve/model_builder_servers.py | 13 +++++++++++++ .../src/sagemaker/serve/model_builder_utils.py | 6 ++++++ sagemaker-serve/tests/integ/test_tei_integration.py | 2 +- 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py b/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py index d68c2bffe1..faa8066d52 100644 --- a/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py @@ -196,6 +196,11 @@ def _get_deserializer(self, obj): return StringDeserializer() if _is_jsonable(obj): return JSONDeserializer() + if isinstance(obj, dict) and "content_type" in obj: + try: + return BytesDeserializer() + except ValueError as e: + logger.error(e) raise ValueError( ( diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 831c37ee14..bd2c9a70a2 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model: hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string + # Get model metadata for task detection (same pattern as _build_for_triton) + hf_model_md = self.get_huggingface_model_metadata( + self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + ) + model_task = hf_model_md.get("pipeline_tag") + if model_task: + self.env_vars.update({"HF_TASK": model_task}) + self.env_vars.update({"HF_MODEL_ID": self.model}) + + # Add HuggingFace token if available (same as other methods) + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): + self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 1c3016cf86..83e44a56c4 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -992,6 +992,12 @@ def _hf_schema_builder_init(self, model_task: str) -> None: sample_outputs, ) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task) + # Unwrap list outputs for binary tasks (text-to-image, audio, etc.) + # Remote schema retriever returns [{'data': b'...', 'content_type': '...'}] + # but SchemaBuilder expects {'data': b'...', 'content_type': '...'} + if isinstance(sample_outputs, list) and len(sample_outputs) > 0: + sample_outputs = sample_outputs[0] + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) except ValueError as e: diff --git a/sagemaker-serve/tests/integ/test_tei_integration.py b/sagemaker-serve/tests/integ/test_tei_integration.py index 5f4107213d..f91f0a18a6 100644 --- a/sagemaker-serve/tests/integ/test_tei_integration.py +++ b/sagemaker-serve/tests/integ/test_tei_integration.py @@ -104,7 +104,7 @@ def build_and_deploy(): core_endpoint = model_builder.deploy( endpoint_name=f"{ENDPOINT_NAME_PREFIX}-{unique_id}", - initial_instance_count=1 + initial_instance_count=1, ) logger.info(f"Endpoint Successfully Created: {core_endpoint.endpoint_name}") From 7ee9cd671953d1585ee1ef4eb1a01ddf64b41396 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 17 Dec 2025 08:49:54 -0800 Subject: [PATCH 2/2] Fix serialization deserialization issues in core --- sagemaker-core/src/sagemaker/core/training/configs.py | 7 ++++--- sagemaker-core/src/sagemaker/core/utils/utils.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/training/configs.py b/sagemaker-core/src/sagemaker/core/training/configs.py index 9a712cb19a..ad2232d630 100644 --- a/sagemaker-core/src/sagemaker/core/training/configs.py +++ b/sagemaker-core/src/sagemaker/core/training/configs.py @@ -257,15 +257,16 @@ class InputData(BaseConfig): Parameters: channel_name (StrPipeVar): The name of the input data source channel. - data_source (Union[str, S3DataSource, FileSystemDataSource, DatasetSource]): + data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]): The data source for the channel. Can be an S3 URI string, local file path string, - S3DataSource object, or FileSystemDataSource object. + S3DataSource object, FileSystemDataSource object, DatasetSource object, or a + pipeline variable (Properties) from a previous step. content_type (StrPipeVar): The MIME type of the data. """ channel_name: StrPipeVar = None - data_source: Union[str, FileSystemDataSource, S3DataSource, DatasetSource] = None + data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None content_type: StrPipeVar = None diff --git a/sagemaker-core/src/sagemaker/core/utils/utils.py b/sagemaker-core/src/sagemaker/core/utils/utils.py index 909e8463a9..163b3b5354 100644 --- a/sagemaker-core/src/sagemaker/core/utils/utils.py +++ b/sagemaker-core/src/sagemaker/core/utils/utils.py @@ -273,7 +273,7 @@ def pascal_to_snake(pascal_str): def is_not_primitive(obj): - return not isinstance(obj, (int, float, str, bool, datetime.datetime)) + return not isinstance(obj, (int, float, str, bool, datetime.datetime, bytes)) def is_not_str_dict(obj):