Skip to content

Commit c17957d

Browse files
Allow empty "" domain
1 parent 184dd6b commit c17957d

File tree

6 files changed

+139
-15
lines changed

6 files changed

+139
-15
lines changed

sagemaker_ssh_helper/ide.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,14 @@ def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0):
204204
def get_kernel_instance_ids(self, app_name, timeout_in_sec):
205205
self.logger.info("Resolving IDE instance IDs through SSM tags")
206206
self.log_urls(app_name)
207-
self.logger.info(f"Connect from local machine (with GUI and Jupyter): sm-local-ssh-ide connect {app_name}")
208-
self.logger.info(f"To connect with SSH only: sm-local-ssh-ide connect {app_name} --ssh-only")
209207
if self.domain_id and self.user:
210208
result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name, timeout_in_sec)
209+
elif self.user:
210+
self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest "
211+
f"active kernel gateway with the name {app_name} in the region {self.current_region} "
212+
f"for user profile {self.user}")
213+
result = SSMManager().get_studio_user_kgw_instance_ids("", self.user, app_name,
214+
timeout_in_sec)
211215
else:
212216
self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest "
213217
f"active kernel gateway with the name {app_name} in the region {self.current_region}")

sagemaker_ssh_helper/log.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def get_studio_kgw_ssm_instance_ids(self, kgw_name, timeout_in_sec=0):
7474
return self.get_ssm_instance_ids(f'/aws/sagemaker/studio', f"KernelGateway/{kgw_name}",
7575
timeout_in_sec=timeout_in_sec)
7676

77-
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter: str = None):
78-
if arn_filter:
77+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter_regex: str = None):
78+
if arn_filter_regex:
7979
raise ValueError("Not supported for SSHLog")
8080
return self.get_ssm_instance_ids_once(log_group=arn_resource_type, stream_name=arn_resource_name)
8181

@@ -194,7 +194,7 @@ def get_transform_metadata_url(self, transform_job_name):
194194

195195
def get_ide_cloudwatch_url(self, domain, user, app_name):
196196
app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway'
197-
if domain and user:
197+
if user:
198198
return f"https://{self.aws_console.get_console_domain()}/" \
199199
f"cloudwatch/home?region={self.region_name}#" \
200200
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \

sagemaker_ssh_helper/manager.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import boto3
66
from typing import Dict
77

8+
import re
9+
810

911
class SSMManagerBase(ABC):
1012
logger = logging.getLogger('sagemaker-ssh-helper:SSMManagerBase')
@@ -20,12 +22,12 @@ def __init__(self, region_name: str = None,
2022
def get_instance_ids(self, arn_resource_type, arn_resource_name,
2123
timeout_in_sec=0,
2224
expected_count=1,
23-
arn_filter: str = None):
25+
arn_filter_regex: str = None):
2426
if arn_resource_name.startswith('mi-'):
2527
self.logger.warning("SageMaker resource name usually doesn't not start with 'mi-', "
2628
"did you pass the SSM instance ID by mistake?")
2729
self.logger.info("Using AWS Region: %s", self.region_name)
28-
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name, arn_filter)
30+
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name, arn_filter_regex)
2931

3032
while not mi_ids and timeout_in_sec > 0:
3133
self.logger.info(f"No instance IDs found. Retrying. Is SSM Agent running on the remote? "
@@ -47,7 +49,7 @@ def get_instance_ids(self, arn_resource_type, arn_resource_name,
4749
return mi_ids
4850

4951
@abstractmethod
50-
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter: str = None):
52+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name, arn_filter_regex: str = None):
5153
raise NotImplementedError("Abstract method")
5254

5355

@@ -104,8 +106,12 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0):
104106

105107
def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0):
106108
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
109+
if not domain_id:
110+
arn_filter = f":app/.*/{user_profile_name}/"
111+
else:
112+
arn_filter = f":app/{domain_id}/{user_profile_name}/"
107113
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec,
108-
arn_filter=f":app/{domain_id}/{user_profile_name}/")
114+
arn_filter_regex=arn_filter)
109115

110116
def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0):
111117
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
@@ -116,7 +122,7 @@ def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0):
116122
return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec)
117123

