Skip to content

Commit 0f7192c

Browse files
Improved logging and optimized connect procedure
1 parent 7b0a1b3 commit 0f7192c

14 files changed

+507
-118
lines changed

sagemaker_ssh_helper/log.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
from datetime import datetime, timedelta
55

66
import boto3
7+
from botocore.exceptions import ClientError
8+
from sagemaker import Session
9+
10+
from sagemaker_ssh_helper.aws import AWS
711

812

913
class SSHLog:
10-
logger = logging.getLogger('sagemaker-ssh-helper')
14+
logger = logging.getLogger('sagemaker-ssh-helper:SSHLog')
1115

1216
def __init__(self, region_name=None) -> None:
1317
super().__init__()
14-
self.region_name = region_name
18+
self.region_name = region_name or Session().boto_region_name
19+
self.aws_console = AWS(self.region_name)
1520

1621
def get_ip_addresses(self, training_job_name, retry=0):
1722
SSHLog.logger.info(f"Querying SSH IP addresses for job {training_job_name}")
@@ -80,10 +85,11 @@ def get_ssm_instance_ids_once(self, log_group, stream_name):
8085

8186
def get_ssm_instance_ids(self, log_group, stream_name, retry=0, sleep_between_retries_seconds=10,
8287
expected_count=1):
88+
self.logger.info("Using AWS Region: %s", self.region_name)
8389
mi_ids = self.get_ssm_instance_ids_once(log_group, stream_name)
8490

8591
while not mi_ids and retry > 0:
86-
SSHLog.logger.info(f"SSH Helper not yet started? Retrying. Attempts left: {retry}")
92+
SSHLog.logger.info(f"SSH Helper not yet started on the remote? Retrying. Attempts left: {retry}")
8793
time.sleep(sleep_between_retries_seconds)
8894
mi_ids = self.get_ssm_instance_ids_once(log_group, stream_name)
8995
retry -= 1
@@ -102,12 +108,19 @@ def get_ssm_instance_ids(self, log_group, stream_name, retry=0, sleep_between_re
102108

103109
def _query_log_group(self, log_group, query):
104110
boto_client = boto3.client('logs', region_name=self.region_name)
105-
start_query_response = boto_client.start_query(
106-
logGroupName=log_group,
107-
startTime=int((datetime.now() - timedelta(weeks=2)).timestamp()),
108-
endTime=int(datetime.now().timestamp()),
109-
queryString=query
110-
)
111+
try:
112+
start_query_response = boto_client.start_query(
113+
logGroupName=log_group,
114+
startTime=int((datetime.now() - timedelta(weeks=2)).timestamp()),
115+
endTime=int(datetime.now().timestamp()),
116+
queryString=query
117+
)
118+
except ClientError as e:
119+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
120+
return []
121+
else:
122+
raise
123+
111124
query_id = start_query_response['queryId']
112125
response = None
113126
while response is None or response['status'] == 'Running':
@@ -117,3 +130,72 @@ def _query_log_group(self, log_group, query):
117130
)
118131
lines = response['results']
119132
return lines
133+
134+
def get_training_cloudwatch_url(self, training_job_name):
135+
return f"https://{self.aws_console.get_console_domain()}/" \
136+
f"cloudwatch/home?region={self.region_name}#" \
137+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs$3F" \
138+
f"logStreamNameFilter$3D{training_job_name}$252F"
139+
140+
def get_training_metadata_url(self, training_job_name):
141+
return f"https://{self.aws_console.get_console_domain()}/" \
142+
f"sagemaker/home?region={self.region_name}#" \
143+
f"/jobs/{training_job_name}"
144+
145+
def get_endpoint_cloudwatch_url(self, endpoint_name):
146+
return f"https://{self.aws_console.get_console_domain()}/" \
147+
f"cloudwatch/home?region={self.region_name}#" \
148+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252F{endpoint_name}"
149+
150+
def get_endpoint_metadata_url(self, endpoint_name):
151+
return f"https://{self.aws_console.get_console_domain()}/" \
152+
f"sagemaker/home?region={self.region_name}#" \
153+
f"/endpoints/{endpoint_name}"
154+
155+
def get_endpoint_config_metadata_url(self, endpoint_config_name):
156+
return f"https://{self.aws_console.get_console_domain()}/" \
157+
f"sagemaker/home?region={self.region_name}#" \
158+
f"/endpointConfig/{endpoint_config_name}"
159+
160+
def get_model_metadata_url(self, model_name):
161+
return f"https://{self.aws_console.get_console_domain()}/" \
162+
f"sagemaker/home?region={self.region_name}#" \
163+
f"/models/{model_name}"
164+
165+
def get_processing_cloudwatch_url(self, processing_job_name):
166+
return f"https://{self.aws_console.get_console_domain()}/" \
167+
f"cloudwatch/home?region={self.region_name}#" \
168+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FProcessingJobs$3F" \
169+
f"logStreamNameFilter$3D{processing_job_name}$252F"
170+
171+
def get_processing_metadata_url(self, processing_job_name):
172+
return f"https://{self.aws_console.get_console_domain()}/" \
173+
f"sagemaker/home?region={self.region_name}#" \
174+
f"/processing-jobs/{processing_job_name}"
175+
176+
def get_transform_cloudwatch_url(self, transform_job_name):
177+
return f"https://{self.aws_console.get_console_domain()}/" \
178+
f"cloudwatch/home?region={self.region_name}#" \
179+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTransformJobs$3F" \
180+
f"logStreamNameFilter$3D{transform_job_name}$252F"
181+
182+
def get_transform_metadata_url(self, transform_job_name):
183+
return f"https://{self.aws_console.get_console_domain()}/" \
184+
f"sagemaker/home?region={self.region_name}#" \
185+
f"/transform-jobs/{transform_job_name}"
186+
187+
def get_ide_cloudwatch_url(self, domain, user, app_name):
188+
if domain and user:
189+
return f"https://{self.aws_console.get_console_domain()}/" \
190+
f"cloudwatch/home?region={self.region_name}#" \
191+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
192+
f"$3FlogStreamNameFilter$3D{domain}$252F{user}$252FKernelGateway$252F{app_name}"
193+
return f"https://{self.aws_console.get_console_domain()}/" \
194+
f"cloudwatch/home?region={self.region_name}#" \
195+
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
196+
f"$3FlogStreamNameFilter$3DKernelGateway$252F{app_name}"
197+
198+
def get_ide_metadata_url(self, domain, user):
199+
return f"https://{self.aws_console.get_console_domain()}/" \
200+
f"sagemaker/home?region={self.region_name}#" \
201+
f"/studio/{domain}/user/{user}"

