11import logging
22import os
3+ import socket
34import subprocess
5+ import sys
46import time
57from abc import ABC
8+ from queue import Queue , Empty
9+ from threading import Thread
10+ from typing import Optional
611
712import psutil
813
914
1015class SSMProxy (ABC ):
1116 logger = logging .getLogger ('sagemaker-ssh-helper' )
1217
13- def __init__ (self , ssh_listen_port : int , extra_args : str = "" , region_name : str = None ) -> None :
18+ def __init__ (self , ssh_listen_port : int , extra_args : str = "" , region_name : str = None ,
19+ cloudwatch_url : str = None ) -> None :
1420 super ().__init__ ()
15- self .p = None
21+ self .cloudwatch_url = cloudwatch_url
22+ self .p : Optional [subprocess .Popen ] = None
23+ self .q : Optional [Queue ] = None
24+ self .t : Optional [Thread ] = None
1625 self .region_name = region_name
1726 self .extra_args = extra_args
1827 self .ssh_listen_port = ssh_listen_port
1928
2029 def connect_to_ssm_instance (self , instance_id ) -> None :
21- self .logger .info (f"Connecting to { instance_id } with SSM and start SSH forwarding "
22- f"on local port { self .ssh_listen_port } with extra args: '{ self .extra_args } '" )
30+ self .logger .info (
31+ f"Connecting to { instance_id } with SSM and starting SSH port forwarding "
32+ f"on local port { self .ssh_listen_port } "
33+ + (f" with extra args: '{ self .extra_args } '" if self .extra_args else '' )
34+ )
2335
2436 env = os .environ .copy ()
2537 if self .region_name :
26- self .logger .info (f"Overriding default region : { self .region_name } " )
38+ self .logger .info (f"Setting AWS Region for SSH : { self .region_name } " )
2739 env ["AWS_REGION" ] = self .region_name
2840 env ["AWS_DEFAULT_REGION" ] = self .region_name
2941
42+ env ["LC_ALL" ] = "C"
43+
3044 # The script will create a new SSH key in ~/.ssh/sagemaker-ssh-gw
3145 # and transfer the public key ~/.ssh/sagemaker-ssh-gw.pub to the instance via S3
32- self .p = subprocess .Popen (f"sm-local-start-ssh { instance_id } "
33- f" -L localhost:{ self .ssh_listen_port } :localhost:22"
34- f" { self .extra_args } "
35- " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
36- .split (' ' ), env = env )
37-
38- time .sleep (30 ) # allow 30 sec to initialize
39-
40- self .logger .info (f"Getting remote Python version as a health check" )
41-
42- output = self .run_command_with_output ("python --version 2>&1" )
46+ self .p = subprocess .Popen (
47+ f"sm-local-start-ssh { instance_id } "
48+ f" -N -L localhost:{ self .ssh_listen_port } :localhost:22"
49+ " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
50+ f" { self .extra_args } "
51+ .split (' ' ),
52+ env = env ,
53+ stdout = subprocess .PIPE ,
54+ stderr = subprocess .STDOUT ,
55+ bufsize = 0 ,
56+ close_fds = ('posix' in sys .builtin_module_names )
57+ )
58+
59+ def enqueue_output (out , queue ):
60+ for line in iter (out .readline , b'' ):
61+ queue .put (line )
62+ out .close ()
63+
64+ #
65+ self .q = Queue ()
66+ self .t = Thread (target = enqueue_output , args = (self .p .stdout , self .q ))
67+ self .t .daemon = True # thread dies with the program
68+ self .t .start ()
69+
70+ self .logger .info (f"Getting remote system information as a health check" )
71+
72+ output = self .run_command_with_output ("uname -a 2>&1" )
4373 output_str = output .decode ("latin1" )
4474
4575 self .logger .info ("Got output from the remote: " + output_str .replace ("\n " , " " ))
4676
47- if not output_str .startswith ("Python " ):
48- raise AssertionError ("Failed to get Python version" )
77+ if not output_str .startswith ("Linux " ):
78+ raise ValueError ("Failed to get system version. Got instead: " + output_str )
4979
5080 def terminate_waiting_loop (self ):
5181 self .logger .info ("Terminating the remote waiting loop / sleep process" )
@@ -64,25 +94,78 @@ def terminate_waiting_loop(self):
6494 break
6595
6696 if retval != 0 :
67- raise AssertionError (f"Return value is not zero: { retval } . Do you need to you increase "
68- f"'connection_wait_time' parameter?" )
97+ raise ValueError (
98+ f"Return value is not zero: { retval } . Do you need to you increase "
99+ f"'connection_wait_time' parameter?"
100+ )
69101 self .logger .info ("Successfully terminated the waiting loop" )
70102
71103 def run_command (self , command ):
72- retval = subprocess .call (f"ssh root@localhost -p { self .ssh_listen_port } "
73- " -i ~/.ssh/sagemaker-ssh-gw"
74- " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
75- f" { command } "
76- .split (' ' ))
104+ retval = subprocess .call (
105+ f"ssh -4 root@localhost -p { self .ssh_listen_port } "
106+ " -i ~/.ssh/sagemaker-ssh-gw"
107+ " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
108+ f" { command } "
109+ .split (' ' ))
77110 return retval
78111
79112 def run_command_with_output (self , command ):
80- return subprocess .check_output (f"ssh root@localhost -p { self .ssh_listen_port } "
81- " -i ~/.ssh/sagemaker-ssh-gw"
82- " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
83- " -o ConnectTimeout=10"
84- f" { command } "
85- .split (' ' ))
113+ self ._wait_for_tcp_port ()
114+
115+ try :
116+ # Pre-fetching the key to avoid the 'Warning: Permanently added ... to the list of known hosts' in output
117+ retval = os .system (f"ssh-keyscan -4 -H -p { self .ssh_listen_port } localhost >>~/.ssh/known_hosts" ) # nosec start_process_with_a_shell
118+ if retval != 0 :
119+ self .logger .error (f"Failed to fetch host key. Return value is not zero: { retval } ." )
120+ # No exception here, need to try the command anyway
121+
122+ env = os .environ .copy ()
123+ env ["LC_ALL" ] = "C"
124+
125+ return subprocess .check_output (
126+ f"ssh -4 root@localhost -p { self .ssh_listen_port } "
127+ " -i ~/.ssh/sagemaker-ssh-gw"
128+ " -o PasswordAuthentication=no"
129+ " -o ConnectTimeout=10"
130+ f" { command } "
131+ .split (' ' ),
132+ stderr = subprocess .STDOUT ,
133+ env = env
134+ )
135+ except subprocess .CalledProcessError as e :
136+ out = e .output .decode ('latin1' )
137+ proxy_out = self .fetch_proxy_output ()
138+ raise ValueError (
139+ f"Failed to run command: { command } . "
140+ f"Return code: { e .returncode } . "
141+ f"\n ---Begin proxy output:---\n { proxy_out } ---End proxy output--- "
142+ f"\n ---Begin output:---\n { out } ---End output---. "
143+ f"Check your local log, stdout, and stderr "
144+ f"as well as remote logs{ ' at ' + self .cloudwatch_url if self .cloudwatch_url else '' } "
145+ f"for more details, if needed."
146+ ) from e
147+
148+ def fetch_proxy_output (self ):
149+ array_of_byte_strings = []
150+ while True :
151+ try :
152+ line = self .q .get (timeout = 2 )
153+ array_of_byte_strings += [line ]
154+ except Empty :
155+ break
156+ proxy_out = "" .join ([x .decode ('latin1' ) for x in array_of_byte_strings ])
157+ return proxy_out
158+
159+ def _wait_for_tcp_port (self , timeout = 45 ):
160+ # Use 127.0.0.1 here to avoid AF_INET6 resolution that can give errors
161+ self .logger .info (f"Connecting to 127.0.0.1:{ self .ssh_listen_port } " )
162+ for i in range (0 , timeout ):
163+ try :
164+ with socket .create_connection (("127.0.0.1" , self .ssh_listen_port ), 2 ):
165+ self .logger .info (f"Connection to 127.0.0.1:{ self .ssh_listen_port } is successful" )
166+ break
167+ except ConnectionRefusedError :
168+ time .sleep (1 )
86169
87170 def disconnect (self ):
88171 self .logger .info (f"Disconnecting proxy and stopping SSH port forwarding" )
@@ -93,3 +176,17 @@ def disconnect(self):
93176 parent .terminate ()
94177 except psutil .NoSuchProcess :
95178 pass
179+
180+ def __enter__ (self , * args ):
181+ """
182+ Usage:
183+
184+ with SSMProxy(local_port) as ssm_proxy:
185+ ssm_proxy.connect_to_ssm_instance(instance_id)
186+ ...
187+
188+ """
189+ return self
190+
191+ def __exit__ (self , * args ):
192+ self .disconnect ()
0 commit comments