118124
def get_instance_ids_once(self, arn_resource_type, arn_resource_name,
119-
arn_filter: str = None):
125+
arn_filter_regex: str = None):
120126
# TODO: use tag filter instead, for faster performance
121127
all_instances = self.list_all_instances_with_tags()
122128
result_pairs = []
@@ -127,7 +133,7 @@ def get_instance_ids_once(self, arn_resource_type, arn_resource_name,
127133
if f"/{arn_resource_name}" in tags["SSHResourceArn"] and \
128134
arn_resource_name == tags["SSHResourceName"] and \
129135
f":{arn_resource_type}/" in tags["SSHResourceArn"] and \
130-
(not arn_filter or arn_filter in tags["SSHResourceArn"]):
136+
(not arn_filter_regex or re.search(arn_filter_regex, tags["SSHResourceArn"]) is not None):
131137
if "SSHTimestamp" in tags:
132138
timestamp = tags["SSHTimestamp"]
133139
else:

sagemaker_ssh_helper/sm-local-ssh-ide

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,14 @@ if [[ "$COMMAND" == "connect" ]]; then
3232
if [ -f ~/.sm-studio-domain-id ]; then
3333
DOMAIN_ID="$(cat ~/.sm-studio-domain-id)"
3434
else
35-
echo "sm-local-ssh-ide: WARNING: SageMaker Studio domain ID is not set, "\
36-
"will attempt to connect to the latest active kernel gateway with name '$SM_STUDIO_KGW_NAME' in the region. "\
35+
echo "sm-local-ssh-ide: WARNING: SageMaker Studio domain ID is not set."\
3736
"Run 'sm-local-ssh-ide set-domain-id' to override."
3837
fi
3938
USER_PROFILE_NAME=""
4039
if [ -f ~/.sm-studio-user-profile-name ]; then
4140
USER_PROFILE_NAME="$(cat ~/.sm-studio-user-profile-name)"
4241
else
43-
echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set, "\
44-
"will attempt to connect to the latest active kernel gateway with name '$SM_STUDIO_KGW_NAME' in the region. "\
42+
echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\
4543
"Run 'sm-local-ssh-ide set-user-profile-name' to override."
4644
fi
4745

tests/test_ide.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def test_studio_internet_free_mode(request):
227227
ide.delete_kernel_app("byoi-studio-app", wait=False)
228228

229229

230+
# noinspection DuplicatedCode
230231
def test_studio_multiple_users(request):
231232
ide_ds = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-data-science')
232233
ide_pt = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-pytorch')
@@ -269,6 +270,50 @@ def test_studio_multiple_users(request):
269270
assert "test-data-science" in user_profile_name
270271

271272

273+
# noinspection DuplicatedCode
274+
def test_studio_default_domain_multiple_users(request):
275+
ide_ds = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-data-science')
276+
ide_pt = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-pytorch')
277+
278+
ide_ds.create_ssh_kernel_app(
279+
'ssh-test-user',
280+
image_name_or_arn='sagemaker-data-science-310-v1',
281+
instance_type='ml.m5.large',
282+
ssh_lifecycle_config='sagemaker-ssh-helper-dev',
283+
recreate=True
284+
)
285+
286+
# Give a head start
287+
time.sleep(60)
288+
289+
ide_pt.create_ssh_kernel_app(
290+
'ssh-test-user',
291+
image_name_or_arn='sagemaker-data-science-310-v1',
292+
instance_type='ml.m5.large',
293+
ssh_lifecycle_config='sagemaker-ssh-helper-dev',
294+
recreate=True
295+
)
296+
297+
# Give time for instance ID to propagate
298+
time.sleep(60)
299+
300+
# Empty domain "" to fetch the latest profile, useful when switching between many AWS accounts with the same profile
301+
studio_ids = SSHIDE("", 'test-data-science').get_kernel_instance_ids('ssh-test-user', timeout_in_sec=300)
302+
studio_id = studio_ids[0]
303+
304+
with SSMProxy(10022) as ssm_proxy:
305+
ssm_proxy.connect_to_ssm_instance(studio_id)
306+
307+
user_profile_name = ssm_proxy.run_command_with_output("sm-ssh-ide get-user-profile-name")
308+
user_profile_name = user_profile_name.decode('latin1')
309+
logger.info(f"Collected SageMaker Studio profile name: {user_profile_name}")
310+
311+
ide_ds.delete_kernel_app('ssh-test-user', wait=False)
312+
ide_pt.delete_kernel_app('ssh-test-user', wait=False)
313+
314+
assert "test-data-science" in user_profile_name
315+
316+
272317
def test_studio_notebook_in_firefox(request):
273318
ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), 'test-data-science')
274319

tests/test_ssm_manager.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,74 @@ def test_can_filter_instances_by_timestamp():
168168
assert "mi-01234567890abcd01" in ids
169169
assert "mi-01234567890abcd03" in ids
170170
assert "mi-01234567890abcd04" in ids
171+
172+
173+
# noinspection DuplicatedCode
174+
def test_can_filter_by_domain_and_user():
175+
manager = SSMManager(redo_attempts=0)
176+
manager.list_all_instances_with_tags = Mock(return_value={
177+
"mi-01234567890abcd07": {
178+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
179+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
180+
"SSHCreator": "",
181+
"SSHOwner": "",
182+
"SSHTimestamp": 2
183+
},
184+
"mi-01234567890abcd08": {
185+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
186+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789bc/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
187+
"SSHCreator": "",
188+
"SSHOwner": "",
189+
"SSHTimestamp": 3
190+
},
191+
"mi-01234567890abcd09": {
192+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
193+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-5555555555555/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
194+
"SSHCreator": "",
195+
"SSHOwner": "",
196+
"SSHTimestamp": 1
197+
},
198+
})
199+
200+
ids = manager.get_studio_user_kgw_instance_ids(
201+
"d-0123456789bc", "default-1111111111111",
202+
"sagemaker-data-science-ml-m5-large-1234567890abcdef0"
203+
)
204+
assert len(ids) == 1
205+
assert ids[0] == "mi-01234567890abcd08"
206+
207+
208+
# noinspection DuplicatedCode
209+
def test_can_filter_by_user_with_latest_domain():
210+
manager = SSMManager(redo_attempts=0)
211+
manager.list_all_instances_with_tags = Mock(return_value={
212+
"mi-01234567890abcd07": {
213+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
214+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
215+
"SSHCreator": "",
216+
"SSHOwner": "",
217+
"SSHTimestamp": 2
218+
},
219+
"mi-01234567890abcd08": {
220+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
221+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789bc/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
222+
"SSHCreator": "",
223+
"SSHOwner": "",
224+
"SSHTimestamp": 3
225+
},
226+
"mi-01234567890abcd09": {
227+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
228+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-5555555555555/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
229+
"SSHCreator": "",
230+
"SSHOwner": "",
231+
"SSHTimestamp": 1
232+
},
233+
})
234+
235+
ids = manager.get_studio_user_kgw_instance_ids(
236+
"", "default-1111111111111",
237+
"sagemaker-data-science-ml-m5-large-1234567890abcdef0"
238+
)
239+
assert len(ids) == 2
240+
assert ids[0] == "mi-01234567890abcd08"
241+
assert ids[1] == "mi-01234567890abcd07"

0 commit comments

Comments
 (0)