Skip to content

Commit 5df1659

Browse files
Added support for SageMaker defaults
1 parent d7bfeda commit 5df1659

File tree

10 files changed

+299
-209
lines changed

10 files changed

+299
-209
lines changed

tests/test_batch_inference.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from datetime import timedelta
23

34
import sagemaker
45
from sagemaker.pytorch import PyTorch
@@ -8,19 +9,18 @@
89
import test_util
910

1011

11-
def test_clean_batch_inference(request):
12+
def test_clean_batch_inference():
1213
# noinspection DuplicatedCode
1314
sagemaker_session = sagemaker.Session()
1415
bucket = sagemaker_session.default_bucket()
1516

1617
estimator = PyTorch(entry_point='train_clean.py',
1718
source_dir='source_dir/training_clean/',
18-
role=request.config.getini('sagemaker_role'),
1919
framework_version='1.9.1',
2020
py_version='py38',
2121
instance_count=1,
2222
instance_type='ml.m5.xlarge',
23-
max_run=60 * 60 * 3,
23+
max_run=int(timedelta(minutes=15).total_seconds()),
2424
keep_alive_period_in_seconds=1800,
2525
container_log_level=logging.INFO)
2626
estimator.fit()
@@ -51,24 +51,23 @@ def test_clean_batch_inference(request):
5151
key_prefix='batch-transform/output')
5252

5353

54-
def test_batch_ssh(request):
54+
def test_batch_ssh():
5555
# noinspection DuplicatedCode
5656
sagemaker_session = sagemaker.Session()
5757
bucket = sagemaker_session.default_bucket()
5858

5959
estimator = PyTorch(entry_point='train_clean.py',
6060
source_dir='source_dir/training_clean/',
61-
role=request.config.getini('sagemaker_role'),
6261
framework_version='1.9.1',
6362
py_version='py38',
6463
instance_count=1,
6564
instance_type='ml.m5.xlarge',
66-
max_run=60 * 60 * 3,
65+
max_run=int(timedelta(minutes=15).total_seconds()),
6766
keep_alive_period_in_seconds=1800,
6867
container_log_level=logging.INFO)
6968
estimator.fit()
7069

71-
model = estimator.create_model(entry_point='inference.py',
70+
model = estimator.create_model(entry_point='inference_ssh.py',
7271
source_dir='source_dir/inference/',
7372
dependencies=[SSHModelWrapper.dependency_dir()])
7473

tests/test_clean_core.py

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import os
3+
from datetime import timedelta
4+
from typing import Optional
35

46
import pytest
57
import sagemaker
@@ -21,7 +23,7 @@ def test_clean_train_warm_pool():
2123
py_version='py38',
2224
instance_count=1,
2325
instance_type='ml.m5.xlarge',
24-
max_run=60 * 60 * 3,
26+
max_run=int(timedelta(minutes=15).total_seconds()),
2527
keep_alive_period_in_seconds=1800,
2628
container_log_level=logging.INFO)
2729
estimator.fit()
@@ -32,15 +34,14 @@ def test_clean_train_warm_pool():
3234

3335

3436
# noinspection DuplicatedCode
35-
def test_clean_inference(request):
37+
def test_clean_inference():
3638
estimator = PyTorch(entry_point='train_clean.py',
3739
source_dir='source_dir/training_clean/',
38-
role=request.config.getini('sagemaker_role'),
3940
framework_version='1.9.1',
4041
py_version='py38',
4142
instance_count=1,
4243
instance_type='ml.m5.xlarge',
43-
max_run=60 * 60 * 3,
44+
max_run=int(timedelta(minutes=15).total_seconds()),
4445
keep_alive_period_in_seconds=1800,
4546
container_log_level=logging.INFO)
4647
estimator.fit()
@@ -65,41 +66,30 @@ def test_clean_inference(request):
6566

6667
# noinspection DuplicatedCode
6768
@pytest.mark.parametrize("instance_type", ["ml.m5.xlarge"])
68-
def test_clean_inference_mms(request, instance_type):
69+
def test_clean_inference_mms(instance_type):
6970
estimator = PyTorch(entry_point='train_clean.py',
7071
source_dir='source_dir/training_clean/',
71-
role=request.config.getini('sagemaker_role'),
7272
framework_version='1.9.1',
7373
py_version='py38',
7474
instance_count=1,
7575
instance_type=instance_type,
76-
max_run=60 * 60 * 3,
76+
max_run=int(timedelta(minutes=15).total_seconds()),
7777
keep_alive_period_in_seconds=1800,
7878
container_log_level=logging.INFO)
7979
estimator.fit()
8080

