Skip to content

Commit ee8f12b

Browse files
Better support for popular SageMaker Studio images
1 parent b4eb468 commit ee8f12b

File tree

3 files changed

+206
-52
lines changed

3 files changed

+206
-52
lines changed

sagemaker_ssh_helper/ide.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import logging
2+
import time
3+
4+
import boto3
5+
from botocore.exceptions import ClientError
6+
7+
from sagemaker_ssh_helper.log import SSHLog
8+
from sagemaker_ssh_helper.manager import SSMManager
9+
10+
11+
class IDEAppStatus:
12+
13+
def __init__(self, status, failure_reason) -> None:
14+
super().__init__()
15+
self.failure_reason = failure_reason
16+
self.status = status
17+
18+
def is_pending(self):
19+
return self.status == 'Pending'
20+
21+
def is_in_transition(self):
22+
return self.status == 'Deleting' or self.status == 'Pending'
23+
24+
def is_deleting(self):
25+
return self.status == 'Deleting'
26+
27+
def is_in_service(self):
28+
return self.status == 'InService'
29+
30+
def is_deleted(self):
31+
return self.status == 'Deleted'
32+
33+
def __str__(self) -> str:
34+
if self.failure_reason:
35+
return f"{self.status}, failure reason: {self.failure_reason}"
36+
return f"{self.status}"
37+
38+
39+
class SSHIDE:
40+
logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE')
41+
42+
def __init__(self, domain: str, user: str, region_name: str = None):
43+
self.user = user
44+
self.domain = domain
45+
self.current_region = region_name or boto3.session.Session().region_name
46+
self.client = boto3.client('sagemaker', region_name=self.current_region)
47+
self.ssh_log = SSHLog(region_name=self.current_region)
48+
49+
def create_ssh_kernel_app(self, app_name: str,
50+
image_name='sagemaker-datascience-38',
51+
instance_type='ml.m5.xlarge',
52+
ssh_lifecycle_config='sagemaker-ssh-helper',
53+
recreate=False):
54+
"""
55+
Creates new kernel app with SSH lifecycle config (see kernel-lc-config.sh ).
56+
57+
Images: https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html .
58+
59+
Note that doc is not always up-to-date and doesn't list full names,
60+
e.g., sagemaker-base-python-310 in the doc is sagemaker-base-python-310-v1 in the CreateApp API .
61+
62+
:param app_name:
63+
:param image_name: [name] from the images doc above
64+
:param instance_type:
65+
:param ssh_lifecycle_config:
66+
:param recreate:
67+
"""
68+
self.logger.info(f"Creating kernel app {app_name} with SSH lifecycle config {ssh_lifecycle_config}")
69+
self.log_urls(app_name)
70+
status = self.get_kernel_app_status(app_name)
71+
while status.is_in_transition():
72+
self.logger.info(f"Waiting for the final status. Current status: {status}")
73+
time.sleep(10)
74+
status = self.get_kernel_app_status(app_name)
75+
76+
self.logger.info(f"Previous app status: {status}")
77+
78+
if status.is_in_service():
79+
if recreate:
80+
self.delete_app(app_name, 'KernelGateway')
81+
else:
82+
raise ValueError(f"App {app_name} is in service, pass recreate=True to delete and create again.")
83+
84+
# Here status is None or 'Deleted' or 'Failed'. Safe to create
85+
86+
account_id = boto3.client('sts').get_caller_identity().get('Account')
87+
image_arn = self.resolve_sagemaker_kernel_image_arn(image_name)
88+
lifecycle_arn = f"arn:aws:sagemaker:{self.current_region}:{account_id}:" \
89+
f"studio-lifecycle-config/{ssh_lifecycle_config}"
90+
91+
self.create_app(app_name, 'KernelGateway', instance_type, image_arn, lifecycle_arn)
92+
93+
def get_kernel_app_status(self, app_name: str) -> IDEAppStatus:
94+
"""
95+
:param app_name:
96+
:return: None | 'InService' | 'Deleted' | 'Deleting' | 'Failed' | 'Pending'
97+
"""
98+
response = None
99+
try:
100+
response = self.client.describe_app(
101+
DomainId=self.domain,
102+
AppType='KernelGateway',
103+
UserProfileName=self.user,
104+
AppName=app_name,
105+
)
106+
except ClientError as e:
107+
error_code = e.response.get("Error", {}).get("Code")
108+
if error_code == 'ResourceNotFound':
109+
pass
110+
else:
111+
raise
112+
113+
status = None
114+
failure_reason = None
115+
if response:
116+
status = response['Status']
117+
if 'FailureReason' in response:
118+
failure_reason = response['FailureReason']
119+
return IDEAppStatus(status, failure_reason)
120+
121+
def delete_kernel_app(self, app_name, wait: bool = True):
122+
self.delete_app(app_name, 'KernelGateway', wait)
123+
124+
def delete_app(self, app_name, app_type, wait: bool = True):
125+
self.logger.info(f"Deleting app {app_name}")
126+
127+
try:
128+
_ = self.client.delete_app(
129+
DomainId=self.domain,
130+
AppType=app_type,
131+
UserProfileName=self.user,
132+
AppName=app_name,
133+
)
134+
except ClientError as e:
135+
# probably, already deleted
136+
code = e.response.get("Error", {}).get("Code")
137+
message = e.response.get("Error", {}).get("Message")
138+
self.logger.warning("ClientError code: " + code)
139+
self.logger.warning("ClientError message: " + message)
140+
if code == 'AccessDeniedException':
141+
raise
142+
return
143+
144+
status = self.get_kernel_app_status(app_name)
145+
while wait and status.is_deleting():
146+
self.logger.info(f"Waiting for the Deleted status. Current status: {status}")
147+
time.sleep(10)
148+
status = self.get_kernel_app_status(app_name)
149+
self.logger.info(f"Status after delete: {status}")
150+
if wait and not status.is_deleted():
151+
raise ValueError(f"Failed to delete app {app_name}. Status: {status}")
152+
153+
def create_app(self, app_name, app_type, instance_type, image_arn, lifecycle_arn: str = None):
154+
self.logger.info(f"Creating {app_type} app {app_name} on {instance_type} "
155+
f"with {image_arn} and lifecycle {lifecycle_arn}")
156+
resource_spec = {
157+
'InstanceType': instance_type,
158+
'SageMakerImageArn': image_arn,
159+
}
160+
if lifecycle_arn:
161+
resource_spec['LifecycleConfigArn'] = lifecycle_arn
162+
163+
_ = self.client.create_app(
164+
DomainId=self.domain,
165+
AppType=app_type,
166+
AppName=app_name,
167+
UserProfileName=self.user,
168+
ResourceSpec=resource_spec,
169+
)
170+
status = self.get_kernel_app_status(app_name)
171+
while status.is_pending():
172+
self.logger.info(f"Waiting for the InService status. Current status: {status}")
173+
time.sleep(10)
174+
status = self.get_kernel_app_status(app_name)
175+
176+
self.logger.info(f"New app status: {status}")
177+
178+
if not status.is_in_service():
179+
raise ValueError(f"Failed to create app {app_name}. Status: {status}")
180+
181+
def resolve_sagemaker_kernel_image_arn(self, image_name):
182+
sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images
183+
return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}"
184+
185+
def get_kernel_instance_ids(self, app_name, timeout_in_sec):
186+
self.logger.info("Resolving IDE instance IDs through SSM tags")
187+
self.log_urls(app_name)
188+
# FIXME: resolve with domain and user
189+
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec)
190+
return result
191+
192+
def log_urls(self, app_name):
193+
self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}")
194+
if self.domain and self.user:
195+
self.logger.info(f"Remote apps metadata is at {self.get_user_metadata_url()}")
196+
197+
def get_cloudwatch_url(self, app_name):
198+
return self.ssh_log.get_ide_cloudwatch_url(self.domain, self.user, app_name)
199+
200+
def get_user_metadata_url(self):
201+
return self.ssh_log.get_ide_metadata_url(self.domain, self.user)

sagemaker_ssh_helper/sm-ssh-ide

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ source "$dir"/sm-helper-functions
1212
_install_helper_scripts
1313

1414
# TODO: install into another separate venv (not to pollute main venv, both for conda and non-conda envs)
15+
# See custom conda env example here: https://repost.aws/knowledge-center/sagemaker-lifecycle-script-timeout .
1516
# TODO: determine python location only when needed
1617

1718
SM_STUDIO_PYTHON=$(/opt/.sagemakerinternal/conda/bin/python -c \
@@ -204,6 +205,10 @@ elif [[ "$1" == "env-diagnostics" ]]; then
204205
echo "SageMaker Studio Python location: $($0 get-studio-python-path || echo 'Not found')"
205206
echo "SageMaker Studio Python version: $($0 get-studio-python-version || echo 'Not found')"
206207

208+
# Should be the same as `jupyter kernelspec list` executed from SageMaker Studio Python
209+
echo "Jupyter Kernels: $(find $JUPYTER_PATH -name 'kernel.json')"
210+
find "$JUPYTER_PATH" -name 'kernel.json' -print0 | xargs -0 cat
211+
207212
elif [[ "$1" == "stop" ]]; then
208213

209214
pkill -ef amazon-ssm-agent || echo "sm-ssh-ide: SSM agent already stopped?"

tests/source_dir/training/train_tf_mme.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)