Skip to content

Commit 7986335

Browse files
pintaoz-awspintaoz
andauthored
Retrieve image uri from s3 (#1694)
* Retrieve image uri from s3 * Add intelligent default test --------- Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent 7fc508a commit 7986335

File tree

5 files changed

+157
-16
lines changed

5 files changed

+157
-16
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from __future__ import absolute_import
2+
3+
from sagemaker.utils.image_retriever.image_retriever import ImageRetriever # noqa: F401

sagemaker_utils/src/sagemaker/utils/image_retriever/image_retriever.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14+
from __future__ import absolute_import
15+
116
import re
217
from typing import Optional
318
from graphene.utils.str_converters import to_camel_case
419

5-
# TODO: Update these dependencies after they are moved to corresponding submodule
6-
from legacy.src.sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
7-
from legacy.src.sagemaker.training_compiler.config import TrainingCompilerConfig
20+
from sagemaker.utils.serverless_inference_config import ServerlessInferenceConfig
21+
from sagemaker.utils.training_compiler_config import TrainingCompilerConfig
822
from sagemaker.utils.utils import _botocore_resolver
923
from sagemaker.utils.workflow import is_pipeline_variable
10-
from image_retriever_utils import (
24+
from sagemaker.utils.image_retriever.image_retriever_utils import (
1125
_config_for_framework_and_scope,
1226
_get_final_image_scope,
1327
_get_image_tag,
@@ -26,7 +40,7 @@
2640
config_for_framework,
2741
)
2842
from sagemaker.utils.workflow.utilities import override_pipeline_parameter_var
29-
from sagemaker.utils.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path
43+
from sagemaker.utils.config.config_schema import IMAGE_RETRIEVER, MODULES, PYTHON_SDK, SAGEMAKER, _simple_path
3044
from sagemaker.utils.config.config_manager import SageMakerConfig
3145

3246
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
@@ -109,10 +123,13 @@ def retrieve_hugging_face_uri(
109123
str: The ECR URI for the corresponding SageMaker Docker image.
110124
"""
111125
args = dict(locals())
126+
config = SageMakerConfig()
112127
for name, val in args.items():
113128
if name in CONFIGURABLE_ATTRIBUTES and not val:
114-
default_value = SageMakerConfig.resolve_value_from_config(
115-
config_path=_simple_path(SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name))
129+
default_value = config.resolve_value_from_config(
130+
config_path=_simple_path(
131+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
132+
)
116133
)
117134
if default_value is not None:
118135
locals()[name] = default_value
@@ -497,10 +514,13 @@ def retrieve(
497514
DeprecatedJumpStartModelError: If the version of the model is deprecated.
498515
"""
499516
args = dict(locals())
517+
config = SageMakerConfig()
500518
for name, val in args.items():
501519
if name in CONFIGURABLE_ATTRIBUTES and not val:
502-
default_value = SageMakerConfig.resolve_value_from_config(
503-
config_path=_simple_path(SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name))
520+
default_value = config.resolve_value_from_config(
521+
config_path=_simple_path(
522+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
523+
)
504524
)
505525
if default_value is not None:
506526
locals()[name] = default_value

sagemaker_utils/src/sagemaker/utils/image_retriever/image_retriever_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
1414
from __future__ import absolute_import
1515

16-
import json
1716
import logging
18-
import os
1917
from typing import Optional
2018
from packaging.version import Version
2119
import requests
@@ -44,6 +42,8 @@
4442
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
4543
STABILITYAI_FRAMEWORK = "stabilityai"
4644
SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver"
45+
# TODO: Update to use a bucket in prod account when GA
46+
S3_URL_TEMPLATE = "https://image_uri_configs_beta.s3.amazonaws.com/{framework}.json"
4747

4848

4949
def _get_image_tag(
@@ -185,11 +185,13 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
185185

186186
def config_for_framework(framework):
187187
"""Loads the JSON config for the given framework."""
188-
response = requests.get(s3_url)
189-
return response.json()
190-
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
191-
with open(fname) as f:
192-
return json.load(f)
188+
try:
189+
s3_url = S3_URL_TEMPLATE.format(framework=framework)
190+
response = requests.get(s3_url)
191+
return response.json()
192+
except requests.exceptions.RequestException as e:
193+
print(f"Error retrieving config from S3: {e}")
194+
raise e
193195

194196

195197
def _get_final_image_scope(framework, instance_type, image_scope):

sagemaker_utils/tests/integ/image_retriever/__init__.py

Whitespace-only changes.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from unittest.mock import patch
2+
import pytest
3+
4+
from sagemaker.utils.config.config_schema import (
5+
IMAGE_RETRIEVER,
6+
MODULES,
7+
PYTHON_SDK,
8+
SAGEMAKER,
9+
_simple_path,
10+
)
11+
from sagemaker.utils.image_retriever import ImageRetriever
12+
from sagemaker.utils.config.config_manager import SageMakerConfig
13+
14+
15+
@pytest.mark.integ
16+
def test_retrieve_image_uri():
17+
image_uri = ImageRetriever.retrieve("clarify", "us-west-2")
18+
assert (
19+
image_uri == "306415355426.dkr.ecr.us-west-2.amazonaws.com/sagemaker-clarify-processing:1.0"
20+
)
21+
22+
image_uri = ImageRetriever.retrieve(
23+
framework="sagemaker-distribution",
24+
image_scope="inference",
25+
instance_type="ml.g5.4xlarge",
26+
region="us-west-1",
27+
)
28+
assert (
29+
image_uri
30+
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu"
31+
)
32+
33+
image_uri = ImageRetriever.retrieve(
34+
"xgboost",
35+
"eu-west-1",
36+
version="0.90-1",
37+
instance_type="ml.m5.xlarge",
38+
image_scope="inference",
39+
)
40+
assert (
41+
image_uri == "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
42+
)
43+
44+
image_uri = ImageRetriever.retrieve(
45+
framework="tensorflow",
46+
region="us-west-2",
47+
version="2.3",
48+
py_version="py37",
49+
instance_type="ml.p4d.24xlarge",
50+
image_scope="training",
51+
)
52+
assert (
53+
image_uri
54+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.3-gpu-py37-cu110-ubuntu18.04-v3"
55+
)
56+
57+
58+
@pytest.mark.integ
59+
def test_retrieve_pytorch_uri():
60+
image_uri = ImageRetriever.retrieve_pytorch_uri(
61+
region="us-west-2",
62+
version="1.6",
63+
py_version="py3",
64+
instance_type="ml.p4d.24xlarge",
65+
image_scope="training",
66+
)
67+
assert (
68+
image_uri
69+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.6-gpu-py3-cu110-ubuntu18.04-v3"
70+
)
71+
72+
73+
@pytest.mark.integ
74+
def test_retrieve_hugging_face_uri():
75+
image_uri = ImageRetriever.retrieve_hugging_face_uri(
76+
version="4.28.1",
77+
py_version="py310",
78+
instance_type="ml.p2.xlarge",
79+
region="us-east-1",
80+
image_scope="training",
81+
base_framework_version="pytorch2.0.0",
82+
container_version="cu110-ubuntu20.04",
83+
)
84+
assert image_uri == "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training"
85+
":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
86+
87+
88+
@pytest.mark.integ
89+
def test_retrieve_base_python_image_uri():
90+
image_uri = ImageRetriever.retrieve_base_python_image_uri()
91+
assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0"
92+
93+
94+
@pytest.mark.integ
95+
@patch.object(SageMakerConfig, "resolve_value_from_config")
96+
def test_retrieve_image_uri_intelligent_default(mock_load_config):
97+
def custom_return(config_path):
98+
if config_path == _simple_path(
99+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, "ImageScope"
100+
):
101+
return "inference"
102+
if config_path == _simple_path(
103+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, "InstanceType"
104+
):
105+
return "ml.g5.4xlarge"
106+
107+
mock_load_config.side_effect = custom_return
108+
109+
# Will get image_scope="inference" and instance_type="ml.g5.4xlarge" from intelligent default
110+
image_uri = ImageRetriever.retrieve(
111+
framework="sagemaker-distribution",
112+
region="us-west-1",
113+
)
114+
assert (
115+
image_uri
116+
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu"
117+
)

0 commit comments

Comments
 (0)