8181
model_1 = estimator.create_model(entry_point='inference_clean.py',
8282
source_dir='source_dir/inference_clean/')
8383

84-
# we need a temp endpoint to produce 'repacked_model_data'
85-
temp_endpoint_name = name_from_base('temp-inference-mms')
86-
temp_predictor: Predictor = model_1.deploy(initial_instance_count=1,
87-
instance_type='ml.m5.xlarge',
88-
endpoint_name=temp_endpoint_name,
89-
wait=True)
84+
_ = model_1.prepare_container_def(instance_type='ml.m5.xlarge')
9085
repacked_model_data_1 = model_1.repacked_model_data
91-
temp_predictor.delete_endpoint()
9286

9387
# MUST have the same entry point file name as for the model_1
9488
model_2 = estimator.create_model(entry_point='inference_clean.py',
9589
source_dir='source_dir/inference_clean_model2/')
96-
temp_endpoint_name = name_from_base('temp-inference-mms')
97-
temp_predictor: Predictor = model_2.deploy(initial_instance_count=1,
98-
instance_type='ml.m5.xlarge',
99-
endpoint_name=temp_endpoint_name,
100-
wait=True)
90+
91+
_ = model_2.prepare_container_def(instance_type='ml.m5.xlarge')
10192
repacked_model_data_2 = model_2.repacked_model_data
102-
temp_predictor.delete_endpoint()
10393

10494
bucket = sagemaker.Session().default_bucket()
10595
job_name = estimator.latest_training_job.name
@@ -115,12 +105,15 @@ def test_clean_inference_mms(request, instance_type):
115105

116106
endpoint_name = name_from_base('inference-mms')
117107

118-
predictor: Predictor = mdm.deploy(initial_instance_count=1,
119-
instance_type='ml.m5.xlarge',
120-
endpoint_name=endpoint_name,
121-
wait=True)
122-
108+
predictor: Optional[Predictor] = None
123109
try:
110+
predictor = mdm.deploy(
111+
initial_instance_count=1,
112+
instance_type='ml.m5.xlarge',
113+
endpoint_name=endpoint_name,
114+
wait=True
115+
)
116+
124117
# Note: we need a repacked model data here, not an estimator data
125118
mdm.add_model(model_data_source=repacked_model_data_1, model_data_path='model_1.tar.gz')
126119
mdm.add_model(model_data_source=repacked_model_data_2, model_data_path='model_2.tar.gz')
@@ -143,49 +136,32 @@ def test_clean_inference_mms(request, instance_type):
143136

144137
# noinspection DuplicatedCode
145138
@pytest.mark.parametrize("instance_type", ["ml.m5.xlarge"])
146-
def test_clean_inference_mms_without_model(request, instance_type):
139+
def test_clean_inference_mms_without_model(instance_type):
147140
estimator = PyTorch(entry_point='train_clean.py',
148141
source_dir='source_dir/training_clean/',
149-
role=request.config.getini('sagemaker_role'),
150142
framework_version='1.9.1',
151143
py_version='py38',
152144
instance_count=1,
153145
instance_type=instance_type,
154-
max_run=60 * 60 * 3,
146+
max_run=int(timedelta(minutes=15).total_seconds()),
155147
keep_alive_period_in_seconds=1800,
156148
container_log_level=logging.INFO)
157149
estimator.fit()
158150

159151
model_1 = estimator.create_model(entry_point='inference_clean.py',
160152
source_dir='source_dir/inference_clean/')
161153

162-
# we need a temp endpoint to produce 'repacked_model_data'
163-
temp_endpoint_name = name_from_base('temp-inference-mms')
164-
temp_predictor: Predictor = model_1.deploy(initial_instance_count=1,
165-
instance_type='ml.m5.xlarge',
166-
endpoint_name=temp_endpoint_name,
167-
wait=True)
154+
model_1_description = model_1.prepare_container_def(instance_type='ml.m5.xlarge')
168155
repacked_model_data_1 = model_1.repacked_model_data
169-
temp_predictor.delete_endpoint()
170-
171-
# But we still don't have access to the deployed container URI from Model object, so still need to use boto3.
172-
# Re-fetch container and model data location from Container 1 of the model:
173-
model_1_description = model_1.sagemaker_session.describe_model(model_1.name)
174-
container_uri = model_1_description['PrimaryContainer']['Image']
175-
# Also re-fetch deploy environment:
176-
deploy_env = model_1_description['PrimaryContainer']['Environment']
156+
container_uri = model_1_description['Image']
157+
deploy_env = model_1_description['Environment']
177158

