Skip to content

Commit 1bc1bd5

Browse files
Multiple users support on local machine
1 parent 36d222f commit 1bc1bd5

File tree

7 files changed

+148
-42
lines changed

7 files changed

+148
-42
lines changed

sagemaker_ssh_helper/ide.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,20 @@ def resolve_sagemaker_kernel_image_arn(self, image_name):
198198
sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images
199199
return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}"
200200

201+
def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0):
202+
print(self.get_kernel_instance_ids(app_name, timeout_in_sec)[index])
203+
201204
def get_kernel_instance_ids(self, app_name, timeout_in_sec):
202205
self.logger.info("Resolving IDE instance IDs through SSM tags")
203206
self.log_urls(app_name)
204207
self.logger.info(f"Connect from local machine (with GUI and Jupyter): sm-local-ssh-ide connect {app_name}")
205208
self.logger.info(f"To connect with SSH only: sm-local-ssh-ide connect {app_name} --ssh-only")
206-
# FIXME: resolve with domain and user
207-
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec)
209+
if self.domain_id and self.user:
210+
result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name, timeout_in_sec)
211+
else:
212+
self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest "
213+
f"active kernel gateway with the name {app_name} in the region {self.current_region}")
214+
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec)
208215
return result
209216

210217
def log_urls(self, app_name):

sagemaker_ssh_helper/log.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def get_studio_kgw_ssm_instance_ids(self, kgw_name, timeout_in_sec=0):
7474
return self.get_ssm_instance_ids(f'/aws/sagemaker/studio', f"KernelGateway/{kgw_name}",
7575
timeout_in_sec=timeout_in_sec)
7676

77-
def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
77+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter: str = None):
78+
if arn_filter:
79+
raise ValueError("Not supported for SSHLog")
7880
return self.get_ssm_instance_ids_once(log_group=arn_resource_type, stream_name=arn_resource_name)
7981

8082
def get_ssm_instance_ids_once(self, log_group, stream_name):

sagemaker_ssh_helper/manager.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ def __init__(self, region_name: str = None,
1919

2020
def get_instance_ids(self, arn_resource_type, arn_resource_name,
2121
timeout_in_sec=0,
22-
expected_count=1):
22+
expected_count=1,
23+
arn_filter: str = None):
2324
if arn_resource_name.startswith('mi-'):
2425
self.logger.warning("SageMaker resource name usually doesn't not start with 'mi-', "
2526
"did you pass the SSM instance ID by mistake?")
2627
self.logger.info("Using AWS Region: %s", self.region_name)
27-
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
28+
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name, arn_filter)
2829

2930
while not mi_ids and timeout_in_sec > 0:
3031
self.logger.info(f"No instance IDs found. Retrying. Is SSM Agent running on the remote? "
@@ -46,7 +47,7 @@ def get_instance_ids(self, arn_resource_type, arn_resource_name,
4647
return mi_ids
4748

4849
@abstractmethod
49-
def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
50+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter: str = None):
5051
raise NotImplementedError("Abstract method")
5152

5253

@@ -101,6 +102,11 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0):
101102
self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}")
102103
return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec)
103104

105+
def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0):
106+
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
107+
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec,
108+
arn_filter=f":app/{domain_id}/{user_profile_name}/")
109+
104110
def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0):
105111
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
106112
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec)
@@ -109,7 +115,8 @@ def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0):
109115
self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}")
110116
return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec)
111117

112-
def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
118+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name,
119+
arn_filter: str = None):
113120
# TODO: use tag filter instead, for faster performance
114121
all_instances = self.list_all_instances_with_tags()
115122
result_pairs = []
@@ -119,7 +126,8 @@ def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
119126
continue
120127
if f"/{arn_resource_name}" in tags["SSHResourceArn"] and \
121128
arn_resource_name == tags["SSHResourceName"] and \
122-
f":{arn_resource_type}/" in tags["SSHResourceArn"]:
129+
f":{arn_resource_type}/" in tags["SSHResourceArn"] and \
130+
(not arn_filter or arn_filter in tags["SSHResourceArn"]):
123131
if "SSHTimestamp" in tags:
124132
timestamp = tags["SSHTimestamp"]
125133
else:

sagemaker_ssh_helper/sm-local-ssh-ide

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,57 @@
33
# Commands:
44
# connect <kernel_gateway_name> [--ssh-only] [<extra_ssh_args>]
55
# run-command <command> <args...>
6-
# TODO: create <kernel_gateway_name> [--recreate] --domain <domain> --user <profile-name> --image datascience-1.0 --instance ml.t3.medium
6+
# set-domain-id <domain_id>
7+
# set-user-profile-name <user_profile_name>
8+
# set-jb-license-server <jb-license-server-hostname-without-http>
9+
# TODO: create <kernel_gateway_name> [--recreate] --image datascience-1.0 --instance ml.t3.medium
710
# TODO: list (all apps from all users and domains marked with '*' if can connect with SSH and with '!' if user don't match)
8-
# TODO: open-firefox <kernel_gateway_name> --domain <domain> --user <profile-name>
9-
# TODO: set-domain, set-user (defaults)
10-
# TODO: set-kernel-gateway-name, get-kernel-gateway-name (sets default for connect and create)
1111

1212
# SageMaker Studio Kernel Gateway name is usually the same as the hostname,
1313
# e. g. sagemaker-data-science-ml-m5-large-1234567890abcdef0
1414

1515
# To open SageMaker Studio UI in Firefox from command line on macOS, use the following command:
1616
# open -a Firefox $(AWS_PROFILE=terry aws sagemaker create-presigned-domain-url --domain-id d-lnwlaexample --user-profile-name terry-whitlock --query AuthorizedUrl --output text)
1717

18+
# replace with your JetBrains License Server host, or leave it as is if you don't use one
19+
JB_LICENSE_SERVER_HOST="jetbrains-license-server.example.com"
20+
21+
1822
COMMAND=$1
1923

2024
if [[ "$COMMAND" == "connect" ]]; then
2125

2226
SM_STUDIO_KGW_NAME="$2"
2327
OPTIONS="$3"
2428

25-
# FIXME: distinguish between user profiles
29+
# TODO: if name is empty, list and choose
30+
31+
DOMAIN_ID=""
32+
if [ -f ~/.sm-studio-domain-id ]; then
33+
DOMAIN_ID="$(cat ~/.sm-studio-domain-id)"
34+
else
35+
echo "sm-local-ssh-ide: WARNING: SageMaker Studio domain ID is not set, "\
36+
"will attempt to connect to the latest active kernel gateway with name '$SM_STUDIO_KGW_NAME' in the region. "\
37+
"Run 'sm-local-ssh-ide set-domain-id' to override."
38+
fi
39+
USER_PROFILE_NAME=""
40+
if [ -f ~/.sm-studio-user-profile-name ]; then
41+
USER_PROFILE_NAME="$(cat ~/.sm-studio-user-profile-name)"
42+
else
43+
echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set, "\
44+
"will attempt to connect to the latest active kernel gateway with name '$SM_STUDIO_KGW_NAME' in the region. "\
45+
"Run 'sm-local-ssh-ide set-user-profile-name' to override."
46+
fi
2647

2748
INSTANCE_ID=$(python <<EOF
2849
import sagemaker; from sagemaker_ssh_helper.ide import SSHIDE;
2950
import logging; logging.basicConfig(level=logging.INFO);
30-
print(SSHIDE(None, None).get_kernel_instance_ids("$SM_STUDIO_KGW_NAME", timeout_in_sec=300)[0])
51+
SSHIDE("$DOMAIN_ID", "$USER_PROFILE_NAME").print_kernel_instance_id("$SM_STUDIO_KGW_NAME", timeout_in_sec=300)
3152
EOF
3253
)
3354