sagemaker_ssh_helper/manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, region_name=None, sleep_between_retries_in_seconds=10, redo_a
1414
self.clock_timestamp_override = clock_timestamp_override
1515
self.redo_attempts = redo_attempts
1616
self.sleep_between_retries_in_seconds = sleep_between_retries_in_seconds
17-
self.region_name = region_name
17+
self.region_name = region_name or boto3.session.Session().region_name
1818

1919
def list_all_instances_with_tags(self) -> Dict[str, Dict[str, str]]:
2020
ssm = boto3.client('ssm', region_name=self.region_name)
@@ -68,6 +68,7 @@ def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0):
6868
return self.get_instance_ids('notebook-instance', f"{instance_name}", timeout_in_sec)
6969

7070
def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
71+
# TODO: use tag filter instead, for faster performance
7172
all_instances = self.list_all_instances_with_tags()
7273
result_pairs = []
7374
for mi_id in all_instances:
@@ -90,10 +91,14 @@ def get_instance_ids_once(self, arn_resource_type, arn_resource_name):
9091
def get_instance_ids(self, arn_resource_type, arn_resource_name,
9192
timeout_in_sec=0,
9293
expected_count=1):
94+
if arn_resource_name.startswith('mi-'):
95+
self.logger.warning("SageMaker resource name usually doesn't not start with 'mi-', "
96+
"did you pass the SSM instance ID by mistake?")
97+
self.logger.info("Using AWS Region: %s", self.region_name)
9398
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
9499

95100
while not mi_ids and timeout_in_sec > 0:
96-
self.logger.info(f"SSH Helper not yet started? Retrying. Seconds left: {timeout_in_sec}")
101+
self.logger.info(f"SSM Agent not yet started on the remote? Retrying. Seconds left: {timeout_in_sec}")
97102
time.sleep(self.sleep_between_retries_in_seconds)
98103
mi_ids = self.get_instance_ids_once(arn_resource_type, arn_resource_name)
99104
timeout_in_sec -= self.sleep_between_retries_in_seconds

sagemaker_ssh_helper/proxy.py

