Skip to content

Commit 12f9e6d

Browse files
Better support for popular SageMaker Studio images
1 parent 91c37e0 commit 12f9e6d

File tree

6 files changed

+146
-38
lines changed

6 files changed

+146
-38
lines changed

SageMaker_SSH_IDE.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
"outputs": [],
5454
"source": [
5555
"%%sh\n",
56-
"pip uninstall --root-user-action ignore -y -q awscli\n",
57-
"pip install --root-user-action ignore -q -U sagemaker-ssh-helper\n",
56+
"pip uninstall -y -q awscli\n",
57+
"pip install -q -U sagemaker-ssh-helper\n",
5858
"pip freeze | grep sagemaker-ssh-helper"
5959
]
6060
},

SageMaker_SSH_Notebook.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"outputs": [],
2222
"source": [
2323
"%%sh\n",
24-
"pip uninstall --root-user-action ignore -y -q awscli\n",
25-
"pip install --root-user-action ignore -q -U sagemaker-ssh-helper\n",
24+
"pip uninstall -y -q awscli\n",
25+
"pip install q -U sagemaker-ssh-helper\n",
2626
"pip freeze | grep sagemaker-ssh-helper"
2727
]
2828
},

server-lc-config.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ pip uninstall -y -q awscli
1717
pip install -q sagemaker-ssh-helper
1818

1919
# Uncomment two lines below to update SageMaker SSH Helper to the latest dev version from main branch
20-
#git clone https://github.com/aws-samples/sagemaker-ssh-helper.git /tmp/sagemaker-ssh-helper/
21-
#cd /tmp/sagemaker-ssh-helper/ && pip install . && cd ..
20+
#git clone https://github.com/aws-samples/sagemaker-ssh-helper.git ./sagemaker-ssh-helper/ || echo 'Already cloned'
21+
#cd ./sagemaker-ssh-helper/ && git pull --no-rebase && pip install . && cd ..
2222

2323
ps xfaeww
2424

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11

22
def pytest_addoption(parser):
3-
parser.addini('sagemaker_role', '')
4-
parser.addini('kernel_gateway_name', '')
3+
parser.addini('sagemaker_studio_domain', '')

tests/pytest.ini

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
log_cli = true
33
log_cli_level = 20
44
log_format = %(asctime)s %(levelname)s %(name)s - %(message)s
5-
log_date_format = %Y-%m-%d %H:%M:%S
5+
log_date_format = %Y-%m-%d %H:%M:%S %z %Z
66

77
markers=
88
manual: Optional tests that cannot be executed in an automated CI/CD pipeline, but helpful for troubleshooting
99

10-
# Change to your role or pass as an extra parameter to pytest: '-o sagemaker_role=...'.
11-
sagemaker_role = arn:aws:iam::<<YOUR_ACCOUNT_ID>>:role/service-role/<<YOUR_AmazonSageMaker_ExecutionRole>>
12-
1310
# Manually start SageMaker_SSH_IDE.ipynb in SageMaker Studio and replace with your kgw app name
14-
kernel_gateway_name = datascience-1-0-ml-t3-medium-xxx
11+
sagemaker_studio_domain = d-egm0dexample
1512

1613
# Also see conftest.py

tests/test_ide.py

Lines changed: 137 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,163 @@
11
import logging
2+
import subprocess
3+
import time
24

5+
import pytest
6+
7+
from sagemaker_ssh_helper.ide import SSHIDE
38
from sagemaker_ssh_helper.manager import SSMManager
49
from sagemaker_ssh_helper.proxy import SSMProxy
510

