Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def pascal_to_snake(pascal_str):


def is_not_primitive(obj):
return not isinstance(obj, (int, float, str, bool, datetime.datetime))
return not isinstance(obj, (int, float, str, bool, datetime.datetime, bytes))


def is_not_str_dict(obj):
Expand Down
5 changes: 5 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def _get_deserializer(self, obj):
return StringDeserializer()
if _is_jsonable(obj):
return JSONDeserializer()
if isinstance(obj, dict) and "content_type" in obj:
try:
return BytesDeserializer()
except ValueError as e:
logger.error(e)

raise ValueError(
(
Expand Down
13 changes: 13 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model:
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
# Get model metadata for task detection (same pattern as _build_for_triton)
hf_model_md = self.get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
model_task = hf_model_md.get("pipeline_tag")
if model_task:
self.env_vars.update({"HF_TASK": model_task})

self.env_vars.update({"HF_MODEL_ID": self.model})

# Add HuggingFace token if available (same as other methods)
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")

# Get HF config for string model IDs
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
Expand Down
6 changes: 6 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,12 @@ def _hf_schema_builder_init(self, model_task: str) -> None:
sample_outputs,
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)

# Unwrap list outputs for binary tasks (text-to-image, audio, etc.)
# Remote schema retriever returns [{'data': b'...', 'content_type': '...'}]
# but SchemaBuilder expects {'data': b'...', 'content_type': '...'}
if isinstance(sample_outputs, list) and len(sample_outputs) > 0:
sample_outputs = sample_outputs[0]

self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)

except ValueError as e:
Expand Down
Loading