34-
# TODO: set-jb-license-server
35-
# TODO: check that it's not started with 'http'
36-
# replace with your JetBrains License Server host, or leave it as is if you don't use one
37-
JB_LICENSE_SERVER_HOST="jetbrains-license-server.example.com"
38-
39-
if [ -f ~/.sm-jb-license-server ]; then
40-
echo "sm-local-ssh-ide: ~/.sm-jb-license-server file with PyCharm license server host is already configured, skipping override"
41-
JB_LICENSE_SERVER_HOST="$(cat ~/.sm-jb-license-server)"
42-
else
43-
echo "sm-local-ssh-ide: Saving PyCharm License server host into ~/.sm-jb-license-server"
44-
echo "$JB_LICENSE_SERVER_HOST" > ~/.sm-jb-license-server
45-
fi
46-
47-
4855
if [[ "$OPTIONS" == "--ssh-only" ]]; then
56+
echo "sm-local-ssh-ide: Connecting only SSH to local port 10022 (got the flag --ssh-only)"
4957
shift
5058
shift
5159
shift
@@ -55,6 +63,11 @@ EOF
5563
-L localhost:10022:localhost:22 \
5664
$*
5765
else
66+
if [ -f ~/.sm-jb-license-server ]; then
67+
JB_LICENSE_SERVER_HOST="$(cat ~/.sm-jb-license-server)"
68+
fi
69+
70+
echo "sm-local-ssh-ide: Connecting SSH, VNC and Jupyter to local ports 10022, 5901 and 8889 (add --ssh-only flag to override)"
5871
shift
5972
shift
6073
sm-local-start-ssh "$INSTANCE_ID" \
@@ -65,6 +78,29 @@ EOF
6578
$*
6679
fi
6780

81+
elif [[ "$COMMAND" == "set-jb-license-server" ]]; then
82+
JB_LICENSE_SERVER_HOST="$2"
83+
84+
echo "sm-local-ssh-ide: Saving PyCharm License server host into ~/.sm-jb-license-server"
85+
echo "$JB_LICENSE_SERVER_HOST" > ~/.sm-jb-license-server
86+
87+
elif [[ "$COMMAND" == "set-domain-id" ]]; then
88+
DOMAIN_ID="$2"
89+
if [[ "$DOMAIN_ID" == "" ]]; then
90+
echo "sm-local-ssh-ide: ERROR: <domain-id> argument is expected"
91+
exit 1
92+
fi
93+
echo "sm-local-ssh-ide: Saving SageMaker Studio domain ID into ~/.sm-studio-domain-id"
94+
echo "$DOMAIN_ID" > ~/.sm-studio-domain-id
95+
96+
elif [[ "$COMMAND" == "set-user-profile-name" ]]; then
97+
USER_PROFILE_NAME="$2"
98+
if [[ "$USER_PROFILE_NAME" == "" ]]; then
99+
echo "sm-local-ssh-ide: ERROR: <user-profile-name> argument is expected"
100+
exit 1
101+
fi
102+
echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-name"
103+
echo "$USER_PROFILE_NAME" > ~/.sm-studio-user-profile-name
68104

69105
elif [[ "$COMMAND" == "run-command" ]]; then
70106

sagemaker_ssh_helper/sm-ssh-ide

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
#!/bin/bash -l
22
# Very important to start with 'bash -l' - to escape SageMaker Studio notebook environment
33

4+
# Commands:
5+
# configure [--ssh-only]
6+
# set-local-user-id <local_user_id>
7+
# set-jb-license-server <jb-license-server-hostname-without-http>
8+
# set-vnc-password <123456>
9+
# get-user-profile-name
10+
# get-domain-id
11+
# get-metadata
12+
413
set -e
514

615
dir=$(dirname "$0")
@@ -97,7 +106,6 @@ elif [[ "$1" == "set-jb-license-server" ]]; then
97106
fi
98107

99108
JB_LICENSE_SERVER_HOST="$(cat ~/.sm-jb-license-server)"
100-
# TODO: check that it's not started with 'http'
101109

102110
if grep -q "$JB_LICENSE_SERVER_HOST" /etc/hosts; then
103111
echo "sm-ssh-ide: Skipping the update of /etc/hosts with PyCharm license server (already there)"

tests/test_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import subprocess
34
from datetime import timedelta
45
from pathlib import Path
56

@@ -14,6 +15,8 @@
1415
from sagemaker_ssh_helper.wrapper import SSHEnvironmentWrapper, SSHEstimatorWrapper, SSHModelWrapper
1516
from test_util import _create_bucket_if_doesnt_exist
1617

