Skip to content

Commit 7b0a1b3

Browse files
Support attaching to already submitted training jobs with SSHEstimatorWrapper.attach(job_name)
1 parent ee8f12b commit 7b0a1b3

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

FAQ.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,27 @@ You might want (optionally) to configure [AWS PrivateLink for Session Manager en
153153
## API Questions
154154

155155
### I'm using boto3 Python SDK instead of SageMaker Python SDK, how can I use SageMaker SSH Helper?
156-
This is a tricky question. In short, this use case is not supported by SageMaker SSH Helper.
156+
This use case is not fully supported by SageMaker SSH Helper.
157157
However, [you can](https://repost.aws/questions/QU8-U_XgPVRSuLTSXf8eW8fA/can-we-connect-to-the-instance-via-ssh-or-other-means-where-a-triton-sagemaker-endpoint-is-deployed) analyze the source code and re-implement SageMaker SSH Helper behaviour with boto3, e.g., by passing environment variables from your code.
158+
158159
In general, this is not recommended, because the set of environment variables and internal logic is a subject to future changes. These changes won't necessarily appear in the release notes and can break your code.
159160

161+
However, if submitted a job with boto3 in this way and started SSH Helper inside the container, you can use high-level APIs to fetch instance IDs and connect to the job with `sm-local` scripts from your local machine:
162+
163+
```python
164+
import logging
165+
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
166+
167+
training_job_name = ...
168+
169+
ssh_wrapper = SSHEstimatorWrapper.attach(training_job_name)
170+
171+
instance_ids = ssh_wrapper.get_instance_ids()
172+
173+
logging.info(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")
174+
logging.info(f"To connect over SSH run: sm-local-ssh-training connect {ssh_wrapper.latest_training_job_name()}")
175+
```
176+
160177

161178
### How can I change the SSH authorized keys bucket and location when running `sm-local-ssh-*` commands?
162179
The **public** key is transferred to the container through the default SageMaker bucket with the S3 URI that looks
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import sagemaker
2+
from sagemaker import Session
3+
# noinspection PyProtectedMember
4+
from sagemaker.estimator import _TrainingJob
5+
6+
7+
class DetachedEstimator(sagemaker.estimator.EstimatorBase):
8+
f"""
9+
A sagemaker.estimator.Estimator that does not block on attach().
10+
"""
11+
12+
def __init__(self, training_job_name: str, sagemaker_session: Session):
13+
super().__init__(sagemaker_session=sagemaker_session, instance_count=0)
14+
self._current_job_name = training_job_name
15+
self.latest_training_job = _TrainingJob(sagemaker_session, self._current_job_name)
16+
17+
def training_image_uri(self):
18+
raise ValueError("Not implemented")
19+
20+
def hyperparameters(self):
21+
raise ValueError("Not implemented")
22+
23+
def create_model(self, **kwargs):
24+
raise ValueError("Not implemented")
25+
26+
@classmethod
27+
def attach(cls, training_job_name: str, sagemaker_session=None, model_channel_name="model"):
28+
if not isinstance(training_job_name, str):
29+
raise ValueError("training_job_name MUST be a string")
30+
# TODO: fetch job details and call _prepare_init_params_from_job_description()
31+
return DetachedEstimator(training_job_name, sagemaker_session)

tests/test_attach.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import logging
2+
import os
3+
from datetime import timedelta
4+
5+
import pytest
6+
from sagemaker import Session
7+
# noinspection PyProtectedMember
8+
from sagemaker.estimator import _TrainingJob
9+
from sagemaker.pytorch import PyTorch
10+
11+
from sagemaker_ssh_helper.detached_sagemaker import DetachedEstimator
12+
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
13+
14+
15+
# noinspection DuplicatedCode
16+
def test_attach_estimator():
17+
estimator = PyTorch(entry_point=os.path.basename('source_dir/training/train.py'),
18+
source_dir='source_dir/training/',
19+
dependencies=[SSHEstimatorWrapper.dependency_dir()],
20+
base_job_name='ssh-training',
21+
framework_version='1.9.1',
22+
py_version='py38',
23+
instance_count=1,
24+
instance_type='ml.m5.xlarge',
25+
max_run=int(timedelta(minutes=15).total_seconds()),
26+
keep_alive_period_in_seconds=1800,
27+
container_log_level=logging.INFO)
28+
29+
_ = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
30+
31+
estimator.fit(wait=False)
32+
33+
job: _TrainingJob = estimator.latest_training_job
34+
ssh_wrapper = SSHEstimatorWrapper.attach(job.name)
35+
36+
instance_ids = ssh_wrapper.get_instance_ids()
37+
38+
logging.info(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")
39+
logging.info(f"To connect over SSH run: sm-local-ssh-training connect {ssh_wrapper.latest_training_job_name()}")
40+
41+
ssh_wrapper.start_ssm_connection_and_continue(11022, 60)
42+
43+
ssh_wrapper.wait_training_job()
44+
45+
assert estimator.model_data.find("model.tar.gz") != -1
46+
47+
48+
def test_cannot_fit_detached_estimator():
49+
estimator = DetachedEstimator.attach('training-job-name', Session())
50+
51+
with pytest.raises(ValueError):
52+
_ = SSHEstimatorWrapper.create(estimator)
53+
54+
55+
def test_can_fetch_job_name_from_detached_estimator():
56+
ssh_wrapper = SSHEstimatorWrapper.attach('training-job-name', Session())
57+
58+
job_name = ssh_wrapper.latest_training_job_name()
59+
60+
assert job_name == 'training-job-name'

0 commit comments

Comments
 (0)