Skip to content

Commit d939202

Browse files
authored
change: task 3 remove dependency on predictor (#1725)
* add the remaining default serializers and remove comment code * entirely remove get_predictor() * test: add unit test
1 parent 3f95c12 commit d939202

File tree

3 files changed

+97
-51
lines changed

3 files changed

+97
-51
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ env/
3535
**/_repack_script_launcher.sh
3636
sagemaker_train/src/**/container_drivers/sm_train.sh
3737
sagemaker_train/src/**/container_drivers/sourcecode.json
38-
sagemaker_train/src/**/container_drivers/distributed.json
38+
sagemaker_train/src/**/container_drivers/distributed.json

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from sagemaker.batch_inference.batch_transform_inference_config import \
3131
BatchTransformInferenceConfig
3232
from sagemaker.compute_resource_requirements import ResourceRequirements
33-
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
33+
from sagemaker.deserializers import (CSVDeserializer, JSONDeserializer,
34+
RecordDeserializer,
35+
TorchTensorDeserializer,
36+
NumpyDeserializer)
3437
from sagemaker.enums import EndpointType, Tag
3538
from sagemaker.estimator import Estimator
3639
from sagemaker.huggingface.llm_utils import (
@@ -43,7 +46,9 @@
4346
from sagemaker.modules.train import ModelTrainer
4447
from sagemaker.predictor import Predictor
4548
from sagemaker.s3 import S3Downloader
46-
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
49+
from sagemaker.serializers import (LibSVMSerializer, NumpySerializer,
50+
RecordSerializer, TorchTensorSerializer,
51+
JSONSerializer)
4752
from sagemaker.serve.builder.djl_builder import DJL
4853
from sagemaker.serve.builder.jumpstart_builder import JumpStart
4954
from sagemaker.serve.builder.schema_builder import SchemaBuilder
@@ -116,6 +121,21 @@
116121
ModelServer.SMD,
117122
}
118123

124+
# Default serializers and deserializers by framework
125+
DEFAULT_SERIALIZERS_BY_FRAMEWORK = {
126+
"XGBoost": (LibSVMSerializer(), CSVDeserializer()),
127+
"LDA": (RecordSerializer(), RecordDeserializer()),
128+
"PyTorch": (TorchTensorSerializer(), JSONDeserializer()),
129+
"TensorFlow": (NumpySerializer(), JSONDeserializer()),
130+
"MXNet": (RecordSerializer(), JSONDeserializer()), # MxNetPredictor
131+
"Chainer": (NumpySerializer(), JSONDeserializer()), # ChainerPredictor
132+
"SKLearn": (NumpySerializer(), NumpyDeserializer()), # SKLearnPredictor
133+
"HuggingFace": (JSONSerializer(), JSONDeserializer()), # HuggingFacePredictor
134+
"DJL": (JSONSerializer(), JSONDeserializer()), # DJLPredictor
135+
"SparkML": (NumpySerializer(), JSONDeserializer()), # SparkMLPredictor
136+
"NTM": (RecordSerializer(), JSONDeserializer()), # NTMPredictor
137+
}
138+
119139

120140
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
121141
@dataclass
@@ -465,9 +485,27 @@ def _prepare_for_mode(
465485
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
466486
)
467487

488+
def _fetch_serializer_and_deserializer_for_framework(self, framework: str):
489+
"""Fetch the default serializer and deserializer for a given framework.
490+
491+
Args:
492+
framework (str): The framework name.
493+
494+
Returns:
495+
tuple: A tuple containing (serializer, deserializer).
496+
"""
497+
if framework in DEFAULT_SERIALIZERS_BY_FRAMEWORK:
498+
return DEFAULT_SERIALIZERS_BY_FRAMEWORK[framework]
499+
500+
# Default to JSON serialization if framework not found
501+
return NumpySerializer(), JSONDeserializer()
502+
468503
def _get_client_translators(self):
469-
"""Placeholder docstring"""
504+
"""Get serializer and deserializer for client-side translation."""
470505
serializer = None
506+
deserializer = None
507+
508+
# If content_type or accept_type are explicitly provided, use those
471509
if self.content_type == "application/x-npy":
472510
serializer = NumpySerializer()
473511
elif self.content_type == "tensor/pt":
@@ -476,10 +514,7 @@ def _get_client_translators(self):
476514
serializer = self.schema_builder.custom_input_translator
477515
elif self.schema_builder:
478516
serializer = self.schema_builder.input_serializer
479-
else:
480-
raise Exception("Cannot serialize. Try providing a SchemaBuilder if not present.")
481517

482-
deserializer = None
483518
if self.accept_type == "application/json":
484519
deserializer = JSONDeserializer()
485520
elif self.accept_type == "tensor/pt":
@@ -488,58 +523,45 @@ def _get_client_translators(self):
488523
deserializer = self.schema_builder.custom_output_translator
489524
elif self.schema_builder:
490525
deserializer = self.schema_builder.output_deserializer
491-
else:
492-
raise Exception("Cannot deserialize. Try providing a SchemaBuilder if not present.")
493526

494-
return serializer, deserializer
527+
# If serializer or deserializer are still None, try to infer from framework
528+
if (serializer is None or deserializer is None) and hasattr(self, "_framework"):
529+
default_serializer, default_deserializer = self._fetch_serializer_and_deserializer_for_framework(
530+
self._framework
531+
)
532+
if serializer is None:
533+
serializer = default_serializer
534+
if deserializer is None:
535+
deserializer = default_deserializer
495536

496-
def _get_predictor(
497-
self, endpoint_name: str, sagemaker_session: Session, component_name: Optional[str] = None
498-
) -> Predictor:
499-
"""Placeholder docstring"""
500-
serializer, deserializer = self._get_client_translators()
537+
# If still None, raise an exception
538+
if serializer is None:
539+
raise Exception("Cannot serialize. Try providing a SchemaBuilder if not present.")
540+
if deserializer is None:
541+
raise Exception("Cannot deserialize. Try providing a SchemaBuilder if not present.")
501542

502-
return Predictor(
503-
endpoint_name=endpoint_name,
504-
sagemaker_session=sagemaker_session,
505-
serializer=serializer,
506-
deserializer=deserializer,
507-
component_name=component_name,
508-
)
543+
return serializer, deserializer
509544

510545
def _create_model(self):
511-
"""Placeholder docstring"""
512-
# TODO: we should create model as per the framework
513-
self.pysdk_model = Model(
546+
"""Create a sagemaker-core Model instance."""
547+
from sagemaker_core.resources import Model as CoreModel
548+
549+
# Create the sagemaker-core Model
550+
core_model = CoreModel(
514551
image_uri=self.image_uri,
515552
image_config=self.image_config,
516553
vpc_config=self.vpc_config,
517554
model_data=self.s3_upload_path,
518555
role=self.serve_settings.role_arn,
519556
env=self.env_vars,
520557
sagemaker_session=self.sagemaker_session,
521-
predictor_cls=self._get_predictor,
522-
name=self.name,
558+
name=self.name
523559
)
524560

525-
# store the modes in the model so that we may
526-
# reference the configurations for local deploy() & predict()
527-
self.pysdk_model.mode = self.mode
528-
self.pysdk_model.modes = self.modes
529-
self.pysdk_model.serve_settings = self.serve_settings
530-
if self.role_arn:
531-
self.pysdk_model.role = self.role_arn
532-
if self.sagemaker_session:
533-
self.pysdk_model.sagemaker_session = self.sagemaker_session
534-
535-
# dynamically generate a method to direct model.deploy() logic based on mode
536-
# unique method to models created via ModelBuilder()
537-
self._original_deploy = self.pysdk_model.deploy
538-
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
539-
self._original_register = self.pysdk_model.register
540-
self.pysdk_model.register = self._model_builder_register_wrapper
541-
self.model_package = None
542-
return self.pysdk_model
561+
# Store any necessary information for later use in deploy()
562+
self._serializer, self._deserializer = self._get_client_translators()
563+
564+
return core_model
543565

544566
@_capture_telemetry("register")
545567
def _model_builder_register_wrapper(self, *args, **kwargs):
@@ -766,8 +788,8 @@ def _build_for_torchserve(self) -> Type[Model]:
766788
)
767789

768790
self._prepare_for_mode()
769-
self.model = self._create_model()
770-
return self.model
791+
model = self._create_model()
792+
return model
771793

772794
def _build_for_smd(self) -> Type[Model]:
773795
"""Build the model for SageMaker Distribution"""
@@ -784,8 +806,8 @@ def _build_for_smd(self) -> Type[Model]:
784806
)
785807