18+
logger = logging.getLogger('sagemaker-ssh-helper:test_functions')
19+
1720

1821
def test_ssm_role_from_arn():
1922
assert SSHEnvironmentWrapper.ssm_role_from_iam_arn("arn:aws:iam::012345678901:role/service-role/SageMakerRole") \
@@ -288,3 +291,16 @@ def test_model_repacking_default_entry_point_with_existing_model():
288291
logging.info("Model data: %s", model.repacked_model_data)
289292
assert model.repacked_model_data is not None # FIXME: not working
290293
# FIXME: SAGEMAKER_SUBMIT_DIRECTORY = file://source_dir/inference_hf_accelerate instead of /opt/ml/model/code
294+
295+
296+
def test_called_process_error_with_output():
297+
got_error = False
298+
try:
299+
# should fail, because we're not connected to a remote kernel
300+
subprocess.check_output("sm-local-ssh-ide run-command python --version".split(' '), stderr=subprocess.STDOUT)
301+
except subprocess.CalledProcessError as e:
302+
output = e.output.decode('latin1').strip()
303+
logger.info(f"Got error (expected): {output}")
304+
got_error = True
305+
assert output == "ssh: connect to host localhost port 10022: Connection refused"
306+
assert got_error

tests/test_ide.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,6 @@ def test_notebook_instance():
162162
assert "Python 3.8" in python_version
163163

164164

165-
def test_called_process_error_with_output():
166-
got_error = False
167-
try:
168-
# should fail, because we're not connected to a remote kernel
169-
subprocess.check_output("sm-local-ssh-ide run-command python --version".split(' '), stderr=subprocess.STDOUT)
170-
except subprocess.CalledProcessError as e:
171-
output = e.output.decode('latin1').strip()
172-
logger.info(f"Got error (expected): {output}")
173-
got_error = True
174-
assert output == "ssh: connect to host localhost port 10022: Connection refused"
175-
assert got_error
176-
177-
178165
def test_studio_internet_free_mode(request):
179166
"""
180167
See https://docs.aws.amazon.com/sagemaker/latest/dg/studio-byoi.html
@@ -236,6 +223,48 @@ def test_studio_internet_free_mode(request):
236223
ide.delete_kernel_app("byoi-studio-app", wait=False)
237224

238225

226+
def test_studio_multiple_users(request):
227+
ide_ds = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-data-science')
228+
ide_pt = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-pytorch')
229+
230+
ide_ds.create_ssh_kernel_app(
231+
'ssh-test-user',
232+
image_name_or_arn='sagemaker-data-science-310-v1',
233+
instance_type='ml.m5.large',
234+
ssh_lifecycle_config='sagemaker-ssh-helper-dev',
235+
recreate=True
236+
)
237+
238+
# Give a head start
239+
time.sleep(60)
240+
241+
ide_pt.create_ssh_kernel_app(
242+
'ssh-test-user',
243+
image_name_or_arn='sagemaker-data-science-310-v1',
244+
instance_type='ml.m5.large',
245+
ssh_lifecycle_config='sagemaker-ssh-helper-dev',
246+
recreate=True
247+
)
248+
249+
# Give time for instance ID to propagate
250+
time.sleep(60)
251+
252+
studio_ids = ide_ds.get_kernel_instance_ids('ssh-test-user', timeout_in_sec=300)
253+
studio_id = studio_ids[0]
254+
255+
with SSMProxy(10022) as ssm_proxy:
256+
ssm_proxy.connect_to_ssm_instance(studio_id)
257+
258+
user_profile_name = ssm_proxy.run_command_with_output("sm-ssh-ide get-user-profile-name")
259+
user_profile_name = user_profile_name.decode('latin1')
260+
logger.info(f"Collected SageMaker Studio profile name: {user_profile_name}")
261+
262+
ide_ds.delete_kernel_app('ssh-test-user', wait=False)
263+
ide_pt.delete_kernel_app('ssh-test-user', wait=False)
264+
265+
assert "test-data-science" in user_profile_name
266+
267+
239268
def test_studio_notebook_in_firefox(request):
240269
ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-data-science')
241270

0 commit comments

Comments
 (0)