Lines changed: 128 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,81 @@
11
import logging
22
import os
3+
import socket
34
import subprocess
5+
import sys
46
import time
57
from abc import ABC
8+
from queue import Queue, Empty
9+
from threading import Thread
10+
from typing import Optional
611

712
import psutil
813

914

1015
class SSMProxy(ABC):
1116
logger = logging.getLogger('sagemaker-ssh-helper')
1217

13-
def __init__(self, ssh_listen_port: int, extra_args: str = "", region_name: str = None) -> None:
18+
def __init__(self, ssh_listen_port: int, extra_args: str = "", region_name: str = None,
19+
cloudwatch_url: str = None) -> None:
1420
super().__init__()
15-
self.p = None
21+
self.cloudwatch_url = cloudwatch_url
22+
self.p: Optional[subprocess.Popen] = None
23+
self.q: Optional[Queue] = None
24+
self.t: Optional[Thread] = None
1625
self.region_name = region_name
1726
self.extra_args = extra_args
1827
self.ssh_listen_port = ssh_listen_port
1928

2029
def connect_to_ssm_instance(self, instance_id) -> None:
21-
self.logger.info(f"Connecting to {instance_id} with SSM and start SSH forwarding "
22-
f"on local port {self.ssh_listen_port} with extra args: '{self.extra_args}'")
30+
self.logger.info(
31+
f"Connecting to {instance_id} with SSM and starting SSH port forwarding "
32+
f"on local port {self.ssh_listen_port}"
33+
+ (f" with extra args: '{self.extra_args}'" if self.extra_args else '')
34+
)
2335

2436
env = os.environ.copy()
2537
if self.region_name:
26-
self.logger.info(f"Overriding default region: {self.region_name}")
38+
self.logger.info(f"Setting AWS Region for SSH: {self.region_name}")
2739
env["AWS_REGION"] = self.region_name
2840
env["AWS_DEFAULT_REGION"] = self.region_name
2941

42+
env["LC_ALL"] = "C"
43+
3044
# The script will create a new SSH key in ~/.ssh/sagemaker-ssh-gw
3145
# and transfer the public key ~/.ssh/sagemaker-ssh-gw.pub to the instance via S3
32-
self.p = subprocess.Popen(f"sm-local-start-ssh {instance_id}"
33-
f" -L localhost:{self.ssh_listen_port}:localhost:22"
34-
f" {self.extra_args}"
35-
" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
36-
.split(' '), env=env)
37-
38-
time.sleep(30) # allow 30 sec to initialize
39-
40-
self.logger.info(f"Getting remote Python version as a health check")
41-
42-
output = self.run_command_with_output("python --version 2>&1")
46+
self.p = subprocess.Popen(
47+
f"sm-local-start-ssh {instance_id}"
48+
f" -N -L localhost:{self.ssh_listen_port}:localhost:22"
49+
" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
50+
f" {self.extra_args}"
51+
.split(' '),
52+
env=env,
53+
stdout=subprocess.PIPE,
54+
stderr=subprocess.STDOUT,
55+
bufsize=0,
56+
close_fds=('posix' in sys.builtin_module_names)
57+
)
58+
59+
def enqueue_output(out, queue):
60+
for line in iter(out.readline, b''):
61+
queue.put(line)
62+
out.close()
63+
64+
#
65+
self.q = Queue()
66+
self.t = Thread(target=enqueue_output, args=(self.p.stdout, self.q))
67+
self.t.daemon = True # thread dies with the program
68+
self.t.start()
69+
70+
self.logger.info(f"Getting remote system information as a health check")
71+
72+
output = self.run_command_with_output("uname -a 2>&1")
4373
output_str = output.decode("latin1")
4474

4575
self.logger.info("Got output from the remote: " + output_str.replace("\n", " "))
4676

47-
if not output_str.startswith("Python"):
48-
raise AssertionError("Failed to get Python version")
77+
if not output_str.startswith("Linux"):
78+
raise ValueError("Failed to get system version. Got instead: " + output_str)
4979

5080
def terminate_waiting_loop(self):
5181
self.logger.info("Terminating the remote waiting loop / sleep process")
@@ -64,25 +94,78 @@ def terminate_waiting_loop(self):
6494
break
6595