786808
self._prepare_for_mode()
787-
self.model = self._create_model()
788-
return self.model
809+
model = self._create_model()
810+
return model
789811

790812
def _user_agent_decorator(self, func):
791813
"""Placeholder docstring"""
@@ -1204,7 +1226,12 @@ def _build_single_modelbuilder( # pylint: disable=R0911
12041226
12051227
Returns:
12061228
Type[Model]: A deployable ``Model`` object.
1229+
1230+
.. note::
1231+
In a future version, this method will return a sagemaker-core Model object instead.
12071232
"""
1233+
# Store serializers/deserializers for later use
1234+
self._serializer, self._deserializer = self._get_client_translators()
12081235

12091236
self.modes = dict()
12101237

@@ -1314,8 +1341,9 @@ def _build_single_modelbuilder( # pylint: disable=R0911
13141341
# Set TorchServe as default model server
13151342
if not self.model_server:
13161343
self.model_server = ModelServer.TORCHSERVE
1317-
self.built_model = self._build_for_torchserve()
1318-
return self.built_model
1344+
model = self._build_for_torchserve()
1345+
self.built_model = model
1346+
return model
13191347

13201348
raise ValueError("%s model server is not supported" % self.model_server)
13211349

sagemaker-serve/tests/unit/test_model_builder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.enums import EndpointType
2929
from sagemaker.model import Model
3030
from sagemaker.serve import SchemaBuilder
31+
# TODO: target v3 model builder. currently not possible due to patch magicmocks
3132
from sagemaker.serve.builder.model_builder import ModelBuilder
3233
from sagemaker.serve.mode.function_pointers import Mode
3334
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_TRACKING_ARN
@@ -39,6 +40,9 @@
3940
_validate_optimization_configuration
4041
from sagemaker.serverless.serverless_inference_config import \
4142
ServerlessInferenceConfig
43+
from sagemaker.serve.model_builder import DEFAULT_SERIALIZERS_BY_FRAMEWORK
44+
from sagemaker.base_deserializers import JSONDeserializer
45+
from sagemaker.base_serializers import NumpySerializer
4246

4347
schema_builder = MagicMock()
4448
mock_inference_spec = Mock()
@@ -3439,6 +3443,20 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding(
34393443
),
34403444
)
34413445

3446+
def test_fetch_serializer_and_deserializer_for_framework(self):
3447+
"""Test that _fetch_serializer_and_deserializer_for_framework returns the correct serializer/deserializer pairs."""
3448+
builder = ModelBuilder()
3449+
3450+
# Test for known frameworks
3451+
for framework, expected_pair in DEFAULT_SERIALIZERS_BY_FRAMEWORK.items():
3452+
serializer, deserializer = builder._fetch_serializer_and_deserializer_for_framework(framework)
3453+
self.assertEqual(type(serializer), type(expected_pair[0]))
3454+
self.assertEqual(type(deserializer), type(expected_pair[1]))
3455+
3456+
# Test for unknown framework - should return default (NumpySerializer, JSONDeserializer)
3457+
serializer, deserializer = builder._fetch_serializer_and_deserializer_for_framework("UnknownFramework")
3458+
self.assertIsInstance(serializer, NumpySerializer)
3459+
self.assertIsInstance(deserializer, JSONDeserializer)
34423460

34433461
class TestModelBuilderOptimizationSharding(unittest.TestCase):
34443462
@patch.object(ModelBuilder, "_prepare_for_mode")

0 commit comments

Comments
 (0)