|
1 | 1 | import logging |
| 2 | +import subprocess |
| 3 | +import time |
2 | 4 |
|
| 5 | +import pytest |
| 6 | + |
| 7 | +from sagemaker_ssh_helper.ide import SSHIDE |
3 | 8 | from sagemaker_ssh_helper.manager import SSMManager |
4 | 9 | from sagemaker_ssh_helper.proxy import SSMProxy |
5 | 10 |
|
6 | | -logger = logging.getLogger('sagemaker-ssh-helper') |
7 | | - |
8 | | - |
9 | | -def test_sagemaker_studio(request): |
10 | | - kernel_gateway_name = request.config.getini('kernel_gateway_name') |
11 | | - |
12 | | - studio_ids = SSMManager().get_studio_kgw_instance_ids(kernel_gateway_name, timeout_in_sec=300) |
| 11 | +logger = logging.getLogger('sagemaker-ssh-helper:test_ide') |
| 12 | + |
| 13 | + |
| 14 | +# TODO: add a test for typing SageMaker Studio terminal commands - check conda env is activated (Selenium?) |
| 15 | + |
| 16 | +# See https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html . |
| 17 | + |
| 18 | +SSH_TEST_INSTANCES = [ |
| 19 | + # 0 |
| 20 | + ('test-data-science', 'ssh-test-ds1-ml-m5-large', |
| 21 | + 'datascience-1.0', 'ml.m5.large', 'Python 3.7.10'), |
| 22 | + # 1 |
| 23 | + ('test-data-science', 'ssh-test-ds2-ml-m5-large', |
| 24 | + 'sagemaker-data-science-38', 'ml.m5.large', 'Python 3.8.13'), |
| 25 | + # 2 |
| 26 | + ('test-data-science', 'ssh-test-ds3-ml-m5-large', |
| 27 | + 'sagemaker-data-science-310-v1', 'ml.m5.large', 'Python 3.10.6'), |
| 28 | + |
| 29 | + # 3 |
| 30 | + ('test-base-python', 'ssh-test-bp2-ml-m5-large', |
| 31 | + 'sagemaker-base-python-38', 'ml.m5.large', 'Python 3.8.12'), |
| 32 | + # 4 |
| 33 | + ('test-base-python', 'ssh-test-bp3-ml-m5-large', |
| 34 | + 'sagemaker-base-python-310-v1', 'ml.m5.large', 'Python 3.10.8'), |
| 35 | + |
| 36 | + # 5 |
| 37 | + ('test-spark', 'ssh-test-magic-ml-m5-large', |
| 38 | + 'sagemaker-sparkmagic', 'ml.m5.large', 'Python 3.7.10'), |
| 39 | + # 6 |
| 40 | + ('test-spark', 'ssh-test-analytics-ml-m5-large', |
| 41 | + 'sagemaker-sparkanalytics-v1', 'ml.m5.large', 'Python 3.8.13'), |
| 42 | + # 7 |
| 43 | + ('test-spark', 'ssh-test-analytics2-ml-m5-large', |
| 44 | + 'sagemaker-sparkanalytics-310-v1', 'ml.m5.large', 'Python 3.10.6'), |
| 45 | + |
| 46 | + # 8 |
| 47 | + ('test-mxnet', 'ssh-test-mx19-ml-m5-large', |
| 48 | + 'mxnet-1.9-cpu-py38-ubuntu20.04-sagemaker-v1.0', 'ml.m5.large', 'Python 3.8.10'), |
| 49 | + # 9 - TODO: https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ |
| 50 | + # ('test-mxnet', 'ssh-test-mx19-ml-g4dn-xlarge', |
| 51 | + # 'mxnet-1.9-gpu-py38-cu112-ubuntu20.04-sagemaker-v1.0', 'ml.g4dn.xlarge', 'Python 3'), |
| 52 | + |
| 53 | + # 10 |
| 54 | + ('test-pytorch', 'ssh-test-pt112-ml-m5-large', |
| 55 | + 'pytorch-1.12-cpu-py38', 'ml.m5.large', 'Python 3.8.16'), |
| 56 | + # 11 |
| 57 | + ('test-pytorch', 'ssh-test-pt112-ml-g4dn-xlarge', |
| 58 | + 'pytorch-1.12-gpu-py38', 'ml.g4dn.xlarge', 'Python 3.8.16'), |
| 59 | + # 12 |
| 60 | + ('test-pytorch', 'ssh-test-pt113-ml-m5-large', |
| 61 | + 'pytorch-1.13-cpu-py39', 'ml.m5.large', 'Python 3.9.16'), |
| 62 | + # 13 |
| 63 | + ('test-pytorch', 'ssh-test-pt113-ml-g4dn-xlarge', |
| 64 | + 'pytorch-1.13-gpu-py39', 'ml.g4dn.xlarge', 'Python 3.9.16'), |
| 65 | + |
| 66 | + # 14 |
| 67 | + ('test-tensorflow', 'ssh-test-tf211-ml-m5-large', |
| 68 | + 'tensorflow-2.11.0-cpu-py39-ubuntu20.04-sagemaker-v1.1', 'ml.m5.large', 'Python 3.9.10'), |
| 69 | + # 15 |
| 70 | + ('test-tensorflow', 'ssh-test-tf211-ml-g4dn-xlarge', |
| 71 | + 'tensorflow-2.11.0-gpu-py39-cu112-ubuntu20.04-sagemaker-v1.1', 'ml.g4dn.xlarge', 'Python 3.9.10'), |
| 72 | + # 16 |
| 73 | + ('test-tensorflow', 'ssh-test-tf212-ml-m5-large', |
| 74 | + 'tensorflow-2.12.0-cpu-py310-ubuntu20.04-sagemaker-v1', 'ml.m5.large', 'Python 3.10.10'), |
| 75 | + # 17 |
| 76 | + ('test-tensorflow', 'ssh-test-tf212-ml-g4dn-xlarge', |
| 77 | + 'tensorflow-2.12.0-gpu-py310-cu118-ubuntu20.04-sagemaker-v1', 'ml.g4dn.xlarge', 'Python 3.10.10'), |
| 78 | +] |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize('instances', SSH_TEST_INSTANCES) |
| 82 | +def test_sagemaker_studio(instances, request): |
| 83 | + user, app_name, image_name, instance_type, expected_version = instances |
| 84 | + |
| 85 | + ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user) |
| 86 | + |
| 87 | + ide.create_ssh_kernel_app( |
| 88 | + app_name, |
| 89 | + image_name=image_name, |
| 90 | + instance_type=instance_type, |
| 91 | + ssh_lifecycle_config='sagemaker-ssh-helper-dev', |
| 92 | + recreate=True |
| 93 | + ) |
| 94 | + |
| 95 | + # Need to wait here, otherwise it will try to connect to an old offline instance |
| 96 | + # TODO: more robust mechanism? |
| 97 | + time.sleep(60) |
| 98 | + |
| 99 | + studio_ids = ide.get_kernel_instance_ids(app_name, timeout_in_sec=300) |
13 | 100 | studio_id = studio_ids[0] |
14 | 101 |
|
15 | | - ssm_proxy = SSMProxy(10022) |
16 | | - ssm_proxy.connect_to_ssm_instance(studio_id) |
| 102 | + with SSMProxy(10022) as ssm_proxy: |
| 103 | + ssm_proxy.connect_to_ssm_instance(studio_id) |
17 | 104 |
|
18 | | - services_running = ssm_proxy.run_command_with_output("sm-ssh-ide status") |
19 | | - services_running = services_running.decode('latin1') |
| 105 | + services_running = ssm_proxy.run_command_with_output("sm-ssh-ide status") |
| 106 | + services_running = services_running.decode('latin1') |
20 | 107 |
|
21 | | - python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version") |
22 | | - python_version = python_version.decode('latin1') |
| 108 | + output = ssm_proxy.run_command_with_output("sm-ssh-ide env-diagnostics") |
| 109 | + output = output.decode('latin1') |
| 110 | + logger.info(f"Collected env diagnostics for {image_name}: {output}") |
23 | 111 |
|
24 | | - ssm_proxy.disconnect() |
| 112 | + python_version = ssm_proxy.run_command_with_output("sm-ssh-ide get-studio-python-version") |
| 113 | + python_version = python_version.decode('latin1') |
| 114 | + logger.info(f"Collected SageMaker Studio Python version: {python_version}") |
25 | 115 |
|
26 | 116 | assert "127.0.0.1:8889" in services_running |
27 | 117 | assert "127.0.0.1:5901" in services_running |
28 | 118 |
|
29 | | - assert "Python 3.8" in python_version |
| 119 | + assert expected_version in python_version |
| 120 | + |
| 121 | + ide.delete_kernel_app(app_name, wait=False) |
30 | 122 |
|
31 | 123 |
|
32 | | -def test_notebook_instance(request): |
| 124 | +@pytest.mark.parametrize('instances', SSH_TEST_INSTANCES) |
| 125 | +@pytest.mark.manual |
| 126 | +def test_sagemaker_studio_cleanup(instances, request): |
| 127 | + user, app_name, image_name, instance_type, expected_version = instances |
| 128 | + |
| 129 | + ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user) |
| 130 | + ide.delete_kernel_app(app_name, wait=False) |
| 131 | + |
| 132 | + |
| 133 | +def test_notebook_instance(): |
33 | 134 | notebook_ids = SSMManager().get_notebook_instance_ids("sagemaker-ssh-helper", timeout_in_sec=300) |
34 | 135 | studio_id = notebook_ids[0] |
35 | 136 |
|
36 | | - ssm_proxy = SSMProxy(17022) |
37 | | - ssm_proxy.connect_to_ssm_instance(studio_id) |
| 137 | + with SSMProxy(17022) as ssm_proxy: |
| 138 | + ssm_proxy.connect_to_ssm_instance(studio_id) |
38 | 139 |
|
39 | | - _ = ssm_proxy.run_command("apt-get install -q -y net-tools") |
| 140 | + _ = ssm_proxy.run_command("apt-get install -q -y net-tools") |
40 | 141 |
|
41 | | - services_running = ssm_proxy.run_command_with_output("netstat -nptl") |
42 | | - services_running = services_running.decode('latin1') |
| 142 | + services_running = ssm_proxy.run_command_with_output("netstat -nptl") |
| 143 | + services_running = services_running.decode('latin1') |
43 | 144 |
|
44 | | - python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version") |
45 | | - python_version = python_version.decode('latin1') |
46 | | - |
47 | | - ssm_proxy.disconnect() |
| 145 | + python_version = ssm_proxy.run_command_with_output("/opt/conda/bin/python --version") |
| 146 | + python_version = python_version.decode('latin1') |
48 | 147 |
|
49 | 148 | assert "0.0.0.0:22" in services_running |
50 | 149 |
|
51 | 150 | assert "Python 3.8" in python_version |
| 151 | + |
| 152 | + |
| 153 | +def test_called_process_error_with_output(): |
| 154 | + got_error = False |
| 155 | + try: |
| 156 | + # should fail, because we're not connected to a remote kernel |
| 157 | + subprocess.check_output("sm-local-ssh-ide run-command python --version".split(' '), stderr=subprocess.STDOUT) |
| 158 | + except subprocess.CalledProcessError as e: |
| 159 | + output = e.output.decode('latin1').strip() |
| 160 | + logger.info(f"Got error (expected): {output}") |
| 161 | + got_error = True |
| 162 | + assert output == "ssh: connect to host localhost port 10022: Connection refused" |
| 163 | + assert got_error |
0 commit comments