Skip to content

Commit 4f3570e

Browse files
Introducing SSMManager, a faster alternative to SSHLog for fetching instance IDs
1 parent 2364b7f commit 4f3570e

File tree

2 files changed

+311
-0
lines changed

2 files changed

+311
-0
lines changed

sagemaker_ssh_helper/manager.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import logging
2+
import time
3+
4+
import boto3
5+
from typing import Dict
6+
7+
8+
class SSMManager:
9+
logger = logging.getLogger('sagemaker-ssh-helper:SSMManager')
10+
11+
def __init__(self, region_name=None, sleep_between_retries_in_seconds=10, redo_attempts=5,
12+
clock_timestamp_override=None) -> None:
13+
super().__init__()
14+
self.clock_timestamp_override = clock_timestamp_override
15+
self.redo_attempts = redo_attempts
16+
self.sleep_between_retries_in_seconds = sleep_between_retries_in_seconds
17+
self.region_name = region_name
18+
19+
def list_all_instances_with_tags(self) -> Dict[str, Dict[str, str]]:
20+
ssm = boto3.client('ssm', region_name=self.region_name)
21+
22+
result = {}
23+
next_token = "" # nosec hardcoded_password_string # not a password
24+
while next_token is not None:
25+
response = ssm.describe_instance_information(
26+
Filters=[{'Key': 'ResourceType', 'Values': ['ManagedInstance']}],
27+
NextToken=next_token,
28+
MaxResults=50,
29+
)
30+
next_token = response.get('NextToken')
31+
info_list = response['InstanceInformationList']
32+
if info_list:
33+
for info in info_list:
34+
instance_id = info['InstanceId']
35+
tags = ssm.list_tags_for_resource(ResourceType='ManagedInstance', ResourceId=instance_id)
36+
tags_dict = {}
37+
if 'TagList' in tags:
38+
for tag in tags['TagList']:
39+
tags_dict[tag['Key']] = tag['Value']
40+
tags_dict['$__SSMManager__.PingStatus'] = info['PingStatus']
41+
result[instance_id] = tags_dict
42+
43+
return result
44+
45+
def get_training_instance_ids(self, training_job_name, timeout_in_sec=0, expected_count=1):
46+
self.logger.info(f"Querying SSM instance IDs for training job {training_job_name}, "
47+
f"expected instance count = {expected_count}")
48+
return self.get_instance_ids('training-job', training_job_name, timeout_in_sec,
49+
expected_count)
50+
51+
def get_processing_instance_ids(self, processing_job_name, timeout_in_sec=0):
52+
self.logger.info(f"Querying SSM instance IDs for processing job {processing_job_name}")
53+
return self.get_instance_ids('processing-job', processing_job_name, timeout_in_sec)
54+
55+
def get_endpoint_instance_ids(self, endpoint_name, timeout_in_sec=0):
56+
raise AssertionError("Not supported yet.")
57+
58+
def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0):
59+
self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}")
60+
return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec)
61+
62+
def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0):
63+
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway {kgw_name}")
64+
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec)
65+
66+
def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0):
67+
self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}")
68+
return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec)
69+
70+
def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
71+
all_instances = self.list_all_instances_with_tags()
72+
result_pairs = []
73+
for mi_id in all_instances:
74+
tags = all_instances[mi_id]
75+
if "SSHResourceName" not in tags or "SSHResourceArn" not in tags:
76+
continue
77+
if f"/{arn_resource_name}" in tags["SSHResourceArn"] and \
78+
arn_resource_name == tags["SSHResourceName"] and \
79+
f":{arn_resource_type}/" in tags["SSHResourceArn"]:
80+
if "SSHTimestamp" in tags:
81+
timestamp = tags["SSHTimestamp"]
82+
else:
83+
timestamp = 0
84+
result_pairs.append((mi_id, timestamp))
85+
86+
result_pairs.sort(key=lambda i: i[1], reverse=True)
87+
result = [i[0] for i in result_pairs]
88+
return result
89+
90+
def get_instance_ids(self, arn_resource_type, arn_resource_name,
91+
timeout_in_sec=0,
92+
expected_count=1):
93+
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
94+
95+
while not mi_ids and timeout_in_sec > 0:
96+
self.logger.info(f"SSH Helper not yet started? Retrying. Seconds left: {timeout_in_sec}")
97+
time.sleep(self.sleep_between_retries_in_seconds)
98+
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
99+
timeout_in_sec -= self.sleep_between_retries_in_seconds
100+
101+
self.logger.info(f"Got preliminary SSM instance IDs: {mi_ids}")
102+
103+
redo_attempts = self.redo_attempts
104+
# noinspection DuplicatedCode
105+
while len(mi_ids) < expected_count and redo_attempts > 0:
106+
self.logger.info(f"Re-fetch results for other instances to catchup. Attempts left: {redo_attempts}")
107+
time.sleep(30)
108+
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
109+
redo_attempts -= 1
110+
111+
self.logger.info(f"Got final SSM instance IDs: {mi_ids}")
112+
return mi_ids
113+
114+
def list_expired_ssh_instances(self, expiration_days=0):
115+
all_instances = self.list_all_instances_with_tags()
116+
logging.info("Found %s instances in SSM", len(all_instances))
117+
118+
expired_instances = []
119+
for mi_id in all_instances:
120+
tags = all_instances[mi_id]
121+
if "SSHTimestamp" in tags:
122+
timestamp = int(tags["SSHTimestamp"])
123+
else:
124+
timestamp = 0
125+
if "$__SSMManager__.PingStatus" in tags:
126+
ping_status = tags["$__SSMManager__.PingStatus"]
127+
else:
128+
ping_status = "Online"
129+
if ping_status == "Online":
130+
continue
131+
if self.clock_timestamp_override is not None:
132+
expiration_timestamp = self.clock_timestamp_override
133+
else:
134+
expiration_timestamp = int(round(time.time()))
135+
expiration_timestamp -= expiration_days * 3600 * 24
136+
if timestamp < expiration_timestamp:
137+
expired_instances.append(mi_id)
138+
logging.info("Found expired offline SSH instance %s with timestamp %s", mi_id, timestamp)
139+
140+
logging.info("Found %s expired offline SSH instances", len(expired_instances))
141+
return expired_instances

