11import logging
22import os
3+ from datetime import timedelta
4+ from typing import Optional
35
46import pytest
57import sagemaker
@@ -21,7 +23,7 @@ def test_clean_train_warm_pool():
2123 py_version = 'py38' ,
2224 instance_count = 1 ,
2325 instance_type = 'ml.m5.xlarge' ,
24- max_run = 60 * 60 * 3 ,
26+ max_run = int ( timedelta ( minutes = 15 ). total_seconds ()) ,
2527 keep_alive_period_in_seconds = 1800 ,
2628 container_log_level = logging .INFO )
2729 estimator .fit ()
@@ -32,15 +34,14 @@ def test_clean_train_warm_pool():
3234
3335
3436# noinspection DuplicatedCode
35- def test_clean_inference (request ):
37+ def test_clean_inference ():
3638 estimator = PyTorch (entry_point = 'train_clean.py' ,
3739 source_dir = 'source_dir/training_clean/' ,
38- role = request .config .getini ('sagemaker_role' ),
3940 framework_version = '1.9.1' ,
4041 py_version = 'py38' ,
4142 instance_count = 1 ,
4243 instance_type = 'ml.m5.xlarge' ,
43- max_run = 60 * 60 * 3 ,
44+ max_run = int ( timedelta ( minutes = 15 ). total_seconds ()) ,
4445 keep_alive_period_in_seconds = 1800 ,
4546 container_log_level = logging .INFO )
4647 estimator .fit ()
@@ -65,41 +66,30 @@ def test_clean_inference(request):
6566
6667# noinspection DuplicatedCode
6768@pytest .mark .parametrize ("instance_type" , ["ml.m5.xlarge" ])
68- def test_clean_inference_mms (request , instance_type ):
69+ def test_clean_inference_mms (instance_type ):
6970 estimator = PyTorch (entry_point = 'train_clean.py' ,
7071 source_dir = 'source_dir/training_clean/' ,
71- role = request .config .getini ('sagemaker_role' ),
7272 framework_version = '1.9.1' ,
7373 py_version = 'py38' ,
7474 instance_count = 1 ,
7575 instance_type = instance_type ,
76- max_run = 60 * 60 * 3 ,
76+ max_run = int ( timedelta ( minutes = 15 ). total_seconds ()) ,
7777 keep_alive_period_in_seconds = 1800 ,
7878 container_log_level = logging .INFO )
7979 estimator .fit ()
8080
8181 model_1 = estimator .create_model (entry_point = 'inference_clean.py' ,
8282 source_dir = 'source_dir/inference_clean/' )
8383
84- # we need a temp endpoint to produce 'repacked_model_data'
85- temp_endpoint_name = name_from_base ('temp-inference-mms' )
86- temp_predictor : Predictor = model_1 .deploy (initial_instance_count = 1 ,
87- instance_type = 'ml.m5.xlarge' ,
88- endpoint_name = temp_endpoint_name ,
89- wait = True )
84+ _ = model_1 .prepare_container_def (instance_type = 'ml.m5.xlarge' )
9085 repacked_model_data_1 = model_1 .repacked_model_data
91- temp_predictor .delete_endpoint ()
9286
9387 # MUST have the same entry point file name as for the model_1
9488 model_2 = estimator .create_model (entry_point = 'inference_clean.py' ,
9589 source_dir = 'source_dir/inference_clean_model2/' )
96- temp_endpoint_name = name_from_base ('temp-inference-mms' )
97- temp_predictor : Predictor = model_2 .deploy (initial_instance_count = 1 ,
98- instance_type = 'ml.m5.xlarge' ,
99- endpoint_name = temp_endpoint_name ,
100- wait = True )
90+
91+ _ = model_2 .prepare_container_def (instance_type = 'ml.m5.xlarge' )
10192 repacked_model_data_2 = model_2 .repacked_model_data
102- temp_predictor .delete_endpoint ()
10393
10494 bucket = sagemaker .Session ().default_bucket ()
10595 job_name = estimator .latest_training_job .name
@@ -115,12 +105,15 @@ def test_clean_inference_mms(request, instance_type):
115105
116106 endpoint_name = name_from_base ('inference-mms' )
117107
118- predictor : Predictor = mdm .deploy (initial_instance_count = 1 ,
119- instance_type = 'ml.m5.xlarge' ,
120- endpoint_name = endpoint_name ,
121- wait = True )
122-
108+ predictor : Optional [Predictor ] = None
123109 try :
110+ predictor = mdm .deploy (
111+ initial_instance_count = 1 ,
112+ instance_type = 'ml.m5.xlarge' ,
113+ endpoint_name = endpoint_name ,
114+ wait = True
115+ )
116+
124117 # Note: we need a repacked model data here, not an estimator data
125118 mdm .add_model (model_data_source = repacked_model_data_1 , model_data_path = 'model_1.tar.gz' )
126119 mdm .add_model (model_data_source = repacked_model_data_2 , model_data_path = 'model_2.tar.gz' )
@@ -143,49 +136,32 @@ def test_clean_inference_mms(request, instance_type):
143136
144137# noinspection DuplicatedCode
145138@pytest .mark .parametrize ("instance_type" , ["ml.m5.xlarge" ])
146- def test_clean_inference_mms_without_model (request , instance_type ):
139+ def test_clean_inference_mms_without_model (instance_type ):
147140 estimator = PyTorch (entry_point = 'train_clean.py' ,
148141 source_dir = 'source_dir/training_clean/' ,
149- role = request .config .getini ('sagemaker_role' ),
150142 framework_version = '1.9.1' ,
151143 py_version = 'py38' ,
152144 instance_count = 1 ,
153145 instance_type = instance_type ,
154- max_run = 60 * 60 * 3 ,
146+ max_run = int ( timedelta ( minutes = 15 ). total_seconds ()) ,
155147 keep_alive_period_in_seconds = 1800 ,
156148 container_log_level = logging .INFO )
157149 estimator .fit ()
158150
159151 model_1 = estimator .create_model (entry_point = 'inference_clean.py' ,
160152 source_dir = 'source_dir/inference_clean/' )
161153
162- # we need a temp endpoint to produce 'repacked_model_data'
163- temp_endpoint_name = name_from_base ('temp-inference-mms' )
164- temp_predictor : Predictor = model_1 .deploy (initial_instance_count = 1 ,
165- instance_type = 'ml.m5.xlarge' ,
166- endpoint_name = temp_endpoint_name ,
167- wait = True )
154+ model_1_description = model_1 .prepare_container_def (instance_type = 'ml.m5.xlarge' )
168155 repacked_model_data_1 = model_1 .repacked_model_data
169- temp_predictor .delete_endpoint ()
170-
171- # But we still don't have access to the deployed container URI from Model object, so still need to use boto3.
172- # Re-fetch container and model data location from Container 1 of the model:
173- model_1_description = model_1 .sagemaker_session .describe_model (model_1 .name )
174- container_uri = model_1_description ['PrimaryContainer' ]['Image' ]
175- # Also re-fetch deploy environment:
176- deploy_env = model_1_description ['PrimaryContainer' ]['Environment' ]
156+ container_uri = model_1_description ['Image' ]
157+ deploy_env = model_1_description ['Environment' ]
177158
178159 # MUST have the same entry point file name as for the model_1
179160 model_2 = estimator .create_model (entry_point = 'inference_clean.py' ,
180161 source_dir = 'source_dir/inference_clean_model2/' )
181162
182- temp_endpoint_name = name_from_base ('temp-inference-mms' )
183- temp_predictor : Predictor = model_2 .deploy (initial_instance_count = 1 ,
184- instance_type = 'ml.m5.xlarge' ,
185- endpoint_name = temp_endpoint_name ,
186- wait = True )
163+ _ = model_2 .prepare_container_def (instance_type = 'ml.m5.xlarge' )
187164 repacked_model_data_2 = model_2 .repacked_model_data
188- temp_predictor .delete_endpoint ()
189165
190166 bucket = sagemaker .Session ().default_bucket ()
191167 job_name = estimator .latest_training_job .name
@@ -196,8 +172,8 @@ def test_clean_inference_mms_without_model(request, instance_type):
196172 mdm = MultiDataModel (
197173 name = mdm_name ,
198174 model_data_prefix = model_data_prefix ,
199- role = model_1 .role ,
200175 image_uri = container_uri ,
176+ # entry_point=model_1.entry_point, # NOTE: entry point ignored
201177 env = deploy_env , # will copy 'SAGEMAKER_PROGRAM' env variable with entry point file name
202178 predictor_cls = PyTorchPredictor
203179 )
0 commit comments