3030from sagemaker .batch_inference .batch_transform_inference_config import \
3131 BatchTransformInferenceConfig
3232from sagemaker .compute_resource_requirements import ResourceRequirements
33- from sagemaker .deserializers import JSONDeserializer , TorchTensorDeserializer
33+ from sagemaker .deserializers import (CSVDeserializer , JSONDeserializer ,
34+ RecordDeserializer ,
35+ TorchTensorDeserializer ,
36+ NumpyDeserializer )
3437from sagemaker .enums import EndpointType , Tag
3538from sagemaker .estimator import Estimator
3639from sagemaker .huggingface .llm_utils import (
4346from sagemaker .modules .train import ModelTrainer
4447from sagemaker .predictor import Predictor
4548from sagemaker .s3 import S3Downloader
46- from sagemaker .serializers import NumpySerializer , TorchTensorSerializer
49+ from sagemaker .serializers import (LibSVMSerializer , NumpySerializer ,
50+ RecordSerializer , TorchTensorSerializer ,
51+ JSONSerializer )
4752from sagemaker .serve .builder .djl_builder import DJL
4853from sagemaker .serve .builder .jumpstart_builder import JumpStart
4954from sagemaker .serve .builder .schema_builder import SchemaBuilder
116121 ModelServer .SMD ,
117122}
118123
124+ # Default serializers and deserializers by framework
125+ DEFAULT_SERIALIZERS_BY_FRAMEWORK = {
126+ "XGBoost" : (LibSVMSerializer (), CSVDeserializer ()),
127+ "LDA" : (RecordSerializer (), RecordDeserializer ()),
128+ "PyTorch" : (TorchTensorSerializer (), JSONDeserializer ()),
129+ "TensorFlow" : (NumpySerializer (), JSONDeserializer ()),
130+ "MXNet" : (RecordSerializer (), JSONDeserializer ()), # MxNetPredictor
131+ "Chainer" : (NumpySerializer (), JSONDeserializer ()), # ChainerPredictor
132+ "SKLearn" : (NumpySerializer (), NumpyDeserializer ()), # SKLearnPredictor
133+ "HuggingFace" : (JSONSerializer (), JSONDeserializer ()), # HuggingFacePredictor
134+ "DJL" : (JSONSerializer (), JSONDeserializer ()), # DJLPredictor
135+ "SparkML" : (NumpySerializer (), JSONDeserializer ()), # SparkMLPredictor
136+ "NTM" : (RecordSerializer (), JSONDeserializer ()), # NTMPredictor
137+ }
138+
119139
120140# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
121141@dataclass
@@ -465,9 +485,27 @@ def _prepare_for_mode(
465485 % (Mode .LOCAL_CONTAINER , Mode .SAGEMAKER_ENDPOINT , Mode .IN_PROCESS )
466486 )
467487
488+ def _fetch_serializer_and_deserializer_for_framework (self , framework : str ):
489+ """Fetch the default serializer and deserializer for a given framework.
490+
491+ Args:
492+ framework (str): The framework name.
493+
494+ Returns:
495+ tuple: A tuple containing (serializer, deserializer).
496+ """
497+ if framework in DEFAULT_SERIALIZERS_BY_FRAMEWORK :
498+ return DEFAULT_SERIALIZERS_BY_FRAMEWORK [framework ]
499+
500+ # Default to JSON serialization if framework not found
501+ return NumpySerializer (), JSONDeserializer ()
502+
468503 def _get_client_translators (self ):
469- """Placeholder docstring """
504+ """Get serializer and deserializer for client-side translation. """
470505 serializer = None
506+ deserializer = None
507+
508+ # If content_type or accept_type are explicitly provided, use those
471509 if self .content_type == "application/x-npy" :
472510 serializer = NumpySerializer ()
473511 elif self .content_type == "tensor/pt" :
@@ -476,10 +514,7 @@ def _get_client_translators(self):
476514 serializer = self .schema_builder .custom_input_translator
477515 elif self .schema_builder :
478516 serializer = self .schema_builder .input_serializer
479- else :
480- raise Exception ("Cannot serialize. Try providing a SchemaBuilder if not present." )
481517
482- deserializer = None
483518 if self .accept_type == "application/json" :
484519 deserializer = JSONDeserializer ()
485520 elif self .accept_type == "tensor/pt" :
@@ -488,58 +523,45 @@ def _get_client_translators(self):
488523 deserializer = self .schema_builder .custom_output_translator
489524 elif self .schema_builder :
490525 deserializer = self .schema_builder .output_deserializer
491- else :
492- raise Exception ("Cannot deserialize. Try providing a SchemaBuilder if not present." )
493526
494- return serializer , deserializer
527+ # If serializer or deserializer are still None, try to infer from framework
528+ if (serializer is None or deserializer is None ) and hasattr (self , "_framework" ):
529+ default_serializer , default_deserializer = self ._fetch_serializer_and_deserializer_for_framework (
530+ self ._framework
531+ )
532+ if serializer is None :
533+ serializer = default_serializer
534+ if deserializer is None :
535+ deserializer = default_deserializer
495536
496- def _get_predictor (
497- self , endpoint_name : str , sagemaker_session : Session , component_name : Optional [ str ] = None
498- ) -> Predictor :
499- """Placeholder docstring"""
500- serializer , deserializer = self . _get_client_translators ( )
537+ # If still None, raise an exception
538+ if serializer is None :
539+ raise Exception ( "Cannot serialize. Try providing a SchemaBuilder if not present." )
540+ if deserializer is None :
541+ raise Exception ( "Cannot deserialize. Try providing a SchemaBuilder if not present." )
501542
502- return Predictor (
503- endpoint_name = endpoint_name ,
504- sagemaker_session = sagemaker_session ,
505- serializer = serializer ,
506- deserializer = deserializer ,
507- component_name = component_name ,
508- )
543+ return serializer , deserializer
509544
510545 def _create_model (self ):
511- """Placeholder docstring"""
512- # TODO: we should create model as per the framework
513- self .pysdk_model = Model (
546+ """Create a sagemaker-core Model instance."""
547+ from sagemaker_core .resources import Model as CoreModel
548+
549+ # Create the sagemaker-core Model
550+ core_model = CoreModel (
514551 image_uri = self .image_uri ,
515552 image_config = self .image_config ,
516553 vpc_config = self .vpc_config ,
517554 model_data = self .s3_upload_path ,
518555 role = self .serve_settings .role_arn ,
519556 env = self .env_vars ,
520557 sagemaker_session = self .sagemaker_session ,
521- predictor_cls = self ._get_predictor ,
522- name = self .name ,
558+ name = self .name
523559 )
524560
525- # store the modes in the model so that we may
526- # reference the configurations for local deploy() & predict()
527- self .pysdk_model .mode = self .mode
528- self .pysdk_model .modes = self .modes
529- self .pysdk_model .serve_settings = self .serve_settings
530- if self .role_arn :
531- self .pysdk_model .role = self .role_arn
532- if self .sagemaker_session :
533- self .pysdk_model .sagemaker_session = self .sagemaker_session
534-
535- # dynamically generate a method to direct model.deploy() logic based on mode
536- # unique method to models created via ModelBuilder()
537- self ._original_deploy = self .pysdk_model .deploy
538- self .pysdk_model .deploy = self ._model_builder_deploy_wrapper
539- self ._original_register = self .pysdk_model .register
540- self .pysdk_model .register = self ._model_builder_register_wrapper
541- self .model_package = None
542- return self .pysdk_model
561+ # Store any necessary information for later use in deploy()
562+ self ._serializer , self ._deserializer = self ._get_client_translators ()
563+
564+ return core_model
543565
544566 @_capture_telemetry ("register" )
545567 def _model_builder_register_wrapper (self , * args , ** kwargs ):
@@ -766,8 +788,8 @@ def _build_for_torchserve(self) -> Type[Model]:
766788 )
767789
768790 self ._prepare_for_mode ()
769- self . model = self ._create_model ()
770- return self . model
791+ model = self ._create_model ()
792+ return model
771793
772794 def _build_for_smd (self ) -> Type [Model ]:
773795 """Build the model for SageMaker Distribution"""
@@ -784,8 +806,8 @@ def _build_for_smd(self) -> Type[Model]:
784806 )
785807
786808 self ._prepare_for_mode ()
787- self . model = self ._create_model ()
788- return self . model
809+ model = self ._create_model ()
810+ return model
789811
790812 def _user_agent_decorator (self , func ):
791813 """Placeholder docstring"""
@@ -1204,7 +1226,12 @@ def _build_single_modelbuilder( # pylint: disable=R0911
12041226
12051227 Returns:
12061228 Type[Model]: A deployable ``Model`` object.
1229+
1230+ .. note::
1231+ In a future version, this method will return a sagemaker-core Model object instead.
12071232 """
1233+ # Store serializers/deserializers for later use
1234+ self ._serializer , self ._deserializer = self ._get_client_translators ()
12081235
12091236 self .modes = dict ()
12101237
@@ -1314,8 +1341,9 @@ def _build_single_modelbuilder( # pylint: disable=R0911
13141341 # Set TorchServe as default model server
13151342 if not self .model_server :
13161343 self .model_server = ModelServer .TORCHSERVE
1317- self .built_model = self ._build_for_torchserve ()
1318- return self .built_model
1344+ model = self ._build_for_torchserve ()
1345+ self .built_model = model
1346+ return model
13191347
13201348 raise ValueError ("%s model server is not supported" % self .model_server )
13211349
0 commit comments