diff --git a/sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py b/sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py index 8dd09ac2a0..7b98345366 100644 --- a/sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py +++ b/sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py @@ -25,6 +25,7 @@ from sagemaker.train.model_trainer import ModelTrainer from sagemaker.train.configs import SourceCode from sagemaker.core.resources import EndpointConfig +from sagemaker.core.helper.session_helper import Session logger = logging.getLogger(__name__) @@ -37,6 +38,8 @@ AWS_REGION = "us-west-2" PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39" +sagemaker_session = Session() + @pytest.mark.slow_test def test_train_inference_e2e_build_deploy_invoke_cleanup(): @@ -143,13 +146,6 @@ def create_schema_builder(): def train_model(): """Train model using ModelTrainer.""" - from sagemaker.core.helper.session_helper import Session - import boto3 - - # Create SageMaker session with AWS region - boto_session = boto3.Session(region_name=AWS_REGION) - sagemaker_session = Session(boto_session=boto_session) - training_code_dir = create_pytorch_training_code() unique_id = str(uuid.uuid4())[:8] @@ -192,9 +188,10 @@ def invoke(self, input_object, model): inference_spec=SimpleInferenceSpec(), image_uri=PYTORCH_TRAINING_IMAGE.replace("training", "inference"), dependencies={"auto": False}, + sagemaker_session=sagemaker_session, ) - core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}", region="us-west-2") + core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}") logger.info(f"Model Successfully Created: {core_model.model_name}") core_endpoint = model_builder.deploy( @@ -221,7 +218,9 @@ def make_prediction(core_endpoint): def cleanup_resources(core_model, core_endpoint): """Fully clean up model and endpoint creation - preserving exact logic from manual test""" - core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name) + core_endpoint_config = EndpointConfig.get( + endpoint_config_name=core_endpoint.endpoint_name, + ) core_model.delete() core_endpoint.delete() diff --git a/sagemaker-serve/tests/integ/test_triton_integration.py b/sagemaker-serve/tests/integ/test_triton_integration.py index 5b2c3b2e3a..0ff32a76a2 100644 --- a/sagemaker-serve/tests/integ/test_triton_integration.py +++ b/sagemaker-serve/tests/integ/test_triton_integration.py @@ -22,6 +22,7 @@ from sagemaker.serve.model_builder import ModelBuilder from sagemaker.serve.utils.types import ModelServer from sagemaker.core.resources import EndpointConfig +from sagemaker.core.helper.session_helper import Session # PyTorch Imports import torch @@ -33,6 +34,8 @@ MODEL_NAME_PREFIX = "triton-test-model" ENDPOINT_NAME_PREFIX = "triton-test-endpoint" +sagemaker_session = Session() + # Create a simple PyTorch model class SimpleModel(nn.Module): @@ -96,11 +99,12 @@ def build_and_deploy(): schema_builder = create_schema_builder() model_builder = ModelBuilder( - model=pytorch_model, - model_path=model_path, - model_server=ModelServer.TRITON, - schema_builder=schema_builder - ) + model=pytorch_model, + model_path=model_path, + model_server=ModelServer.TRITON, + schema_builder=schema_builder, + sagemaker_session=sagemaker_session, + ) unique_id = str(uuid.uuid4())[:8] # Build and deploy your model. Returns SageMaker Core Model and Endpoint objects @@ -139,7 +143,9 @@ def make_prediction(core_endpoint): def cleanup_resources(core_model, core_endpoint): """Fully clean up model and endpoint creation - preserving exact logic from manual test""" - core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name) + core_endpoint_config = EndpointConfig.get( + endpoint_config_name=core_endpoint.endpoint_name, + ) core_model.delete() core_endpoint.delete()