Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode
from sagemaker.core.resources import EndpointConfig
from sagemaker.core.helper.session_helper import Session

logger = logging.getLogger(__name__)

Expand All @@ -37,6 +38,8 @@
AWS_REGION = "us-west-2"
PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39"

sagemaker_session = Session()


@pytest.mark.slow_test
def test_train_inference_e2e_build_deploy_invoke_cleanup():
Expand Down Expand Up @@ -143,13 +146,6 @@ def create_schema_builder():

def train_model():
"""Train model using ModelTrainer."""
from sagemaker.core.helper.session_helper import Session
import boto3

# Create SageMaker session with AWS region
boto_session = boto3.Session(region_name=AWS_REGION)
sagemaker_session = Session(boto_session=boto_session)

training_code_dir = create_pytorch_training_code()
unique_id = str(uuid.uuid4())[:8]

Expand Down Expand Up @@ -192,9 +188,10 @@ def invoke(self, input_object, model):
inference_spec=SimpleInferenceSpec(),
image_uri=PYTORCH_TRAINING_IMAGE.replace("training", "inference"),
dependencies={"auto": False},
sagemaker_session=sagemaker_session,
)

core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}", region="us-west-2")
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
logger.info(f"Model Successfully Created: {core_model.model_name}")

core_endpoint = model_builder.deploy(
Expand All @@ -221,7 +218,9 @@ def make_prediction(core_endpoint):

def cleanup_resources(core_model, core_endpoint):
"""Fully clean up model and endpoint creation - preserving exact logic from manual test"""
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_endpoint_config = EndpointConfig.get(
endpoint_config_name=core_endpoint.endpoint_name,
)

core_model.delete()
core_endpoint.delete()
Expand Down
18 changes: 12 additions & 6 deletions sagemaker-serve/tests/integ/test_triton_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.core.resources import EndpointConfig
from sagemaker.core.helper.session_helper import Session

# PyTorch Imports
import torch
Expand All @@ -33,6 +34,8 @@
MODEL_NAME_PREFIX = "triton-test-model"
ENDPOINT_NAME_PREFIX = "triton-test-endpoint"

sagemaker_session = Session()


# Create a simple PyTorch model
class SimpleModel(nn.Module):
Expand Down Expand Up @@ -96,11 +99,12 @@ def build_and_deploy():
schema_builder = create_schema_builder()

model_builder = ModelBuilder(
model=pytorch_model,
model_path=model_path,
model_server=ModelServer.TRITON,
schema_builder=schema_builder
)
model=pytorch_model,
model_path=model_path,
model_server=ModelServer.TRITON,
schema_builder=schema_builder,
sagemaker_session=sagemaker_session,
)

unique_id = str(uuid.uuid4())[:8]
# Build and deploy your model. Returns SageMaker Core Model and Endpoint objects
Expand Down Expand Up @@ -139,7 +143,9 @@ def make_prediction(core_endpoint):

def cleanup_resources(core_model, core_endpoint):
"""Fully clean up model and endpoint creation - preserving exact logic from manual test"""
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_endpoint_config = EndpointConfig.get(
endpoint_config_name=core_endpoint.endpoint_name,
)

core_model.delete()
core_endpoint.delete()
Expand Down
Loading