tests/test_ssm_manager.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import logging
2+
3+
from mock.mock import Mock
4+
5+
from sagemaker_ssh_helper.manager import SSMManager
6+
7+
logger = logging.getLogger('sagemaker-ssh-helper')
8+
9+
10+
def test_can_fetch_instance_by_name():
11+
manager = SSMManager(redo_attempts=0)
12+
manager.list_all_instances_with_tags = Mock(return_value={
13+
"mi-01234567890abcd00": {},
14+
"mi-01234567890abcd01": {
15+
"SSHResourceName": "ssh-job",
16+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:training-job/ssh-job",
17+
"SSHCreator": "",
18+
"SSHOwner": "",
19+
"SSHTimestamp": 1677072061
20+
},
21+
"mi-01234567890abcd02": {
22+
"SSHResourceName": "ssh-job",
23+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:training-job/ssh-job",
24+
"SSHCreator": "",
25+
"SSHOwner": "",
26+
"SSHTimestamp": 1677072061
27+
},
28+
"mi-01234567890abcd03": {
29+
"SSHResourceName": "ssh-job",
30+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:processing-job/ssh-job",
31+
"SSHCreator": "",
32+
"SSHOwner": "",
33+
"SSHTimestamp": 1677071209
34+
},
35+
"mi-01234567890abcd04": {
36+
"SSHResourceName": "ssh-job",
37+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:transform-job/ssh-job",
38+
"SSHCreator": "",
39+
"SSHOwner": "",
40+
"SSHTimestamp": 1677069966
41+
},
42+
"mi-01234567890abcd05": {
43+
"SSHResourceName": "",
44+
"SSHResourceArn": "",
45+
"SSHCreator": "",
46+
"SSHOwner": "",
47+
"SSHTimestamp": 1677073892
48+
},
49+
"mi-01234567890abcd06": {
50+
"SSHResourceName": "ssh-training",
51+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:training-job/ssh-training",
52+
"SSHCreator": "",
53+
"SSHOwner": "",
54+
"SSHTimestamp": 1677077641
55+
},
56+
"mi-01234567890abcd07": {
57+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
58+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
59+
"SSHCreator": "",
60+
"SSHOwner": "",
61+
"SSHTimestamp": 1677077641
62+
}
63+
})
64+
65+
ids = manager.get_training_instance_ids("ssh-job", expected_count=2)
66+
assert len(ids) == 2
67+
assert ids[0] == "mi-01234567890abcd01"
68+
assert ids[1] == "mi-01234567890abcd02"
69+
70+
ids = manager.get_processing_instance_ids("ssh-job")
71+
assert len(ids) == 1
72+
assert ids[0] == "mi-01234567890abcd03"
73+
74+
ids = manager.get_transformer_instance_ids("ssh-job")
75+
assert len(ids) == 1
76+
assert ids[0] == "mi-01234567890abcd04"
77+
78+
ids = manager.get_studio_kgw_instance_ids("sagemaker-data-science-ml-m5-large-1234567890abcdef0")
79+
assert len(ids) == 1
80+
assert ids[0] == "mi-01234567890abcd07"
81+
82+
ids = manager.get_training_instance_ids("sagemaker-data-science-ml-m5-large-1234567890abcdef0")
83+
assert len(ids) == 0
84+
85+
ids = manager.get_studio_kgw_instance_ids("ssh-job")
86+
assert len(ids) == 0
87+
88+
89+
def test_instances_sorted_by_lru():
90+
manager = SSMManager(redo_attempts=0)
91+
manager.list_all_instances_with_tags = Mock(return_value={
92+
"mi-01234567890abcd07": {
93+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
94+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
95+
"SSHCreator": "",
96+
"SSHOwner": "",
97+
"SSHTimestamp": 2
98+
},
99+
"mi-01234567890abcd08": {
100+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
101+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
102+
"SSHCreator": "",
103+
"SSHOwner": "",
104+
"SSHTimestamp": 3
105+
},
106+
"mi-01234567890abcd09": {
107+
"SSHResourceName": "sagemaker-data-science-ml-m5-large-1234567890abcdef0",
108+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:app/d-0123456789ab/default-1111111111111/KernelGateway/sagemaker-data-science-ml-m5-large-1234567890abcdef0",
109+
"SSHCreator": "",
110+
"SSHOwner": "",
111+
"SSHTimestamp": 1
112+
},
113+
})
114+
115+
ids = manager.get_studio_kgw_instance_ids("sagemaker-data-science-ml-m5-large-1234567890abcdef0")
116+
assert len(ids) == 3
117+
assert ids[0] == "mi-01234567890abcd08"
118+
assert ids[1] == "mi-01234567890abcd07"
119+
assert ids[2] == "mi-01234567890abcd09"
120+
121+
122+
def test_can_fetch_instances_from_default_region():
123+
manager = SSMManager(redo_attempts=0)
124+
_ = manager.list_all_instances_with_tags()
125+
126+
127+
def test_can_fetch_instances_from_another_region():
128+
manager = SSMManager(region_name="eu-west-2", redo_attempts=0)
129+
_ = manager.list_all_instances_with_tags()
130+
131+
132+
def test_can_filter_instances_by_timestamp():
133+
manager = SSMManager(redo_attempts=0, clock_timestamp_override=1677158462)
134+
manager.list_all_instances_with_tags = Mock(return_value={
135+
"mi-01234567890abcd00": {},
136+
"mi-01234567890abcd01": {
137+
"SSHOwner": "",
138+
"$__SSMManager__.PingStatus": "ConnectionLost"
139+
},
140+
"mi-01234567890abcd02": {
141+
"SSHOwner": "",
142+
"$__SSMManager__.PingStatus": "Online"
143+
},
144+
"mi-01234567890abcd03": {
145+
"SSHOwner": "",
146+
"$__SSMManager__.PingStatus": "ConnectionLost"
147+
},
148+
"mi-01234567890abcd04": {
149+
"SSHResourceName": "ssh-job-1",
150+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:training-job/ssh-job-1",
151+
"SSHCreator": "",
152+
"SSHOwner": "",
153+
"SSHTimestamp": 1677072061,
154+
"$__SSMManager__.PingStatus": "ConnectionLost"
155+
},
156+
"mi-01234567890abcd05": {
157+
"SSHResourceName": "ssh-job-2",
158+
"SSHResourceArn": "arn:aws:sagemaker:eu-west-1:555555555555:training-job/ssh-job-2",
159+
"SSHCreator": "",
160+
"SSHOwner": "",
161+
"SSHTimestamp": 1677158461,
162+
"$__SSMManager__.PingStatus": "ConnectionLost"
163+
},
164+
})
165+
166+
ids = manager.list_expired_ssh_instances(expiration_days=1)
167+
assert len(ids) == 3
168+
assert "mi-01234567890abcd01" in ids
169+
assert "mi-01234567890abcd03" in ids
170+
assert "mi-01234567890abcd04" in ids

0 commit comments

Comments
 (0)