6696
if retval != 0:
67-
raise AssertionError(f"Return value is not zero: {retval}. Do you need to you increase "
68-
f"'connection_wait_time' parameter?")
97+
raise ValueError(
98+
f"Return value is not zero: {retval}. Do you need to you increase "
99+
f"'connection_wait_time' parameter?"
100+
)
69101
self.logger.info("Successfully terminated the waiting loop")
70102

71103
def run_command(self, command):
72-
retval = subprocess.call(f"ssh root@localhost -p {self.ssh_listen_port}"
73-
" -i ~/.ssh/sagemaker-ssh-gw"
74-
" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
75-
f" {command}"
76-
.split(' '))
104+
retval = subprocess.call(
105+
f"ssh -4 root@localhost -p {self.ssh_listen_port}"
106+
" -i ~/.ssh/sagemaker-ssh-gw"
107+
" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
108+
f" {command}"
109+
.split(' '))
77110
return retval
78111

79112
def run_command_with_output(self, command):
80-
return subprocess.check_output(f"ssh root@localhost -p {self.ssh_listen_port}"
81-
" -i ~/.ssh/sagemaker-ssh-gw"
82-
" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
83-
" -o ConnectTimeout=10"
84-
f" {command}"
85-
.split(' '))
113+
self._wait_for_tcp_port()
114+
115+
try:
116+
# Pre-fetching the key to avoid the 'Warning: Permanently added ... to the list of known hosts' in output
117+
retval = os.system(f"ssh-keyscan -4 -H -p {self.ssh_listen_port} localhost >>~/.ssh/known_hosts") # nosec start_process_with_a_shell
118+
if retval != 0:
119+
self.logger.error(f"Failed to fetch host key. Return value is not zero: {retval}.")
120+
# No exception here, need to try the command anyway
121+
122+
env = os.environ.copy()
123+
env["LC_ALL"] = "C"
124+
125+
return subprocess.check_output(
126+
f"ssh -4 root@localhost -p {self.ssh_listen_port}"
127+
" -i ~/.ssh/sagemaker-ssh-gw"
128+
" -o PasswordAuthentication=no"
129+
" -o ConnectTimeout=10"
130+
f" {command}"
131+
.split(' '),
132+
stderr=subprocess.STDOUT,
133+
env=env
134+
)
135+
except subprocess.CalledProcessError as e:
136+
out = e.output.decode('latin1')
137+
proxy_out = self.fetch_proxy_output()
138+
raise ValueError(
139+
f"Failed to run command: {command}. "
140+
f"Return code: {e.returncode}. "
141+
f"\n---Begin proxy output:---\n{proxy_out}---End proxy output--- "
142+
f"\n---Begin output:---\n{out}---End output---. "
143+
f"Check your local log, stdout, and stderr "
144+
f"as well as remote logs{' at ' + self.cloudwatch_url if self.cloudwatch_url else ''} "
145+
f"for more details, if needed."
146+
) from e
147+
148+
def fetch_proxy_output(self):
149+
array_of_byte_strings = []
150+
while True:
151+
try:
152+
line = self.q.get(timeout=2)
153+
array_of_byte_strings += [line]
154+
except Empty:
155+
break
156+
proxy_out = "".join([x.decode('latin1') for x in array_of_byte_strings])
157+
return proxy_out
158+
159+
def _wait_for_tcp_port(self, timeout=45):
160+
# Use 127.0.0.1 here to avoid AF_INET6 resolution that can give errors
161+
self.logger.info(f"Connecting to 127.0.0.1:{self.ssh_listen_port}")
162+
for i in range(0, timeout):
163+
try:
164+
with socket.create_connection(("127.0.0.1", self.ssh_listen_port), 2):
165+
self.logger.info(f"Connection to 127.0.0.1:{self.ssh_listen_port} is successful")
166+
break
167+
except ConnectionRefusedError:
168+
time.sleep(1)
86169

87170
def disconnect(self):
88171
self.logger.info(f"Disconnecting proxy and stopping SSH port forwarding")
@@ -93,3 +176,17 @@ def disconnect(self):
93176
parent.terminate()
94177
except psutil.NoSuchProcess:
95178
pass
179+
180+
def __enter__(self, *args):
181+
"""
182+
Usage:
183+
184+
with SSMProxy(local_port) as ssm_proxy:
185+
ssm_proxy.connect_to_ssm_instance(instance_id)
186+
...
187+
188+
"""
189+
return self
190+
191+
def __exit__(self, *args):
192+
self.disconnect()

0 commit comments

Comments
 (0)