6-
logger = logging.getLogger('sagemaker-ssh-helper')
7-
8-
9-
def test_sagemaker_studio(request):
10-
kernel_gateway_name = request.config.getini('kernel_gateway_name')
11-
12-
studio_ids = SSMManager().get_studio_kgw_instance_ids(kernel_gateway_name, timeout_in_sec=300)
11+
logger = logging.getLogger('sagemaker-ssh-helper:test_ide')
12+
13+
14+
# TODO: add a test for typing SageMaker Studio terminal commands - check conda env is activated (Selenium?)
15+
16+
# See https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html .
17+
18+
SSH_TEST_INSTANCES = [
19+
# 0
20+
('test-data-science', 'ssh-test-ds1-ml-m5-large',
21+
'datascience-1.0', 'ml.m5.large', 'Python 3.7.10'),
22+
# 1
23+
('test-data-science', 'ssh-test-ds2-ml-m5-large',
24+
'sagemaker-data-science-38', 'ml.m5.large', 'Python 3.8.13'),
25+
# 2
26+
('test-data-science', 'ssh-test-ds3-ml-m5-large',
27+
'sagemaker-data-science-310-v1', 'ml.m5.large', 'Python 3.10.6'),
28+
29+
# 3
30+
('test-base-python', 'ssh-test-bp2-ml-m5-large',
31+
'sagemaker-base-python-38', 'ml.m5.large', 'Python 3.8.12'),
32+
# 4
33+
('test-base-python', 'ssh-test-bp3-ml-m5-large',
34+
'sagemaker-base-python-310-v1', 'ml.m5.large', 'Python 3.10.8'),
35+
36+
# 5
37+
('test-spark', 'ssh-test-magic-ml-m5-large',
38+
'sagemaker-sparkmagic', 'ml.m5.large', 'Python 3.7.10'),
39+
# 6
40+
('test-spark', 'ssh-test-analytics-ml-m5-large',
41+
'sagemaker-sparkanalytics-v1', 'ml.m5.large', 'Python 3.8.13'),
42+
# 7
43+
('test-spark', 'ssh-test-analytics2-ml-m5-large',
44+
'sagemaker-sparkanalytics-310-v1', 'ml.m5.large', 'Python 3.10.6'),
45+
46+
# 8
47+
('test-mxnet', 'ssh-test-mx19-ml-m5-large',
48+
'mxnet-1.9-cpu-py38-ubuntu20.04-sagemaker-v1.0', 'ml.m5.large', 'Python 3.8.10'),
49+
# 9 - TODO: https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/
50+
# ('test-mxnet', 'ssh-test-mx19-ml-g4dn-xlarge',
51+
# 'mxnet-1.9-gpu-py38-cu112-ubuntu20.04-sagemaker-v1.0', 'ml.g4dn.xlarge', 'Python 3'),
52+
53+
# 10
54+
('test-pytorch', 'ssh-test-pt112-ml-m5-large',
55+
'pytorch-1.12-cpu-py38', 'ml.m5.large', 'Python 3.8.16'),
56+
# 11
57+
('test-pytorch', 'ssh-test-pt112-ml-g4dn-xlarge',
58+
'pytorch-1.12-gpu-py38', 'ml.g4dn.xlarge', 'Python 3.8.16'),
59+
# 12
60+
('test-pytorch', 'ssh-test-pt113-ml-m5-large',
61+
'pytorch-1.13-cpu-py39', 'ml.m5.large', 'Python 3.9.16'),
62+
# 13
63+
('test-pytorch', 'ssh-test-pt113-ml-g4dn-xlarge',
64+
'pytorch-1.13-gpu-py39', 'ml.g4dn.xlarge', 'Python 3.9.16'),
65+
66+
# 14
67+
('test-tensorflow', 'ssh-test-tf211-ml-m5-large',
68+
'tensorflow-2.11.0-cpu-py39-ubuntu20.04-sagemaker-v1.1', 'ml.m5.large', 'Python 3.9.10'),
69+
# 15
70+
('test-tensorflow', 'ssh-test-tf211-ml-g4dn-xlarge',
71+
'tensorflow-2.11.0-gpu-py39-cu112-ubuntu20.04-sagemaker-v1.1', 'ml.g4dn.xlarge', 'Python 3.9.10'),
72+
# 16
73+
('test-tensorflow', 'ssh-test-tf212-ml-m5-large',
74+
'tensorflow-2.12.0-cpu-py310-ubuntu20.04-sagemaker-v1', 'ml.m5.large', 'Python 3.10.10'),
75+
# 17
76+
('test-tensorflow', 'ssh-test-tf212-ml-g4dn-xlarge',
77+
'tensorflow-2.12.0-gpu-py310-cu118-ubuntu20.04-sagemaker-v1', 'ml.g4dn.xlarge', 'Python 3.10.10'),
78+
]
79+
80+
81+
@pytest.mark.parametrize('instances', SSH_TEST_INSTANCES)
82+
def test_sagemaker_studio(instances, request):
83+
user, app_name, image_name, instance_type, expected_version = instances
84+
85+
ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user)
86+
87+
ide.create_ssh_kernel_app(
88+
app_name,
89+
image_name=image_name,
90+
instance_type=instance_type,
91+
ssh_lifecycle_config='sagemaker-ssh-helper-dev',
92+
recreate=True
93+
)
94+
95+
# Need to wait here, otherwise it will try to connect to an old offline instance
96+
# TODO: more robust mechanism?
97+
time.sleep(60)
98+
99+
studio_ids = ide.get_kernel_instance_ids(app_name, timeout_in_sec=300)
13100
studio_id = studio_ids[0]
14101

