Skip to content

Commit 6bb64d1

Browse files
Test for MXNet
1 parent 36914ac commit 6bb64d1

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/test_frameworks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import time
34

45
import pytest
@@ -532,3 +533,29 @@ def test_train_estimator_ssh(request):
532533

533534
finally:
534535
predictor.delete_endpoint()
536+
537+
538+
def test_train_mxnet_ssh(request):
539+
logging.info("Starting training")
540+
541+
from sagemaker.mxnet import MXNet
542+
estimator = MXNet(entry_point=os.path.basename('source_dir/training/train.py'),
543+
source_dir='source_dir/training/',
544+
dependencies=[SSHEstimatorWrapper.dependency_dir()],
545+
base_job_name='ssh-training-mxnet',
546+
role=request.config.getini('sagemaker_role'),
547+
py_version='py38',
548+
framework_version='1.9',
549+
instance_count=1,
550+
instance_type='ml.m5.xlarge',
551+
max_run=60 * 30,
552+
keep_alive_period_in_seconds=1800,
553+
container_log_level=logging.INFO)
554+
555+
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
556+
estimator.fit(wait=False)
557+
ssh_wrapper.start_ssm_connection_and_continue(11022, 60)
558+
ssh_wrapper.wait_training_job()
559+
logging.info("Finished training")
560+
561+
assert estimator.model_data.find("model.tar.gz") != -1

0 commit comments

Comments
 (0)