Skip to content

Commit 5e6e9e1

Browse files
pintaoz-awspintaoz
andauthored
Support intelligent defaults for ImageRetriever (#1679)
* Support intelligent defaults for ImageRetriever * Address comment * Delete config schema for other modules * move training configs from legacy * Update path --------- Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent a9a5072 commit 5e6e9e1

File tree

12 files changed

+870
-4
lines changed

12 files changed

+870
-4
lines changed

sagemaker_utils/src/sagemaker/utils/config/config_schema.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
TELEMETRY_OPT_OUT = "TelemetryOptOut"
118118
NOTEBOOK_JOB = "NotebookJob"
119119
MODEL_TRAINER = "ModelTrainer"
120+
IMAGE_RETRIEVER = "ImageRetriever"
120121

121122

122123
def _simple_path(*args: str):
@@ -666,6 +667,12 @@ def _simple_path(*args: str):
666667
},
667668
"baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
668669
"sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
670+
"version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
671+
"py_version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
672+
"instance_type": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
673+
"accelerator_type": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
674+
"image_scope": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
675+
"container_version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
669676
"distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
670677
"compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
671678
"networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
@@ -677,6 +684,14 @@ def _simple_path(*args: str):
677684
"trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
678685
"environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
679686
"hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
687+
"smp": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
688+
"base_framework_version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
689+
"training_compiler_config": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
690+
"model_id": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
691+
"model_version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
692+
"sdk_version": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
693+
"inference_tool": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
694+
"serverless_inference_config": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
680695
},
681696
PROPERTIES: {
682697
SCHEMA_VERSION: {
@@ -731,6 +746,7 @@ def _simple_path(*args: str):
731746
},
732747
},
733748
MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
749+
IMAGE_RETRIEVER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
734750
ESTIMATOR: {
735751
TYPE: OBJECT,
736752
ADDITIONAL_PROPERTIES: False,
@@ -1231,6 +1247,7 @@ def _simple_path(*args: str):
12311247
},
12321248
CONTAINER_CONFIG: {
12331249
TYPE: OBJECT,
1250+
PROPERTIES: {}
12341251
},
12351252
},
12361253
},
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Module for deprecation abstractions."""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
import warnings
18+
19+
logger = logging.getLogger(__name__)
20+
21+
V2_URL = "https://sagemaker.readthedocs.io/en/stable/v2.html"
22+
23+
24+
def _warn(msg, sdk_version=None):
25+
"""Generic warning raiser referencing V2
26+
27+
Args:
28+
phrase: The phrase to include in the warning.
29+
sdk_version: the sdk version of removal of support.
30+
"""
31+
_sdk_version = sdk_version if sdk_version is not None else "2"
32+
full_msg = f"{msg} in sagemaker>={_sdk_version}.\nSee: {V2_URL} for details."
33+
warnings.warn(full_msg, DeprecationWarning, stacklevel=2)
34+
logger.warning(full_msg)
35+
36+
37+
def removed_warning(phrase, sdk_version=None):
38+
"""Raise a warning for a no-op in sagemaker>=2
39+
40+
Args:
41+
phrase: the prefix phrase of the warning message.
42+
sdk_version: the sdk version of removal of support.
43+
"""
44+
_warn(f"{phrase} is a no-op", sdk_version)
45+
46+
47+
def renamed_warning(phrase):
48+
"""Raise a warning for a rename in sagemaker>=2
49+
50+
Args:
51+
phrase: the prefix phrase of the warning message.
52+
"""
53+
_warn(f"{phrase} has been renamed")
54+
55+
56+
def deprecation_warn(name, date, msg=None):
57+
"""Raise a warning for soon to be deprecated feature in sagemaker>=2
58+
59+
Args:
60+
name (str): Name of the feature
61+
date (str): the date when the feature will be deprecated
62+
msg (str): the prefix phrase of the warning message.
63+
"""
64+
_warn(f"{name} will be deprecated on {date}.{msg}")
65+
66+
67+
def deprecation_warn_base(msg):
68+
"""Raise a warning for soon to be deprecated feature in sagemaker>=2
69+
70+
Args:
71+
msg (str): the warning message.
72+
"""
73+
_warn(msg)
74+
75+
76+
def deprecation_warning(date, msg=None):
77+
"""Decorator for raising deprecation warning for a feature in sagemaker>=2
78+
79+
Args:
80+
date (str): the date when the feature will be deprecated
81+
msg (str): the prefix phrase of the warning message.
82+
83+
Usage:
84+
@deprecation_warning(msg="message", date="date")
85+
def sample_function():
86+
print("xxxx....")
87+
88+
@deprecation_warning(msg="message", date="date")
89+
class SampleClass():
90+
def __init__(self):
91+
print("xxxx....")
92+
93+
"""
94+
95+
def deprecate(obj):
96+
def wrapper(*args, **kwargs):
97+
deprecation_warn(obj.__name__, date, msg)
98+
return obj(*args, **kwargs)
99+
100+
return wrapper
101+
102+
return deprecate
103+
104+
105+
def renamed_kwargs(old_name, new_name, value, kwargs):
106+
"""Checks if the deprecated argument is in kwargs
107+
108+
Raises warning, if present.
109+
110+
Args:
111+
old_name: name of deprecated argument
112+
new_name: name of the new argument
113+
value: value associated with new name, if supplied
114+
kwargs: keyword arguments dict
115+
116+
Returns:
117+
value of the keyword argument, if present
118+
"""
119+
if old_name in kwargs:
120+
value = kwargs.get(old_name, value)
121+
kwargs[new_name] = value
122+
renamed_warning(old_name)
123+
return value
124+
125+
126+
def removed_arg(name, arg):
127+
"""Checks if the deprecated argument is populated.
128+
129+
Raises warning, if not None.
130+
131+
Args:
132+
name: name of deprecated argument
133+
arg: the argument to check
134+
"""
135+
if arg is not None:
136+
removed_warning(name)
137+
138+
139+
def removed_kwargs(name, kwargs):
140+
"""Checks if the deprecated argument is in kwargs
141+
142+
Raises warning, if present.
143+
144+
Args:
145+
name: name of deprecated argument
146+
kwargs: keyword arguments dict
147+
"""
148+
if name in kwargs:
149+
removed_warning(name)
150+
151+
152+
def removed_function(name):
153+
"""A no-op deprecated function factory."""
154+
155+
def func(*args, **kwargs): # pylint: disable=W0613
156+
removed_warning(f"The function {name}")
157+
158+
return func
159+
160+
161+
def deprecated(sdk_version=None):
162+
"""Decorator for raising deprecated warning for a feature in sagemaker>=2
163+
164+
Args:
165+
sdk_version (str): the sdk version of removal of support.
166+
167+
Usage:
168+
@deprecated()
169+
def sample_function():
170+
print("xxxx....")
171+
172+
@deprecated(sdk_version="2.66")
173+
class SampleClass():
174+
def __init__(self):
175+
print("xxxx....")
176+
177+
"""
178+
179+
def deprecate(obj):
180+
def wrapper(*args, **kwargs):
181+
removed_warning(obj.__name__, sdk_version)
182+
return obj(*args, **kwargs)
183+
184+
return wrapper
185+
186+
return deprecate
187+
188+
189+
def deprecated_function(func, name):
190+
"""Wrap a function with a deprecation warning.
191+
192+
Args:
193+
func: Function to wrap in a deprecation warning.
194+
name: The name that has been deprecated.
195+
196+
Returns:
197+
The modified function
198+
"""
199+
200+
def deprecate(*args, **kwargs):
201+
renamed_warning(f"The {name}")
202+
return func(*args, **kwargs)
203+
204+
return deprecate
205+
206+
207+
def deprecated_serialize(instance, name):
208+
"""Modifies a serializer instance serialize method.
209+
210+
Args:
211+
instance: Instance to modify serialize method.
212+
name: The name that has been deprecated.
213+
214+
Returns:
215+
The modified instance
216+
"""
217+
instance.serialize = deprecated_function(instance.serialize, name)
218+
return instance
219+
220+
221+
def deprecated_deserialize(instance, name):
222+
"""Modifies a deserializer instance deserialize method.
223+
224+
Args:
225+
instance: Instance to modify deserialize method.
226+
name: The name that has been deprecated.
227+
228+
Returns:
229+
The modified instance
230+
"""
231+
instance.deserialize = deprecated_function(instance.deserialize, name)
232+
return instance
233+
234+
235+
def deprecated_class(cls, name):
236+
"""Returns a class based on super class with a deprecation warning.
237+
238+
Args:
239+
cls: The class to derive with a deprecation warning on __init__
240+
name: The name of the class.
241+
242+
Returns:
243+
The modified class.
244+
"""
245+
246+
class DeprecatedClass(cls):
247+
"""Provides a warning for the class name."""
248+
249+
def __init__(self, *args, **kwargs):
250+
"""Provides a warning for the class name."""
251+
renamed_warning(f"The class {name}")
252+
super(DeprecatedClass, self).__init__(*args, **kwargs)
253+
254+
return DeprecatedClass
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Defines enum values."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from enum import Enum
19+
20+
21+
LOGGER = logging.getLogger("sagemaker")
22+
23+
24+
class EndpointType(Enum):
25+
"""Types of endpoint"""
26+
27+
MODEL_BASED = "ModelBased" # Amazon SageMaker Model Based Endpoint
28+
INFERENCE_COMPONENT_BASED = (
29+
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
30+
)
31+
32+
33+
class RoutingStrategy(Enum):
34+
"""Strategy for routing https traffics."""
35+
36+
RANDOM = "RANDOM"
37+
"""The endpoint routes each request to a randomly chosen instance.
38+
"""
39+
LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
40+
"""The endpoint routes requests to the specific instances that have
41+
more capacity to process them.
42+
"""
43+
44+
45+
class Tag(str, Enum):
46+
"""Enum class for tag keys to apply to models."""
47+
48+
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
49+
SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
50+
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
51+
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"

sagemaker_utils/src/sagemaker/utils/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from packaging import version
2727

2828
import sagemaker.utils
29-
from sagemaker.utils.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
29+
from sagemaker.utils.deprecations import deprecation_warn_base, renamed_kwargs
3030
from sagemaker.utils.instance_group import InstanceGroup
3131
from sagemaker.utils.s3_utils import s3_path_join
3232
from sagemaker.utils.session_settings import SessionSettings

0 commit comments

Comments
 (0)