Skip to content

Commit 47df7c0

Browse files
committed
Update instance type regex to also include hyphens
For commit: aws/sagemaker-python-sdk-staging@824675b
1 parent 3b07b4a commit 47df7c0

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ def get_instance_type_family(instance_type: str) -> str:
15551555
"""
15561556
instance_type_family = ""
15571557
if isinstance(instance_type, str):
1558-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1558+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
15591559
if match is not None:
15601560
instance_type_family = match[1]
15611561
return instance_type_family

sagemaker-core/src/sagemaker/core/fw_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from packaging import version
2727

28-
from sagemaker.core import image_uris
2928
import sagemaker.core.common_utils as utils
3029
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
3130
from sagemaker.core.instance_group import InstanceGroup
@@ -1208,6 +1207,8 @@ def create_image_uri(
12081207
Returns:
12091208
the image uri
12101209
"""
1210+
from sagemaker.core import image_uris
1211+
12111212
renamed_warning("The method create_image_uri")
12121213
return image_uris.retrieve(
12131214
framework=framework,

sagemaker-core/tests/unit/test_fw_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_validate_mp_config_ddp_and_horovod(self):
225225
class TestTarAndUploadDir:
226226
"""Test tar_and_upload_dir function."""
227227

228-
@patch("sagemaker.core.fw_utils.sagemaker_utils.create_tar_file")
228+
@patch("sagemaker.core.common_utils.create_tar_file")
229229
def test_tar_and_upload_dir_s3_source(self, mock_create_tar):
230230
"""Test with S3 source directory."""
231231
mock_session = Mock()
@@ -242,7 +242,7 @@ def test_tar_and_upload_dir_s3_source(self, mock_create_tar):
242242
assert result.script_name == "train.py"
243243
mock_create_tar.assert_not_called()
244244

245-
@patch("sagemaker.core.fw_utils.sagemaker_utils.create_tar_file")
245+
@patch("sagemaker.core.common_utils.create_tar_file")
246246
@patch("sagemaker.core.fw_utils.tempfile.mkdtemp")
247247
@patch("sagemaker.core.fw_utils.shutil.rmtree")
248248
def test_tar_and_upload_dir_local_file(
@@ -495,6 +495,8 @@ def test_is_gpu_instance_true(self):
495495
assert _is_gpu_instance("ml.p3.2xlarge") is True
496496
assert _is_gpu_instance("ml.g4dn.xlarge") is True
497497
assert _is_gpu_instance("local_gpu") is True
498+
assert _is_gpu_instance("ml.p6-b200.48xlarge") is True
499+
assert _is_gpu_instance("ml.g6e-12xlarge.xlarge") is True
498500

499501
def test_is_gpu_instance_false(self):
500502
"""Test _is_gpu_instance with non-GPU instance."""
@@ -505,6 +507,7 @@ def test_is_trainium_instance_true(self):
505507
"""Test _is_trainium_instance with Trainium instance."""
506508
assert _is_trainium_instance("ml.trn1.2xlarge") is True
507509
assert _is_trainium_instance("ml.trn1.32xlarge") is True
510+
assert _is_trainium_instance("ml.trn1-n.2xlarge") is True
508511

509512
def test_is_trainium_instance_false(self):
510513
"""Test _is_trainium_instance with non-Trainium instance."""
@@ -523,6 +526,7 @@ def test_region_supports_profiler(self):
523526

524527
def test_instance_type_supports_profiler(self):
525528
"""Test _instance_type_supports_profiler."""
529+
assert _instance_type_supports_profiler("ml.trn1-n.xlarge") is True
526530
assert _instance_type_supports_profiler("ml.trn1.2xlarge") is True
527531
assert _instance_type_supports_profiler("ml.p3.2xlarge") is False
528532

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ def _is_inferentia_or_trainium(self, instance_type: Optional[str]) -> bool:
15041504
bool: Whether the given instance type is Inferentia or Trainium.
15051505
"""
15061506
if isinstance(instance_type, str):
1507-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1507+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
15081508
if match:
15091509
if match[1].startswith("inf") or match[1].startswith("trn"):
15101510
return True

0 commit comments

Comments
 (0)