178159
# MUST have the same entry point file name as for the model_1
179160
model_2 = estimator.create_model(entry_point='inference_clean.py',
180161
source_dir='source_dir/inference_clean_model2/')
181162

182-
temp_endpoint_name = name_from_base('temp-inference-mms')
183-
temp_predictor: Predictor = model_2.deploy(initial_instance_count=1,
184-
instance_type='ml.m5.xlarge',
185-
endpoint_name=temp_endpoint_name,
186-
wait=True)
163+
_ = model_2.prepare_container_def(instance_type='ml.m5.xlarge')
187164
repacked_model_data_2 = model_2.repacked_model_data
188-
temp_predictor.delete_endpoint()
189165

190166
bucket = sagemaker.Session().default_bucket()
191167
job_name = estimator.latest_training_job.name
@@ -196,8 +172,8 @@ def test_clean_inference_mms_without_model(request, instance_type):
196172
mdm = MultiDataModel(
197173
name=mdm_name,
198174
model_data_prefix=model_data_prefix,
199-
role=model_1.role,
200175
image_uri=container_uri,
176+
# entry_point=model_1.entry_point, # NOTE: entry point ignored
201177
env=deploy_env, # will copy 'SAGEMAKER_PROGRAM' env variable with entry point file name
202178
predictor_cls=PyTorchPredictor
203179
)

tests/test_distributed.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from datetime import timedelta
34

45
import pytest
56
from sagemaker.pytorch import PyTorch
@@ -20,19 +21,18 @@ def test_node_rank_from_env_json_non_existing_rc():
2021
assert node_rank == 0
2122

2223

23-
def test_distributed_training_with_default_instance_count(request):
24+
def test_distributed_training_with_default_instance_count():
2425
instance_count = 3
2526
default_ssh_instance_count = 2
2627
estimator = PyTorch(entry_point='train.py',
2728
source_dir='source_dir/training/',
2829
dependencies=[SSHEstimatorWrapper.dependency_dir()],
2930
base_job_name='ssh-training',
30-
role=request.config.getini('sagemaker_role'),
3131
framework_version='1.9.1',
3232
py_version='py38',
3333
instance_count=instance_count,
3434
instance_type='ml.m5.xlarge',
35-
max_run=60 * 60 * 3,
35+
max_run=int(timedelta(minutes=15).total_seconds()),
3636
keep_alive_period_in_seconds=1800,
3737
container_log_level=logging.INFO)
3838

@@ -45,18 +45,17 @@ def test_distributed_training_with_default_instance_count(request):
4545

4646

4747
@pytest.mark.parametrize("ssh_instance_count", [3, 1])
48-
def test_distributed_training_with_changed_instance_count(request, ssh_instance_count):
48+
def test_distributed_training_with_changed_instance_count(ssh_instance_count):
4949
instance_count = 3
5050
estimator = PyTorch(entry_point='train.py',
5151
source_dir='source_dir/training/',
5252
dependencies=[SSHEstimatorWrapper.dependency_dir()],
5353
base_job_name='ssh-training',
54-
role=request.config.getini('sagemaker_role'),
5554
framework_version='1.9.1',
5655
py_version='py38',
5756
instance_count=instance_count,
5857
instance_type='ml.m5.xlarge',
59-
max_run=60 * 60 * 3,
58+
max_run=int(timedelta(minutes=15).total_seconds()),
6059
keep_alive_period_in_seconds=1800,
6160
container_log_level=logging.INFO)
6261

@@ -69,18 +68,17 @@ def test_distributed_training_with_changed_instance_count(request, ssh_instance_
6968
assert len(mi_ids) == ssh_instance_count
7069

7170

72-
def test_distributed_training_mpi_single_node(request):
71+
def test_distributed_training_mpi_single_node():
7372
instance_count = 1
7473
estimator = PyTorch(entry_point='train.py',
7574
source_dir='source_dir/training/',
7675
dependencies=[SSHEstimatorWrapper.dependency_dir()],
7776
base_job_name='ssh-training',
78-
role=request.config.getini('sagemaker_role'),
7977
framework_version='1.9.1',
8078
py_version='py38',
8179
instance_count=instance_count,
8280
instance_type='ml.g4dn.xlarge',
83-
max_run=60 * 60 * 3,
81+
max_run=int(timedelta(minutes=15).total_seconds()),
8482
keep_alive_period_in_seconds=1800,
8583
container_log_level=logging.INFO,
8684
distribution={

0 commit comments

Comments
 (0)