Skip to content

Commit ce3099a

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Add serializer and deserializer in Endpoint configuration (#1680)
* Test intelligent defaults in sagemaker-core with utils module * Test intelligent defaults in sagemaker-core with utils module * Test intelligent defaults in sagemaker-core with utils module * Test intelligent defaults in sagemaker-core with utils module * Test intelligent defaults in sagemaker-core with utils module * Add serializer and deserializer in Endpoint configuration * Add serializer and deserializer in Endpoint configuration * Add serializer and deserializer in Endpoint configuration --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent 59bdfec commit ce3099a

File tree

16 files changed

+1922
-13
lines changed

16 files changed

+1922
-13
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from operator import index
2+
3+
import pytest
4+
from unittest.mock import Mock, patch
5+
import numpy as np
6+
from sagemaker.core.helper.session_helper import Session, get_execution_role
7+
from sagemaker.core.resources import (
8+
TrainingJob, Model, EndpointConfig, Endpoint,
9+
AlgorithmSpecification, Channel, DataSource, S3DataSource,
10+
OutputDataConfig, ResourceConfig, StoppingCondition
11+
)
12+
from sagemaker.core.shapes import InvokeEndpointOutput
13+
from sagemaker.core.shapes import ContainerDefinition, ProductionVariant
14+
from sagemaker.utils.base_serializers import CSVSerializer, JSONSerializer
15+
from sagemaker.utils.base_deserializers import CSVDeserializer, JSONDeserializer
16+
import pandas as pd
17+
import time
18+
import boto3
19+
20+
21+
@pytest.mark.integration
22+
class TestEndpointInvoke:
23+
24+
@pytest.fixture(scope="class")
25+
def sagemaker_session(self):
26+
"""Create a SageMaker session."""
27+
return Session()
28+
29+
@pytest.fixture(scope="class")
30+
def role(self):
31+
"""Get the execution role."""
32+
return get_execution_role()
33+
34+
@pytest.fixture(scope="class")
35+
def region(self, sagemaker_session):
36+
"""Get the AWS region."""
37+
return sagemaker_session.boto_region_name
38+
39+
@pytest.fixture(scope="class")
40+
def simple_data(self):
41+
# Create a very simple dataset with 5 rows
42+
# First column is target (0 or 1), followed by two features
43+
train_df = pd.DataFrame([
44+
[0, 1.0, 2.0], # Row 1
45+
[1, 2.0, 3.0], # Row 2
46+
[0, 1.5, 2.5], # Row 3
47+
], columns=['target', 'feature1', 'feature2'])
48+
49+
test_df = pd.DataFrame([
50+
[1, 2.5, 3.5], # Row 4
51+
[0, 1.2, 2.2], # Row 5
52+
], columns=['target', 'feature1', 'feature2'])
53+
54+
# Create version of test data without target
55+
test_df_no_target = test_df.drop('target', axis=1)
56+
57+
return {
58+
"train_data": train_df,
59+
"test_data": test_df,
60+
"test_data_no_target": test_df_no_target
61+
}
62+
63+
@pytest.fixture(scope="class")
64+
def training_resources(self, sagemaker_session, simple_data):
65+
# Set up S3 paths
66+
bucket = "sagemaker-us-west-2-913524917855"
67+
prefix = f"test-scikit-iris-{int(time.time())}"
68+
train_path = f"s3://{bucket}/{prefix}/train.csv"
69+
output_path = f"s3://{bucket}/{prefix}/output"
70+
71+
simple_data["train_data"].to_csv('train.csv', index=False, header=False)#.encode('utf-8')
72+
simple_data["test_data"].to_csv('test.csv', index=False, header=False)
73+
74+
# Upload training data
75+
sagemaker_session.upload_data(
76+
"train.csv",
77+
bucket=bucket,
78+
key_prefix=f"{prefix}"
79+
)
80+
81+
return {
82+
"bucket": bucket,
83+
"prefix": prefix,
84+
"train_path": train_path,
85+
"output_path": output_path
86+
}
87+
88+
89+
@pytest.fixture(scope="class")
90+
def endpoint(self, sagemaker_session, training_resources, role):
91+
# Create training job
92+
image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"
93+
job_name = f"test-xgboost-{int(time.time())}"
94+
95+
training_job = TrainingJob.create(
96+
training_job_name=job_name,
97+
hyper_parameters={
98+
"objective": "multi:softmax",
99+
"num_class": "3",
100+
"num_round": "10",
101+
"eval_metric": "merror",
102+
},
103+
algorithm_specification=AlgorithmSpecification(
104+
training_image=image,
105+
training_input_mode="File"
106+
),
107+
role_arn=role,
108+
input_data_config=[
109+
Channel(
110+
channel_name="train",
111+
content_type="csv",
112+
data_source=DataSource(
113+
s3_data_source=S3DataSource(
114+
s3_data_type="S3Prefix",
115+
s3_uri=training_resources["train_path"],
116+
s3_data_distribution_type="FullyReplicated",
117+
)
118+
),
119+
)
120+
],
121+
output_data_config=OutputDataConfig(
122+
s3_output_path=training_resources["output_path"]
123+
),
124+
resource_config=ResourceConfig(
125+
instance_type="ml.m4.xlarge",
126+
instance_count=1,
127+
volume_size_in_gb=30
128+
),
129+
stopping_condition=StoppingCondition(max_runtime_in_seconds=600),
130+
)
131+
training_job.wait()
132+
133+
# Create model, endpoint config, and endpoint
134+
model_name = f"test-model-{int(time.time())}"
135+
model = Model.create(
136+
model_name=model_name,
137+
primary_container=ContainerDefinition(
138+
image=image,
139+
model_data_url=training_job.model_artifacts.s3_model_artifacts,
140+
),
141+
execution_role_arn=role,
142+
)
143+
144+
endpoint_config = EndpointConfig.create(
145+
endpoint_config_name=model_name,
146+
production_variants=[
147+
ProductionVariant(
148+
variant_name=model_name,
149+
initial_instance_count=1,
150+
instance_type="ml.m5.xlarge",
151+
model_name=model,
152+
)
153+
],
154+
)
155+
156+
endpoint = Endpoint.create(
157+
endpoint_name=model_name,
158+
endpoint_config_name=endpoint_config,
159+
)
160+
endpoint.wait_for_status("InService")
161+
162+
yield endpoint
163+
164+
# Cleanup
165+
endpoint.delete()
166+
endpoint_config.delete()
167+
model.delete()
168+
169+
def test_endpoint_invoke_with_serializers(self, endpoint, simple_data):
170+
# Test with serializer and deserializer
171+
serializer = CSVSerializer()
172+
deserializer = CSVDeserializer()
173+
174+
endpoint.serializer = serializer
175+
endpoint.deserializer = deserializer
176+
177+
178+
response = endpoint.invoke(
179+
body=simple_data["test_data_no_target"],
180+
content_type="text/csv",
181+
accept="text/csv"
182+
)
183+
184+
assert response is not None
185+
assert isinstance(response, InvokeEndpointOutput)
186+
assert hasattr(response, 'body')
187+
assert hasattr(response, 'content_type')
188+
189+
def test_endpoint_invoke_without_serializers(self, endpoint, simple_data):
190+
# Test without serializer and deserializer
191+
endpoint.serializer = None
192+
endpoint.deserializer = None
193+
194+
response = endpoint.invoke(
195+
body=simple_data["test_data_no_target"].to_csv(index=False, header=False),
196+
content_type="text/csv",
197+
accept="text/csv"
198+
)
199+
200+
assert response is not None
201+
assert isinstance(response, InvokeEndpointOutput)
202+
assert hasattr(response, 'body')
203+
assert hasattr(response, 'content_type')
204+
205+
def test_endpoint_invoke_with_invalid_serializer_config(self, endpoint):
206+
# Test with only serializer but no deserializer
207+
endpoint.serializer = CSVSerializer()
208+
endpoint.deserializer = None
209+
210+
with pytest.raises(ValueError) as exc_info:
211+
endpoint.invoke(
212+
body="test data",
213+
content_type="text/csv"
214+
)
215+
assert "Both serializer and deserializer must be provided together" in str(exc_info.value)
216+
217+

sagemaker-core/src/sagemaker/core/resources.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
from sagemaker.core.utils.logs import MultiLogStreamHandler
3333
from sagemaker.core.utils.exceptions import *
3434
from typing import ClassVar
35+
from sagemaker.utils.base_serializers import BaseSerializer
36+
from sagemaker.utils.base_deserializers import BaseDeserializer
3537

3638

3739
logger = get_textual_rich_logger(__name__)
3840

3941

4042
class Base(BaseModel):
41-
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")
43+
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid", arbitrary_types_allowed=True)
4244
config_manager: ClassVar[SageMakerConfig] = SageMakerConfig()
4345

4446
@classmethod
@@ -8744,6 +8746,8 @@ class Endpoint(Base):
87448746
pending_deployment_summary: Optional[PendingDeploymentSummary] = Unassigned()
87458747
explainer_config: Optional[ExplainerConfig] = Unassigned()
87468748
shadow_production_variants: Optional[List[ProductionVariantSummary]] = Unassigned()
8749+
serializer: Optional[BaseSerializer] = None
8750+
deserializer: Optional[BaseDeserializer] = None
87478751

87488752
def get_name(self) -> str:
87498753
attributes = vars(self)
@@ -9318,6 +9322,14 @@ def invoke(
93189322
"""
93199323

93209324

9325+
use_serializer = False
9326+
if ((self.serializer is not None and self.deserializer is None) or
9327+
(self.serializer is None and self.deserializer is not None)):
9328+
raise ValueError("Both serializer and deserializer must be provided together, or neither should be provided")
9329+
if self.serializer is not None and self.deserializer is not None:
9330+
use_serializer = True
9331+
if use_serializer:
9332+
body = self.serializer.serialize(body)
93219333
operation_input_args = {
93229334
'EndpointName': self.endpoint_name,
93239335
'Body': body,
@@ -9343,6 +9355,11 @@ def invoke(
93439355
logger.debug(f"Response: {response}")
93449356

93459357
transformed_response = transform(response, 'InvokeEndpointOutput')
9358+
# Deserialize the body if a deserializer is provided
9359+
if use_serializer:
9360+
body_content = transformed_response["body"]
9361+
deserialized_body = self.deserializer.deserialize(body_content, transformed_response["content_type"])
9362+
transformed_response["body"] = deserialized_body
93469363
return InvokeEndpointOutput(**transformed_response)
93479364

93489365

sagemaker-core/src/sagemaker/core/tools/resources_codegen.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
RESOURCE_METHOD_EXCEPTION_DOCSTRING,
7575
INIT_WAIT_LOGS_TEMPLATE,
7676
PRINT_WAIT_LOGS,
77+
SERIALIZE_INPUT_ENDPOINT_TEMPLATE,
78+
DESERIALIZE_RESPONSE_ENDPOINT_TEMPLATE,
7779
)
7880
from sagemaker.core.tools.data_extractor import (
7981
load_combined_shapes_data,
@@ -195,6 +197,8 @@ def generate_imports(self) -> str:
195197
"from sagemaker.core.utils.logs import MultiLogStreamHandler",
196198
"from sagemaker.core.utils.exceptions import *",
197199
"from typing import ClassVar",
200+
"from sagemaker.utils.base_serializers import BaseSerializer",
201+
"from sagemaker.utils.base_deserializers import BaseDeserializer",
198202
]
199203

200204
formated_imports = "\n".join(imports)
@@ -518,6 +522,18 @@ def _get_class_attributes(self, resource_name: str, class_methods: list) -> tupl
518522
class_attributes_string = (
519523
class_attributes_string + "hub_name: Optional[str] = Unassigned()"
520524
)
525+
if resource_name == "Endpoint":
526+
class_attributes["serializer"] = "Optional[BaseSerializer] = None"
527+
class_attributes_string = class_attributes_string.replace("serializer: BaseSerializer", "")
528+
class_attributes_string = (
529+
class_attributes_string + "serializer: Optional[BaseSerializer] = None\n"
530+
)
531+
class_attributes["deserializer"] = "Optional[BaseDeserializer] = None"
532+
class_attributes_string = class_attributes_string.replace("deserializer: BaseDeserializer", "")
533+
class_attributes_string = (
534+
class_attributes_string + "deserializer: Optional[BaseDeserializer] = None\n"
535+
)
536+
521537

522538
return class_attributes, class_attributes_string, attributes_and_documentation
523539
elif "get_all" in class_methods:
@@ -1471,9 +1487,21 @@ def generate_method(self, method: Method, resource_attributes: list):
14711487
initialize_client = INITIALIZE_CLIENT_TEMPLATE.format(service_name=method.service_name)
14721488
if len(self.shapes[operation_input_shape_name]["members"]) != 0:
14731489
# the method has input arguments
1474-
serialize_operation_input = SERIALIZE_INPUT_TEMPLATE.format(
1475-
operation_input_args=operation_input_args
1476-
)
1490+
if method.resource_name == "Endpoint" and method.method_name == "invoke":
1491+
serialize_operation_input = SERIALIZE_INPUT_ENDPOINT_TEMPLATE.format(
1492+
operation_input_args=operation_input_args
1493+
)
1494+
return_type_conversion = method.return_type
1495+
operation_output_shape = operation_metadata["output"]["shape"]
1496+
deserialize_response = DESERIALIZE_RESPONSE_ENDPOINT_TEMPLATE.format(
1497+
operation_output_shape=operation_output_shape,
1498+
return_type_conversion=return_type_conversion,
1499+
)
1500+
1501+
else:
1502+
serialize_operation_input = SERIALIZE_INPUT_TEMPLATE.format(
1503+
operation_input_args=operation_input_args
1504+
)
14771505
call_operation_api = CALL_OPERATION_API_TEMPLATE.format(
14781506
operation=convert_to_snake_case(method.operation_name)
14791507
)

sagemaker-core/src/sagemaker/core/tools/templates.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def {method_name}(
600600

601601
RESOURCE_BASE_CLASS_TEMPLATE = """
602602
class Base(BaseModel):
603-
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")
603+
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid", arbitrary_types_allowed=True)
604604
config_manager: ClassVar[SageMakerConfig] = SageMakerConfig()
605605
606606
@classmethod
@@ -711,3 +711,28 @@ class {class_name}:
711711
error_message = e.response['Error']['Message']
712712
error_code = e.response['Error']['Code']
713713
```"""
714+
715+
SERIALIZE_INPUT_ENDPOINT_TEMPLATE = """
716+
use_serializer = False
717+
if ((self.serializer is not None and self.deserializer is None) or
718+
(self.serializer is None and self.deserializer is not None)):
719+
raise ValueError("Both serializer and deserializer must be provided together, or neither should be provided")
720+
if self.serializer is not None and self.deserializer is not None:
721+
use_serializer = True
722+
if use_serializer:
723+
body = self.serializer.serialize(body)
724+
operation_input_args = {{
725+
{operation_input_args}
726+
}}
727+
# serialize the input request
728+
operation_input_args = serialize(operation_input_args)
729+
logger.debug(f"Serialized input request: {{operation_input_args}}")"""
730+
731+
DESERIALIZE_RESPONSE_ENDPOINT_TEMPLATE = """
732+
transformed_response = transform(response, 'InvokeEndpointOutput')
733+
# Deserialize the body if a deserializer is provided
734+
if use_serializer:
735+
body_content = transformed_response["body"]
736+
deserialized_body = self.deserializer.deserialize(body_content, transformed_response["content_type"])
737+
transformed_response["body"] = deserialized_body
738+
return {return_type_conversion}(**transformed_response)"""

sagemaker-core/src/sagemaker/core/utils/code_injection/codec.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
LIST_TYPE,
2323
MAP_TYPE,
2424
)
25-
25+
from io import BytesIO
2626

2727
def pascal_to_snake(pascal_str):
2828
"""
@@ -244,6 +244,15 @@ def transform(data, shape, object_instance=None) -> dict:
244244
elif _member_type == MAP_TYPE:
245245
_map_type_shape = SHAPE_DAG[_member_shape]
246246
evaluated_value = _evaluate_map_type(data[_member_name], _map_type_shape)
247+
elif _member_type == 'blob':
248+
blob_data = data[_member_name]
249+
if isinstance(blob_data, bytes):
250+
evaluated_value = BytesIO(blob_data)
251+
elif hasattr(blob_data, 'read'):
252+
# If it's already a file-like object, use it as is
253+
evaluated_value = blob_data
254+
else:
255+
raise ValueError(f"Unexpected blob data type: {type(blob_data)}")
247256
else:
248257
raise ValueError(f"Unexpected member type encountered: {_member_type}")
249258

0 commit comments

Comments
 (0)