Skip to content

Commit 2adc4bd

Browse files
committed
Update image_uri_config, fw_utils and image_uris.py in sagemaker-core
1 parent 4055fcf commit 2adc4bd

File tree

76 files changed

+724
-22978
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+724
-22978
lines changed

sagemaker-core/src/sagemaker/core/fw_utils.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525

2626
from packaging import version
2727

28-
import sagemaker.core.common_utils as sagemaker_utils
29-
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
28+
from sagemaker.core import image_uris
29+
import sagemaker.core.common_utils as utils
30+
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
3031
from sagemaker.core.instance_group import InstanceGroup
31-
from sagemaker.core.s3 import s3_path_join
32+
from sagemaker.core.s3.utils import s3_path_join
3233
from sagemaker.core.session_settings import SessionSettings
3334
from sagemaker.core.workflow import is_pipeline_variable
34-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
35+
from sagemaker.core.workflow.entities import PipelineVariable
3536

3637
logger = logging.getLogger(__name__)
3738

@@ -155,6 +156,9 @@
155156
"2.3.1",
156157
"2.4.1",
157158
"2.5.1",
159+
"2.6.0",
160+
"2.7.1",
161+
"2.8.0",
158162
]
159163

160164
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -455,7 +459,7 @@ def tar_and_upload_dir(
455459

456460
try:
457461
source_files = _list_files_to_compress(script, directory) + dependencies
458-
tar_file = sagemaker_utils.create_tar_file(
462+
tar_file = utils.create_tar_file(
459463
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
460464
)
461465

@@ -516,7 +520,7 @@ def framework_name_from_image(image_uri):
516520
- str: The image tag
517521
- str: If the TensorFlow image is script mode
518522
"""
519-
sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
523+
sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
520524
sagemaker_match = sagemaker_pattern.match(image_uri)
521525
if sagemaker_match is None:
522526
return None, None, None, None
@@ -595,7 +599,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
595599
"""
596600
name_from_image = f"/model_code/{int(time.time())}"
597601
if not is_pipeline_variable(image):
598-
name_from_image = sagemaker_utils.name_from_image(image)
602+
name_from_image = utils.name_from_image(image)
599603
return s3_path_join(code_location_key_prefix, model_name or name_from_image)
600604

601605

@@ -961,7 +965,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
961965
"""
962966
err_msg = ""
963967
if isinstance(instance_type, str):
964-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
968+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
965969
if match and match[1].startswith("trn"):
966970
keys = list(distribution.keys())
967971
if len(keys) == 0:
@@ -1062,7 +1066,7 @@ def validate_torch_distributed_distribution(
10621066
)
10631067

10641068
# Check entry point type
1065-
if not entry_point.endswith(".py"):
1069+
if entry_point is not None and not entry_point.endswith(".py"):
10661070
err_msg += (
10671071
"Unsupported entry point type for the distribution torch_distributed.\n"
10681072
"Only python programs (*.py) are supported."
@@ -1082,7 +1086,7 @@ def _is_gpu_instance(instance_type):
10821086
bool: Whether or not the instance_type supports GPU
10831087
"""
10841088
if isinstance(instance_type, str):
1085-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1089+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
10861090
if match:
10871091
if match[1].startswith("p") or match[1].startswith("g"):
10881092
return True
@@ -1101,7 +1105,7 @@ def _is_trainium_instance(instance_type):
11011105
bool: Whether or not the instance_type is a Trainium instance
11021106
"""
11031107
if isinstance(instance_type, str):
1104-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1108+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11051109
if match and match[1].startswith("trn"):
11061110
return True
11071111
return False
@@ -1148,7 +1152,7 @@ def _instance_type_supports_profiler(instance_type):
11481152
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
11491153
"""
11501154
if isinstance(instance_type, str):
1151-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1155+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11521156
if match and match[1].startswith("trn"):
11531157
return True
11541158
return False
@@ -1174,3 +1178,42 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11741178
"framework_version or py_version was None, yet image_uri was also None. "
11751179
"Either specify both framework_version and py_version, or specify image_uri."
11761180
)
1181+
1182+
1183+
def create_image_uri(
1184+
region,
1185+
framework,
1186+
instance_type,
1187+
framework_version,
1188+
py_version=None,
1189+
account=None, # pylint: disable=W0613
1190+
accelerator_type=None,
1191+
optimized_families=None, # pylint: disable=W0613
1192+
):
1193+
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
1194+
1195+
Args:
1196+
region (str): AWS region where the image is uploaded.
1197+
framework (str): framework used by the image.
1198+
instance_type (str): SageMaker instance type. Used to determine device
1199+
type (cpu/gpu/family-specific optimized).
1200+
framework_version (str): The version of the framework.
1201+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1202+
If not specified, image uri will not include a python component.
1203+
account (str): AWS account that contains the image. (default:
1204+
'520713654638')
1205+
accelerator_type (str): SageMaker Elastic Inference accelerator type.
1206+
optimized_families (str): Deprecated. A no-op argument.
1207+
1208+
Returns:
1209+
the image uri
1210+
"""
1211+
renamed_warning("The method create_image_uri")
1212+
return image_uris.retrieve(
1213+
framework=framework,
1214+
region=region,
1215+
version=framework_version,
1216+
py_version=py_version,
1217+
instance_type=instance_type,
1218+
accelerator_type=accelerator_type,
1219+
)

sagemaker-core/src/sagemaker/core/image_uri_config/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

sagemaker-core/src/sagemaker/core/image_uri_config/huggingface-llm-neuronx.json

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"inf2"
55
],
66
"version_aliases": {
7-
"0.0": "0.0.28"
7+
"0.0": "0.0.28",
8+
"0.2": "0.2.0",
9+
"0.3": "0.3.0"
810
},
911
"versions": {
1012
"0.0.16": {
@@ -654,6 +656,114 @@
654656
"container_version": {
655657
"inf2": "ubuntu22.04"
656658
}
659+
},
660+
"0.2.0": {
661+
"py_versions": [
662+
"py310"
663+
],
664+
"registries": {
665+
"af-south-1": "626614931356",
666+
"ap-east-1": "871362719292",
667+
"ap-east-2": "975050140332",
668+
"ap-northeast-1": "763104351884",
669+
"ap-northeast-2": "763104351884",
670+
"ap-northeast-3": "364406365360",
671+
"ap-south-1": "763104351884",
672+
"ap-south-2": "772153158452",
673+
"ap-southeast-1": "763104351884",
674+
"ap-southeast-2": "763104351884",
675+
"ap-southeast-3": "907027046896",
676+
"ap-southeast-4": "457447274322",
677+
"ap-southeast-5": "550225433462",
678+
"ap-southeast-6": "633930458069",
679+
"ap-southeast-7": "590183813437",
680+
"ca-central-1": "763104351884",
681+
"ca-west-1": "204538143572",
682+
"cn-north-1": "727897471807",
683+
"cn-northwest-1": "727897471807",
684+
"eu-central-1": "763104351884",
685+
"eu-central-2": "380420809688",
686+
"eu-north-1": "763104351884",
687+
"eu-south-1": "692866216735",
688+
"eu-south-2": "503227376785",
689+
"eu-west-1": "763104351884",
690+
"eu-west-2": "763104351884",
691+
"eu-west-3": "763104351884",
692+
"il-central-1": "780543022126",
693+
"me-central-1": "914824155844",
694+
"me-south-1": "217643126080",
695+
"mx-central-1": "637423239942",
696+
"sa-east-1": "763104351884",
697+
"us-east-1": "763104351884",
698+
"us-east-2": "763104351884",
699+
"us-gov-east-1": "446045086412",
700+
"us-gov-west-1": "442386744353",
701+
"us-iso-east-1": "886529160074",
702+
"us-isob-east-1": "094389454867",
703+
"us-isof-east-1": "303241398832",
704+
"us-isof-south-1": "454834333376",
705+
"us-west-1": "763104351884",
706+
"us-west-2": "763104351884"
707+
},
708+
"tag_prefix": "2.5.1-optimum3.3.4",
709+
"repository": "huggingface-pytorch-tgi-inference",
710+
"container_version": {
711+
"inf2": "ubuntu22.04"
712+
}
713+
},
714+
"0.3.0": {
715+
"py_versions": [
716+
"py310"
717+
],
718+
"registries": {
719+
"af-south-1": "626614931356",
720+
"ap-east-1": "871362719292",
721+
"ap-east-2": "975050140332",
722+
"ap-northeast-1": "763104351884",
723+
"ap-northeast-2": "763104351884",
724+
"ap-northeast-3": "364406365360",
725+
"ap-south-1": "763104351884",
726+
"ap-south-2": "772153158452",
727+
"ap-southeast-1": "763104351884",
728+
"ap-southeast-2": "763104351884",
729+
"ap-southeast-3": "907027046896",
730+
"ap-southeast-4": "457447274322",
731+
"ap-southeast-5": "550225433462",
732+
"ap-southeast-6": "633930458069",
733+
"ap-southeast-7": "590183813437",
734+
"ca-central-1": "763104351884",
735+
"ca-west-1": "204538143572",
736+
"cn-north-1": "727897471807",
737+
"cn-northwest-1": "727897471807",
738+
"eu-central-1": "763104351884",
739+
"eu-central-2": "380420809688",
740+
"eu-north-1": "763104351884",
741+
"eu-south-1": "692866216735",
742+
"eu-south-2": "503227376785",
743+
"eu-west-1": "763104351884",
744+
"eu-west-2": "763104351884",
745+
"eu-west-3": "763104351884",
746+
"il-central-1": "780543022126",
747+
"me-central-1": "914824155844",
748+
"me-south-1": "217643126080",
749+
"mx-central-1": "637423239942",
750+
"sa-east-1": "763104351884",
751+
"us-east-1": "763104351884",
752+
"us-east-2": "763104351884",
753+
"us-gov-east-1": "446045086412",
754+
"us-gov-west-1": "442386744353",
755+
"us-iso-east-1": "886529160074",
756+
"us-isob-east-1": "094389454867",
757+
"us-isof-east-1": "303241398832",
758+
"us-isof-south-1": "454834333376",
759+
"us-west-1": "763104351884",
760+
"us-west-2": "763104351884"
761+
},
762+
"tag_prefix": "2.7.0-optimum3.3.6",
763+
"repository": "huggingface-pytorch-tgi-inference",
764+
"container_version": {
765+
"inf2": "ubuntu22.04"
766+
}
657767
}
658768
}
659769
}

0 commit comments

Comments
 (0)