diff --git a/sagemaker-core/src/sagemaker/core/utils/utils.py b/sagemaker-core/src/sagemaker/core/utils/utils.py index 909e8463a9..163b3b5354 100644 --- a/sagemaker-core/src/sagemaker/core/utils/utils.py +++ b/sagemaker-core/src/sagemaker/core/utils/utils.py @@ -273,7 +273,7 @@ def pascal_to_snake(pascal_str): def is_not_primitive(obj): - return not isinstance(obj, (int, float, str, bool, datetime.datetime)) + return not isinstance(obj, (int, float, str, bool, datetime.datetime, bytes)) def is_not_str_dict(obj): diff --git a/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py b/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py index d68c2bffe1..faa8066d52 100644 --- a/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py @@ -196,6 +196,11 @@ def _get_deserializer(self, obj): return StringDeserializer() if _is_jsonable(obj): return JSONDeserializer() + if isinstance(obj, dict) and "content_type" in obj: + try: + return BytesDeserializer() + except ValueError as e: + logger.error(e) raise ValueError( ( diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 831c37ee14..e51dca5e6e 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model: hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string + # Get model metadata for task detection + hf_model_md = self.get_huggingface_model_metadata( + self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + ) + model_task = hf_model_md.get("pipeline_tag") + if model_task: + self.env_vars.update({"HF_TASK": model_task}) + self.env_vars.update({"HF_MODEL_ID": self.model}) + + # Add HuggingFace token if available + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): + self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index d486f24acb..b483a1b163 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -992,6 +992,12 @@ def _hf_schema_builder_init(self, model_task: str) -> None: sample_outputs, ) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task) + # Unwrap list outputs for binary tasks (text-to-image, audio, etc.) + # Remote schema retriever returns [{'data': b'...', 'content_type': '...'}] + # but SchemaBuilder expects {'data': b'...', 'content_type': '...'} + if isinstance(sample_outputs, list) and len(sample_outputs) > 0: + sample_outputs = sample_outputs[0] + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) except ValueError as e: