Skip to content

Commit b364cb6

Browse files
authored
Add JumpStart Support in ModelTrainer (#1677)
* Add basic JumpStart support in ModelTrainer * Add handling of default training data * Add handling of model artifacts * Fix JumpStart defaults and add basic Integs * bug fixes * Update logs * Improve defaults logic * update tests and add error handling for tags * Add EULA message * remove force HP to string * Add better handling of defaults * fix * rename class * fix * fix ResourceConfig resolution * format * update logs and defaults * fix sagemaker_session in defaults * fix variant and default resolution * Account for SageMakerGatedS3Uri env var * remove SageMakerGatedModelS3Uri env var * fix condition in model artifacts * Add util method and update exception wording * use util
1 parent 0fbac6d commit b364cb6

File tree

18 files changed

+1083
-178
lines changed

18 files changed

+1083
-178
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dist/
1313
**/*.pyc
1414
**.pyc
1515
scratch*.py
16+
scratch/
1617
.eggs
1718
*.egg
1819
examples/tensorflow/distributed_mnist/data

sagemaker_train/src/sagemaker/train/configs.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
RemoteDebugConfig,
4343
SessionChainingConfig,
4444
InstanceGroup,
45+
HubAccessConfig,
46+
ModelAccessConfig,
4547
)
4648

4749
from sagemaker.train.utils import convert_unassigned_to_none
@@ -65,6 +67,8 @@
6567
"InstanceGroup",
6668
"TensorBoardOutputConfig",
6769
"CheckpointConfig",
70+
"HubAccessConfig",
71+
"ModelAccessConfig",
6872
"Compute",
6973
"Networking",
7074
"InputData",
@@ -85,7 +89,8 @@ class SourceCode(BaseConfig):
8589
8690
Parameters:
8791
source_dir (Optional[str]):
88-
The local directory containing the source code to be used in the training job container.
92+
The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that
93+
contains the source code to be used in the training job container.
8994
requirements (Optional[str]):
9095
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
9196
requirements will be installed in the training job container.
@@ -103,6 +108,29 @@ class SourceCode(BaseConfig):
103108
command: Optional[str] = None
104109

105110

111+
class OutputDataConfig(shapes.OutputDataConfig):
112+
"""OutputDataConfig.
113+
114+
Provides the configuration for the output data location of the training job.
115+
116+
Parameters:
117+
s3_output_path (Optional[str]):
118+
The S3 URI where the output data will be stored. This is the location where the
119+
training job will save its output data, such as model artifacts and logs.
120+
kms_key_id (Optional[str]):
121+
The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
122+
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
123+
encryption.
124+
compression_type (Optional[str]):
125+
The model output compression type. Select None to output an uncompressed model,
126+
recommended for large model outputs. Defaults to gzip.
127+
"""
128+
129+
s3_output_path: Optional[str] = None
130+
kms_key_id: Optional[str] = None
131+
compression_type: Optional[str] = None
132+
133+
106134
class Compute(shapes.ResourceConfig):
107135
"""Compute.
108136
@@ -149,8 +177,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
149177
compute_config_dict = self.model_dump()
150178
resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys())
151179
filtered_dict = {
152-
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
180+
k: v
181+
for k, v in compute_config_dict.items()
182+
if k in resource_config_fields and v is not None
153183
}
184+
if not filtered_dict:
185+
return None
154186
return shapes.ResourceConfig(**filtered_dict)
155187

156188

@@ -181,6 +213,8 @@ class Networking(shapes.VpcConfig):
181213
algorithm in distributed training.
182214
"""
183215

216+
security_group_ids: Optional[list[str]] = None
217+
subnets: Optional[list[str]] = None
184218
enable_network_isolation: Optional[bool] = None
185219
enable_inter_container_traffic_encryption: Optional[bool] = None
186220

@@ -192,10 +226,12 @@ def _model_validator(self) -> "Networking":
192226
def _to_vpc_config(self) -> shapes.VpcConfig:
193227
"""Convert to a sagemaker_core.shapes.VpcConfig object."""
194228
compute_config_dict = self.model_dump()
195-
resource_config_fields = set(shapes.VpcConfig.__annotations__.keys())
229+
vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys())
196230
filtered_dict = {
197-
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
231+
k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None
198232
}
233+
if not filtered_dict:
234+
return None
199235
return shapes.VpcConfig(**filtered_dict)
200236

201237

sagemaker_train/src/sagemaker/train/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from __future__ import absolute_import
1515
import os
1616

17-
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
18-
1917
SM_CODE = "code"
2018
SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code"
2119

0 commit comments

Comments
 (0)