15-
ssm_proxy = SSMProxy(10022)
16-
ssm_proxy.connect_to_ssm_instance(studio_id)
102+
with SSMProxy(10022) as ssm_proxy:
103+
ssm_proxy.connect_to_ssm_instance(studio_id)
17104

18-
services_running = ssm_proxy.run_command_with_output("sm-ssh-ide status")
19-
services_running = services_running.decode('latin1')
105+
services_running = ssm_proxy.run_command_with_output("sm-ssh-ide status")
106+
services_running = services_running.decode('latin1')
20107

21-
python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version")
22-
python_version = python_version.decode('latin1')
108+
output = ssm_proxy.run_command_with_output("sm-ssh-ide env-diagnostics")
109+
output = output.decode('latin1')
110+
logger.info(f"Collected env diagnostics for {image_name}: {output}")
23111

24-
ssm_proxy.disconnect()
112+
python_version = ssm_proxy.run_command_with_output("sm-ssh-ide get-studio-python-version")
113+
python_version = python_version.decode('latin1')
114+
logger.info(f"Collected SageMaker Studio Python version: {python_version}")
25115

26116
assert "127.0.0.1:8889" in services_running
27117
assert "127.0.0.1:5901" in services_running
28118

29-
assert "Python 3.8" in python_version
119+
assert expected_version in python_version
120+
121+
ide.delete_kernel_app(app_name, wait=False)
30122

31123

32-
def test_notebook_instance(request):
124+
@pytest.mark.parametrize('instances', SSH_TEST_INSTANCES)
125+
@pytest.mark.manual
126+
def test_sagemaker_studio_cleanup(instances, request):
127+
user, app_name, image_name, instance_type, expected_version = instances
128+
129+
ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user)
130+
ide.delete_kernel_app(app_name, wait=False)
131+
132+
133+
def test_notebook_instance():
33134
notebook_ids = SSMManager().get_notebook_instance_ids("sagemaker-ssh-helper", timeout_in_sec=300)
34135
studio_id = notebook_ids[0]
35136

36-
ssm_proxy = SSMProxy(17022)
37-
ssm_proxy.connect_to_ssm_instance(studio_id)
137+
with SSMProxy(17022) as ssm_proxy:
138+
ssm_proxy.connect_to_ssm_instance(studio_id)
38139

39-
_ = ssm_proxy.run_command("apt-get install -q -y net-tools")
140+
_ = ssm_proxy.run_command("apt-get install -q -y net-tools")
40141

41-
services_running = ssm_proxy.run_command_with_output("netstat -nptl")
42-
services_running = services_running.decode('latin1')
142+
services_running = ssm_proxy.run_command_with_output("netstat -nptl")
143+
services_running = services_running.decode('latin1')
43144

44-
python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version")
45-
python_version = python_version.decode('latin1')
46-
47-
ssm_proxy.disconnect()
145+
python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version")
146+
python_version = python_version.decode('latin1')
48147

49148
assert "0.0.0.0:22" in services_running
50149

51150
assert "Python 3.8" in python_version
151+
152+
153+
def test_called_process_error_with_output():
154+
got_error = False
155+
try:
156+
# should fail, because we're not connected to a remote kernel
157+
subprocess.check_output("sm-local-ssh-ide run-command python --version".split(' '), stderr=subprocess.STDOUT)
158+
except subprocess.CalledProcessError as e:
159+
output = e.output.decode('latin1').strip()
160+
logger.info(f"Got error (expected): {output}")
161+
got_error = True
162+
assert output == "ssh: connect to host localhost port 10022: Connection refused"
163+
assert got_error

0 commit comments

Comments
 (0)