From cc27f23e40a0ed9028a897f84b5425af311af3e3 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:52:05 -0800 Subject: [PATCH 1/8] Remove duplicate lineage folder in sagemaker-core --- .../src/sagemaker/core/lineage/action.py | 2 +- .../src/sagemaker/core/lineage/artifact.py | 2 +- .../src/sagemaker/core/lineage/association.py | 2 +- .../src/sagemaker/core/lineage/context.py | 2 +- .../src/sagemaker/core/lineage/query.py | 10 +++--- .../src/sagemaker/lineage/__init__.py | 33 ------------------- .../src/sagemaker/lineage/action.py | 28 ---------------- .../src/sagemaker/lineage/artifact.py | 28 ---------------- .../src/sagemaker/lineage/context.py | 28 ---------------- .../lineage/lineage_trial_component.py | 28 ---------------- .../tests/unit/lineage/test_query.py | 4 +-- 11 files changed, 11 insertions(+), 156 deletions(-) delete mode 100644 sagemaker-core/src/sagemaker/lineage/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/lineage/action.py delete mode 100644 sagemaker-core/src/sagemaker/lineage/artifact.py delete mode 100644 sagemaker-core/src/sagemaker/lineage/context.py delete mode 100644 sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py diff --git a/sagemaker-core/src/sagemaker/core/lineage/action.py b/sagemaker-core/src/sagemaker/core/lineage/action.py index d6b06196f5..0a0686074d 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/action.py +++ b/sagemaker-core/src/sagemaker/core/lineage/action.py @@ -38,7 +38,7 @@ class Action(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import action + from sagemaker.core.lineage import action my_action = action.Action.create( action_name='MyAction', diff --git a/sagemaker-core/src/sagemaker/core/lineage/artifact.py b/sagemaker-core/src/sagemaker/core/lineage/artifact.py index bc9522069f..8af1420659 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/artifact.py +++ b/sagemaker-core/src/sagemaker/core/lineage/artifact.py @@ -42,7 +42,7 @@ class Artifact(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import artifact + from sagemaker.core.lineage import artifact my_artifact = artifact.Artifact.create( artifact_name='MyArtifact', diff --git a/sagemaker-core/src/sagemaker/core/lineage/association.py b/sagemaker-core/src/sagemaker/core/lineage/association.py index f175622cd8..60e5c028ea 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/association.py +++ b/sagemaker-core/src/sagemaker/core/lineage/association.py @@ -31,7 +31,7 @@ class Association(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import association + from sagemaker.core.lineage import association my_association = association.Association.create( source_arn=artifact_arn, diff --git a/sagemaker-core/src/sagemaker/core/lineage/context.py b/sagemaker-core/src/sagemaker/core/lineage/context.py index 7a086095fa..7017d77992 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/context.py +++ b/sagemaker-core/src/sagemaker/core/lineage/context.py @@ -136,7 +136,7 @@ def load(cls, context_name: str, sagemaker_session=None) -> "Context": Examples: .. code-block:: python - from sagemaker.lineage import context + from sagemaker.core.lineage import context my_context = context.Context.create( context_name='MyContext', diff --git a/sagemaker-core/src/sagemaker/core/lineage/query.py b/sagemaker-core/src/sagemaker/core/lineage/query.py index a539cf621c..8a2d14a642 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/query.py +++ b/sagemaker-core/src/sagemaker/core/lineage/query.py @@ -194,9 +194,9 @@ def to_lineage_object(self): A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context`` or ``TrialComponent`` object. """ - from sagemaker.lineage.context import Context, EndpointContext - from sagemaker.lineage.action import Action - from sagemaker.lineage.lineage_trial_component import LineageTrialComponent + from sagemaker.core.lineage.context import Context, EndpointContext + from sagemaker.core.lineage.action import Action + from sagemaker.core.lineage.lineage_trial_component import LineageTrialComponent if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -221,8 +221,8 @@ def to_lineage_object(self): def _artifact_to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding ``Artifact``.""" - from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact - from sagemaker.lineage.artifact import DatasetArtifact + from sagemaker.core.lineage.artifact import Artifact, ModelArtifact, ImageArtifact + from sagemaker.core.lineage.artifact import DatasetArtifact if self.lineage_source == LineageSourceEnum.MODEL.value: return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) diff --git a/sagemaker-core/src/sagemaker/lineage/__init__.py b/sagemaker-core/src/sagemaker/lineage/__init__.py deleted file mode 100644 index 4d9cec4b6c..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage module - compatibility shim. - -This module provides backward compatibility for code using the old -`sagemaker.lineage` import path. All functionality has been moved to -`sagemaker.core.lineage`. - -DEPRECATED: This module is deprecated. Use `sagemaker.core.lineage` instead. -""" -from __future__ import absolute_import - -import warnings - -# Show deprecation warning -warnings.warn( - "The 'sagemaker.lineage' module is deprecated. " "Please use 'sagemaker.core.lineage' instead.", - DeprecationWarning, - stacklevel=2, -) - -# Re-export from core.lineage for backward compatibility -from sagemaker.core.lineage import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/action.py b/sagemaker-core/src/sagemaker/lineage/action.py deleted file mode 100644 index c14ffa2a69..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/action.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.action module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.action` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.action' module is deprecated. " - "Please use 'sagemaker.core.lineage.action' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.action import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/artifact.py b/sagemaker-core/src/sagemaker/lineage/artifact.py deleted file mode 100644 index 4d74205fc5..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/artifact.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.artifact module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.artifact` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.artifact' module is deprecated. " - "Please use 'sagemaker.core.lineage.artifact' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.artifact import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/context.py b/sagemaker-core/src/sagemaker/lineage/context.py deleted file mode 100644 index d5fe8b3884..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/context.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.context module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.context` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.context' module is deprecated. " - "Please use 'sagemaker.core.lineage.context' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.context import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py b/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py deleted file mode 100644 index b729166f2c..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.lineage_trial_component module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.lineage_trial_component` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.lineage_trial_component' module is deprecated. " - "Please use 'sagemaker.core.lineage.lineage_trial_component' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.lineage_trial_component import * # noqa: F401, F403 diff --git a/sagemaker-core/tests/unit/lineage/test_query.py b/sagemaker-core/tests/unit/lineage/test_query.py index 2dbeea28d4..e6fdfc5eb9 100644 --- a/sagemaker-core/tests/unit/lineage/test_query.py +++ b/sagemaker-core/tests/unit/lineage/test_query.py @@ -146,7 +146,7 @@ def test_vertex_hash(self): vertex_set = {vertex1, vertex2} assert len(vertex_set) == 1 - @patch("sagemaker.lineage.context.EndpointContext") + @patch("sagemaker.core.lineage.context.EndpointContext") def test_to_lineage_object_context(self, mock_endpoint_context_class): """Test converting vertex to Context""" mock_session = Mock() @@ -164,7 +164,7 @@ def test_to_lineage_object_context(self, mock_endpoint_context_class): # Should call EndpointContext.load for Endpoint source assert result is not None - @patch("sagemaker.lineage.action.Action") + @patch("sagemaker.core.lineage.action.Action") def test_to_lineage_object_action(self, mock_action_class): """Test converting vertex to Action""" mock_session = Mock() From 21a9f846c84ae5b360133898e226c837dfd84fa3 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:55:36 -0800 Subject: [PATCH 2/8] Remove remote_function in sagemaker.train --- .../train/remote_function/__init__.py | 34 - .../remote_function/checkpoint_location.py | 47 -- .../sagemaker/train/remote_function/client.py | 30 - .../train/remote_function/core/__init__.py | 27 - .../core/_custom_dispatch_table.py | 56 -- .../core/pipeline_variables.py | 30 - .../remote_function/core/serialization.py | 30 - .../remote_function/core/stored_function.py | 30 - .../remote_function/custom_file_filter.py | 128 ---- .../sagemaker/train/remote_function/errors.py | 30 - .../train/remote_function/invoke_function.py | 172 ----- .../sagemaker/train/remote_function/job.py | 30 - .../train/remote_function/logging_config.py | 38 -- .../runtime_environment/__init__.py | 14 - .../bootstrap_runtime_environment.py | 602 ------------------ .../runtime_environment/mpi_utils_remote.py | 252 -------- .../runtime_environment_manager.py | 529 --------------- .../runtime_environment/spark_app.py | 18 - .../train/remote_function/spark_config.py | 30 - 19 files changed, 2127 deletions(-) delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/__init__.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/client.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/errors.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/job.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/logging_config.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py delete mode 100644 sagemaker-train/src/sagemaker/train/remote_function/spark_config.py diff --git a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py deleted file mode 100644 index bf29079921..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function - -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.remote_function import ... -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import remote, RemoteExecutor # noqa: F401 -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 -from sagemaker.core.remote_function.spark_config import SparkConfig # noqa: F401 - -warnings.warn( - "sagemaker.train.remote_function has been moved to sagemaker.core.remote_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py b/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py deleted file mode 100644 index 4153fe03d3..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the CheckpointLocation to remote function.""" -from __future__ import absolute_import - -from os import PathLike -import re - -# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CheckpointConfig.html -S3_URI_REGEX_PATTERN = r"^(https|s3)://([^/]+)/?(.*)$" - -_JOB_CHECKPOINT_LOCATION = "/opt/ml/checkpoints/" - - -def _validate_s3_uri_for_checkpoint(s3_uri: str): - """Validate if checkpoint location is specified with a valid s3 URI.""" - return re.match(S3_URI_REGEX_PATTERN, s3_uri) - - -class CheckpointLocation(PathLike): - """Class to represent the location where checkpoints are accessed in a remote function. - - To save or load checkpoints in a remote function, pass an CheckpointLocation object as a - function parameter and use it as a os.PathLike object. This CheckpointLocation object - represents the local directory (/opt/ml/checkpoints/) of checkpoints in side the job. - """ - - _local_path = _JOB_CHECKPOINT_LOCATION - - def __init__(self, s3_uri): - if not _validate_s3_uri_for_checkpoint(s3_uri): - raise ValueError("CheckpointLocation should be specified with valid s3 URI.") - self._s3_uri = s3_uri - - def __fspath__(self): - """Return job local path where checkpoints are stored.""" - return self._local_path diff --git a/sagemaker-train/src/sagemaker/train/remote_function/client.py b/sagemaker-train/src/sagemaker/train/remote_function/client.py deleted file mode 100644 index eb99d14c1e..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/client.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.client - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.client has been moved to sagemaker.core.remote_function.client. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py deleted file mode 100644 index 7e9f2d30da..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "sagemaker.train.remote_function.core has been moved to sagemaker.core.remote_function.core. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py deleted file mode 100644 index 20b7a297b5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py +++ /dev/null @@ -1,56 +0,0 @@ - -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function data serializer/deserializer.""" -from __future__ import absolute_import - -from sagemaker.train.remote_function.errors import SerializationError - -from sagemaker.core.helper.pipeline_variable import PipelineVariable -from sagemaker.core.workflow.parameters import ( - ParameterInteger, - ParameterFloat, - ParameterString, - ParameterBoolean, -) -from sagemaker.core.workflow.execution_variables import ExecutionVariable -from sagemaker.mlops.workflow.function_step import DelayedReturn -from sagemaker.core.workflow.properties import ( - Properties, - PropertiesMap, - PropertiesList, -) - - -def _pipeline_variable_reducer(pipeline_variable): - """Reducer for pipeline variable.""" - - raise SerializationError( - """Please pass the pipeline variable to the function decorated with @step as an argument. - Referencing to a pipeline variable from within the function - or passing a pipeline variable nested in a data structure are not supported.""" - ) - - -dispatch_table = { - ParameterInteger: _pipeline_variable_reducer, - ParameterFloat: _pipeline_variable_reducer, - ParameterString: _pipeline_variable_reducer, - ParameterBoolean: _pipeline_variable_reducer, - ExecutionVariable: _pipeline_variable_reducer, - PipelineVariable: _pipeline_variable_reducer, - Properties: _pipeline_variable_reducer, - PropertiesMap: _pipeline_variable_reducer, - PropertiesList: _pipeline_variable_reducer, - DelayedReturn: _pipeline_variable_reducer, -} diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py deleted file mode 100644 index 5767a07596..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.pipeline_variables - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.pipeline_variables import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.pipeline_variables has been moved to sagemaker.core.remote_function.core.pipeline_variables. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py deleted file mode 100644 index d30d1494d5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.serialization - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.serialization import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.serialization has been moved to sagemaker.core.remote_function.core.serialization. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py deleted file mode 100644 index 34915a4d42..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.stored_function - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.stored_function import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.core.stored_function has been moved to sagemaker.core.remote_function.core.stored_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py b/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py deleted file mode 100644 index 9c1b1e1baa..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function client.""" -from __future__ import absolute_import - -import fnmatch -import os -import shutil -from typing import List, Optional, Callable, Union - -from sagemaker.core.common_utils import resolve_value_from_config -from sagemaker.core.config.config_schema import REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER - - -class CustomFileFilter: - """Configuration that specifies how the local working directory should be packaged.""" - - def __init__(self, *, ignore_name_patterns: List[str] = None): - """Initialize a CustomFileFilter. - - Args: - ignore_name_patterns (List[str]): ignore files or directories with names - that match one of the glob-style patterns. Defaults to None. - """ - - if ignore_name_patterns is None: - ignore_name_patterns = [] - - self._workdir = os.getcwd() - self._ignore_name_patterns = ignore_name_patterns - - @property - def ignore_name_patterns(self): - """Get the ignore name patterns.""" - return self._ignore_name_patterns - - @property - def workdir(self): - """Get the working directory.""" - return self._workdir - - -def resolve_custom_file_filter_from_config_file( - direct_input: Union[Callable[[str, List], List], CustomFileFilter] = None, - sagemaker_session=None, -) -> Union[Callable[[str, List], List], CustomFileFilter, None]: - """Resolve the CustomFileFilter configuration from the config file. - - Args: - direct_input (Callable[[str, List], List], CustomFileFilter): direct input from the user. - sagemaker_session (sagemaker.core.helper.session.Session): sagemaker session. - Returns: - CustomFileFilter: configuration that specifies how the local - working directory should be packaged. - """ - if direct_input is not None: - return direct_input - ignore_name_patterns = resolve_value_from_config( - direct_input=None, - config_path=".".join([REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER, "IgnoreNamePatterns"]), - default_value=None, - sagemaker_session=sagemaker_session, - ) - if ignore_name_patterns is not None: - return CustomFileFilter(ignore_name_patterns=ignore_name_patterns) - return None - - -def copy_workdir( - dst: str, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, -): - """Copy the local working directory to the destination. - - Args: - dst (str): destination path. - custom_file_filter (Union[Callable[[str, List], List], CustomFileFilter): configuration that - specifies how the local working directory should be packaged. - """ - - def _ignore_patterns(path: str, names: List): # pylint: disable=unused-argument - ignored_names = set() - if custom_file_filter.ignore_name_patterns is not None: - for pattern in custom_file_filter.ignore_name_patterns: - ignored_names.update(fnmatch.filter(names, pattern)) - return ignored_names - - def _filter_non_python_files(path: str, names: List) -> List: - """Ignore function for filtering out non python files.""" - to_ignore = [] - for name in names: - full_path = os.path.join(path, name) - if os.path.isfile(full_path): - if not name.endswith(".py"): - to_ignore.append(name) - elif os.path.isdir(full_path): - if name == "__pycache__": - to_ignore.append(name) - else: - to_ignore.append(name) - - return to_ignore - - _ignore = None - _src = os.getcwd() - if not custom_file_filter: - _ignore = _filter_non_python_files - elif callable(custom_file_filter): - _ignore = custom_file_filter - elif isinstance(custom_file_filter, CustomFileFilter): - _ignore = _ignore_patterns - _src = custom_file_filter.workdir - - shutil.copytree( - _src, - dst, - ignore=_ignore, - ) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/errors.py b/sagemaker-train/src/sagemaker/train/remote_function/errors.py deleted file mode 100644 index e67fcf7d9f..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/errors.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.errors - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.errors import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.errors has been moved to sagemaker.core.remote_function.errors. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py deleted file mode 100644 index 3bafeffd5b..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for invoking remote function inside a job.""" - -from __future__ import absolute_import - -import argparse -import sys -import json -import os -from typing import TYPE_CHECKING - -import boto3 -from sagemaker.train.remote_function.job import ( - KEY_EXPERIMENT_NAME, - KEY_RUN_NAME, -) - -from sagemaker.core.helper.session_helper import Session -from sagemaker.core.s3 import s3_path_join -from sagemaker.train.remote_function.errors import handle_error -from sagemaker.train.remote_function import logging_config -from sagemaker.train.remote_function.core.pipeline_variables import Context - -if TYPE_CHECKING: - from sagemaker.core.experiments.run import Run - - -SUCCESS_EXIT_CODE = 0 - - -def _parse_args(args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--region", type=str, required=True) - parser.add_argument("--s3_base_uri", type=str, required=True) - parser.add_argument("--s3_kms_key", type=str) - parser.add_argument("--run_in_context", type=str) - parser.add_argument("--pipeline_step_name", type=str) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--property_references", nargs="+", type=str, default=[]) - parser.add_argument( - "--serialize_output_to_json", default=False, type=lambda x: (str(x).lower() == "true") - ) - parser.add_argument("--func_step_s3_dir", type=str) - - args, _ = parser.parse_known_args(args) - return args - - -def _get_sagemaker_session(region): - """Get sagemaker session for interacting with AWS or Sagemaker services""" - boto_session = boto3.session.Session(region_name=region) - return Session(boto_session=boto_session) - - -def _load_run_object(run_in_context: str, sagemaker_session: Session) -> "Run": - """Load current run in json string into run object""" - from sagemaker.core.experiments.run import Run - - run_dict = json.loads(run_in_context) - return Run( - experiment_name=run_dict.get(KEY_EXPERIMENT_NAME), - run_name=run_dict.get(KEY_RUN_NAME), - sagemaker_session=sagemaker_session, - ) - - -def _load_pipeline_context(args) -> Context: - """Load pipeline build or run context into context object""" - - pipeline_step_name = args.pipeline_step_name - pipeline_execution_id = args.pipeline_execution_id - property_references = args.property_references - serialize_output_to_json = args.serialize_output_to_json - func_step_s3_dir = args.func_step_s3_dir - - property_references_dict = {} - for i in range(0, len(property_references), 2): - property_references_dict[property_references[i]] = property_references[i + 1] - return Context( - step_name=pipeline_step_name, - execution_id=pipeline_execution_id, - property_references=property_references_dict, - serialize_output_to_json=serialize_output_to_json, - func_step_s3_dir=func_step_s3_dir, - ) - - -def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context -): - """Execute stored remote function""" - from sagemaker.train.remote_function.core.stored_function import StoredFunction - - stored_function = StoredFunction( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - hmac_key=hmac_key, - context=context, - ) - - if run_in_context: - run_obj = _load_run_object(run_in_context, sagemaker_session) - with run_obj: - stored_function.load_and_invoke() - else: - stored_function.load_and_invoke() - - -def main(sys_args=None): - """Entry point for invoke function script - - Args: - sys_args (list): List of arguments to parse. If not specified, sys.argv is used. - """ - - logger = logging_config.get_logger() - - exit_code = SUCCESS_EXIT_CODE - - try: - args = _parse_args(sys_args) - region = args.region - s3_base_uri = args.s3_base_uri - s3_kms_key = args.s3_kms_key - run_in_context = args.run_in_context - pipeline_context = _load_pipeline_context(args) - - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - - sagemaker_session = _get_sagemaker_session(region) - _execute_remote_function( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - run_in_context=run_in_context, - hmac_key=hmac_key, - context=pipeline_context, - ) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while invoking the remote function.") - s3_uri = ( - s3_path_join(s3_base_uri, pipeline_context.execution_id, pipeline_context.step_name) - if pipeline_context.step_name - else s3_base_uri - ) - exit_code = handle_error( - error=e, - sagemaker_session=sagemaker_session, - s3_base_uri=s3_uri, - s3_kms_key=s3_kms_key, - hmac_key=hmac_key, - ) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/job.py b/sagemaker-train/src/sagemaker/train/remote_function/job.py deleted file mode 100644 index 33bf62af86..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/job.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.job - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.job import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.job has been moved to sagemaker.core.remote_function.job. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py b/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py deleted file mode 100644 index 875fabf6e0..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utilities related to logging.""" -from __future__ import absolute_import - -import logging -import time - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py deleted file mode 100644 index 18557a2eb5..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py deleted file mode 100644 index afe0f80012..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ /dev/null @@ -1,602 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import getpass -import json -import multiprocessing -import os -import pathlib -import shutil -import subprocess -import sys -from typing import Any, Dict - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) -else: - from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" -BASE_CHANNEL_PATH = "/opt/ml/input/data" -FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] -PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" -JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" -SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - -logger = get_logger() - - -def _bootstrap_runtime_env_for_remote_function( - client_python_version: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for remote function invocation. - - Args: - client_python_version (str): Python version at the client side. - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Settings for installing dependencies. - """ - - workspace_unpack_dir = _unpack_user_workspace() - if not workspace_unpack_dir: - logger.info("No workspace to unpack and setup.") - return - - _handle_pre_exec_scripts(workspace_unpack_dir) - - _install_dependencies( - workspace_unpack_dir, - conda_env, - client_python_version, - REMOTE_FUNCTION_WORKSPACE, - dependency_settings, - ) - - -def _bootstrap_runtime_env_for_pipeline_step( - client_python_version: str, - func_step_workspace: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for pipeline step invocation. - - Args: - client_python_version (str): Python version at the client side. - func_step_workspace (str): s3 folder where workspace for FunctionStep is stored - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Name of the dependency file. Default is None. - """ - - workspace_dir = _unpack_user_workspace(func_step_workspace) - if not workspace_dir: - os.mkdir(JOB_REMOTE_FUNCTION_WORKSPACE) - workspace_dir = pathlib.Path(os.getcwd(), JOB_REMOTE_FUNCTION_WORKSPACE).absolute() - - pre_exec_script_and_dependencies_dir = os.path.join( - BASE_CHANNEL_PATH, SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME - ) - - if not os.path.exists(pre_exec_script_and_dependencies_dir): - logger.info("No dependencies to bootstrap") - return - for file in os.listdir(pre_exec_script_and_dependencies_dir): - src_path = os.path.join(pre_exec_script_and_dependencies_dir, file) - dest_path = os.path.join(workspace_dir, file) - shutil.copy(src_path, dest_path) - - _handle_pre_exec_scripts(workspace_dir) - - _install_dependencies( - workspace_dir, - conda_env, - client_python_version, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - dependency_settings, - ) - - -def _handle_pre_exec_scripts(script_file_dir: str): - """Run the pre execution scripts. - - Args: - script_file_dir (str): Directory in the container where pre-execution scripts exists. - """ - - path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME) - RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path=path_to_pre_exec_script) - - -def _install_dependencies( - dependency_file_dir: str, - conda_env: str, - client_python_version: str, - channel_name: str, - dependency_settings: _DependencySettings = None, -): - """Install dependencies in the job container - - Args: - dependency_file_dir (str): Directory in the container where dependency file exists. - conda_env (str): conda environment to be activated. - client_python_version (str): Python version at the client side. - channel_name (str): Channel where dependency file was uploaded. - dependency_settings (dict): Settings for installing dependencies. - """ - - if dependency_settings is not None and dependency_settings.dependency_file is None: - # an empty dict is passed when no dependencies are specified - logger.info("No dependencies to install.") - elif dependency_settings is not None: - dependencies_file = os.path.join(dependency_file_dir, dependency_settings.dependency_file) - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - # no dependency file name is passed when an legacy version of the SDK is used - # we look for a file with .txt, .yml or .yaml extension in the workspace directory - dependencies_file = None - for file in os.listdir(dependency_file_dir): - if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"): - dependencies_file = os.path.join(dependency_file_dir, file) - break - - if dependencies_file: - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - logger.info( - "Did not find any dependency file in the directory at '%s'." - " Assuming no additional dependencies to install.", - os.path.join(BASE_CHANNEL_PATH, channel_name), - ) - - -def _unpack_user_workspace(func_step_workspace: str = None): - """Unzip the user workspace""" - - workspace_archive_dir_path = ( - os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE) - if not func_step_workspace - else os.path.join(BASE_CHANNEL_PATH, func_step_workspace) - ) - if not os.path.exists(workspace_archive_dir_path): - logger.info( - "Directory '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip") - if not os.path.isfile(workspace_archive_path): - logger.info( - "Workspace archive '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute() - shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir) - logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir) - workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE) - return workspace_unpack_dir - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_conda_env", type=str) - parser.add_argument("--client_python_version", type=str) - parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--dependency_settings", type=str) - parser.add_argument("--func_step_s3_dir", type=str) - parser.add_argument("--distribution", type=str, default=None) - parser.add_argument("--user_nproc_per_node", type=str, default=None) - args, _ = parser.parse_known_args(sys_args) - return args - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def set_env( - resource_config: Dict[str, Any], - distribution: str = None, - user_nproc_per_node: bool = None, - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - - if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) - else: - if int(env_vars["SM_NUM_GPUS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) - elif int(env_vars["SM_NUM_NEURONS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) - else: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "host_count": env_vars["SM_HOST_COUNT"], - "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - - if distribution and distribution == "torchrun": - logger.info("Distribution: torchrun") - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - env_vars["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" - env_vars["RDMAV_FORK_SAFE"] = "1" - env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - env_vars["NCCL_PROTO"] = "simple" - elif distribution and distribution == "mpirun": - logger.info("Distribution: mpirun") - - env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] - env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) - - host_list = [ - "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts - ] - env_vars["SM_HOSTS_LIST"] = ",".join(host_list) - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - - if instance_type in SM_EFA_NCCL_INSTANCES: - env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" - env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" - else: - env_vars["SM_FI_PROVIDER"] = "" - env_vars["SM_NCCL_PROTO"] = "" - - if instance_type in SM_EFA_RDMA_INSTANCES: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" - else: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" - - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE - - try: - args = _parse_args(sys_args) - - logger.info("Arguments:") - for arg in vars(args): - logger.info("%s=%s", arg, getattr(args, arg)) - - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir - distribution = args.distribution - user_nproc_per_node = args.user_nproc_per_node - - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") - - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) - - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - if os.path.exists(RESOURCE_CONFIG): - try: - logger.info("Found %s", RESOURCE_CONFIG) - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - set_env( - resource_config=resource_config, - distribution=distribution, - user_nproc_per_node=user_nproc_per_node, - ) - except (json.JSONDecodeError, FileNotFoundError) as e: - # Optionally, you might want to log this error - logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py deleted file mode 100644 index 79ddd4020b..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import json -import os -import subprocess -import sys -import time -from typing import List - -import paramiko - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - get_logger, - ) -else: - from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - -FAILURE_REASON_PATH = "/opt/ml/output/failure" -FINISHED_STATUS_FILE = "/tmp/done.algo-1" - -logger = get_logger() - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_ended", type=str, default="0") - args, _ = parser.parse_known_args(sys_args) - return args - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug("Connection failed with exception: %s", e) - return False - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info("Writing %s to %s", status_file, host) - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info("Cannot connect to %s", host) - return False - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info("Worker is attempting to connect to the master node %s...", master_host) - if _can_connect(master_host, port): - logger.info("Worker can connect to master node %s.", master_host) - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info("Waiting for status file %s", status_file) - while not os.path.exists(status_file): - time.sleep(30) - logger.info("Found status file %s", status_file) - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def bootstrap_worker_node( - master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE -): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % current_host) - _wait_for_status_file(status_file) - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError("Timed out waiting for %s to be reachable." % worker) - logger.info("Retrying to write status file to %s", worker) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - try: - args = _parse_args(sys_args) - - job_ended = args.job_ended - - main_host = os.environ["SM_MASTER_ADDR"] - current_host = os.environ["SM_CURRENT_HOST"] - - if job_ended == "0": - logger.info("Job is running, bootstrapping nodes") - - start_sshd_daemon() - - if current_host != main_host: - bootstrap_worker_node(main_host, current_host) - else: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - bootstrap_master_node(worker_hosts) - else: - logger.info("Job ended, writing status file to workers") - - if current_host == main_host: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - write_status_file_to_workers(worker_hosts) - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - - sys.exit(DEFAULT_FAILURE_CODE) - - -if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py deleted file mode 100644 index 9cb0c7aee4..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK""" - -from __future__ import absolute_import - - -import logging -import sys -import shlex -import os -import subprocess -import time -import dataclasses -import json - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger - - -logger = get_logger() - - -@dataclasses.dataclass -class _DependencySettings: - """Dependency settings for the remote function. - - Instructs the runtime environment script on how to handle dependencies. - If ``dependency_file`` is set, the runtime environment script will attempt - to install the dependencies. If ``dependency_file`` is not set, the runtime - environment script will assume no dependencies are required. - """ - - dependency_file: str = None - - def to_string(self): - """Converts the dependency settings to a string.""" - return json.dumps(dataclasses.asdict(self)) - - @staticmethod - def from_string(dependency_settings_string): - """Converts a json string to dependency settings. - - Args: - dependency_settings_string (str): The json string to convert. - """ - if dependency_settings_string is None: - return None - dependency_settings_dict = json.loads(dependency_settings_string) - return _DependencySettings(dependency_settings_dict.get("dependency_file")) - - @staticmethod - def from_dependency_file_path(dependency_file_path): - """Converts a dependency file path to dependency settings. - - Args: - dependency_file_path (str): The path to the dependency file. - """ - if dependency_file_path is None: - return _DependencySettings() - if dependency_file_path == "auto_capture": - return _DependencySettings("env_snapshot.yml") - return _DependencySettings(os.path.basename(dependency_file_path)) - - -class RuntimeEnvironmentManager: - """Runtime Environment Manager class to manage runtime environment.""" - - def _validate_path(self, path: str) -> str: - """Validate and sanitize file path to prevent path traversal attacks. - - Args: - path (str): The file path to validate - - Returns: - str: The validated absolute path - - Raises: - ValueError: If the path is invalid or contains suspicious patterns - """ - if not path: - raise ValueError("Path cannot be empty") - - # Get absolute path to prevent path traversal - abs_path = os.path.abspath(path) - - # Check for null bytes (common in path traversal attacks) - if '\x00' in path: - raise ValueError(f"Invalid path contains null byte: {path}") - - return abs_path - - def _validate_env_name(self, env_name: str) -> None: - """Validate conda environment name to prevent command injection. - - Args: - env_name (str): The environment name to validate - - Raises: - ValueError: If the environment name contains invalid characters - """ - if not env_name: - raise ValueError("Environment name cannot be empty") - - # Allow only alphanumeric, underscore, and hyphen - import re - if not re.match(r'^[a-zA-Z0-9_-]+$', env_name): - raise ValueError( - f"Invalid environment name '{env_name}'. " - "Only alphanumeric characters, underscores, and hyphens are allowed." - ) - - def snapshot(self, dependencies: str = None) -> str: - """Creates snapshot of the user's environment - - If a req.txt or conda.yml file is provided, it verifies their existence and - returns the local file path - If ``auto_capture`` is set, this method will take the snapshot of - user's dependencies installed in the local runtime. - Current support for ``auto_capture``: - * conda env, generate a yml file and return it's local path - - Args: - dependencies (str): Local path where dependencies file exists. - - Returns: - file path of the existing or generated dependencies file - """ - - # No additional dependencies specified - if dependencies is None: - return None - - if dependencies == "auto_capture": - return self._capture_from_local_runtime() - - # Dependencies specified as either req.txt or conda_env.yml - if ( - dependencies.endswith(".txt") - or dependencies.endswith(".yml") - or dependencies.endswith(".yaml") - ): - self._is_file_exists(dependencies) - return dependencies - - raise ValueError(f'Invalid dependencies provided: "{dependencies}"') - - def _capture_from_local_runtime(self) -> str: - """Generates dependencies list from the user's local runtime. - - Raises RuntimeEnvironmentError if not able to. - - Currently supports: conda environments - """ - - # Try to capture dependencies from the conda environment, if any. - conda_env_name = self._get_active_conda_env_name() - conda_env_prefix = self._get_active_conda_env_prefix() - if conda_env_name: - logger.info("Found conda_env_name: '%s'", conda_env_name) - elif conda_env_prefix: - logger.info("Found conda_env_prefix: '%s'", conda_env_prefix) - else: - raise ValueError("No conda environment seems to be active.") - - if conda_env_name == "base": - logger.warning( - "We recommend using an environment other than base to " - "isolate your project dependencies from conda dependencies" - ) - - local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml") - self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path) - - return local_dependencies_path - - def _get_active_conda_env_prefix(self) -> str: - """Returns the conda prefix from the set environment variable. None otherwise.""" - return os.getenv("CONDA_PREFIX") - - def _get_active_conda_env_name(self) -> str: - """Returns the conda environment name from the set environment variable. None otherwise.""" - return os.getenv("CONDA_DEFAULT_ENV") - - def bootstrap( - self, local_dependencies_file: str, client_python_version: str, conda_env: str = None - ): - """Bootstraps the runtime environment by installing the additional dependencies if any. - - Args: - local_dependencies_file (str): path where dependencies file exists. - conda_env (str): conda environment to be activated. Default is None. - - Returns: None - """ - - if local_dependencies_file.endswith(".txt"): - if conda_env: - self._install_req_txt_in_conda_env(conda_env, local_dependencies_file) - self._write_conda_env_to_file(conda_env) - - else: - self._install_requirements_txt(local_dependencies_file, _python_executable()) - - elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"): - if conda_env: - self._update_conda_env(conda_env, local_dependencies_file) - else: - conda_env = "sagemaker-runtime-env" - self._create_conda_env(conda_env, local_dependencies_file) - self._validate_python_version(client_python_version, conda_env) - self._write_conda_env_to_file(conda_env) - - def run_pre_exec_script(self, pre_exec_script_path: str): - """Runs script of pre-execution commands if existing. - - Args: - pre_exec_script_path (str): Path to pre-execution command script file. - """ - if os.path.isfile(pre_exec_script_path): - logger.info("Running pre-execution commands in '%s'", pre_exec_script_path) - return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path) - - if return_code: - error_message = ( - f"Encountered error while running pre-execution commands. Reason: {error_logs}" - ) - raise RuntimeEnvironmentError(error_message) - else: - logger.info( - "'%s' does not exist. Assuming no pre-execution commands to run", - pre_exec_script_path, - ) - - def change_dir_permission(self, dirs: list, new_permission: str): - """Change the permission of given directories - - Args: - dirs (list[str]): A list of directories for permission update. - new_permission (str): The new permission for the given directories. - """ - - _ERROR_MSG_PREFIX = "Failed to change directory permissions due to: " - command = ["sudo", "chmod", "-R", new_permission] + dirs - logger.info("Executing '%s'.", " ".join(command)) - - try: - subprocess.run(command, check=True, stderr=subprocess.PIPE) - except subprocess.CalledProcessError as called_process_err: - err_msg = called_process_err.stderr.decode("utf-8") - raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}") - except FileNotFoundError as file_not_found_err: - if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err): - raise RuntimeEnvironmentError( - f"{_ERROR_MSG_PREFIX} {file_not_found_err}. " - "Please contact the image owner to install 'sudo' in the job container " - "and provide sudo privilege to the container user." - ) - raise RuntimeEnvironmentError(file_not_found_err) - - def _is_file_exists(self, dependencies): - """Check whether the dependencies file exists at the given location. - - Raises error if not - """ - if not os.path.isfile(dependencies): - raise ValueError(f'No dependencies file named "{dependencies}" was found.') - - def _install_requirements_txt(self, local_path, python_executable): - """Install requirements.txt file""" - # Validate path to prevent command injection - validated_path = self._validate_path(local_path) - cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"] - logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd()) - _run_shell_cmd(cmd) - logger.info("Command %s ran successfully", " ".join(cmd)) - - def _create_conda_env(self, env_name, local_path): - """Create conda env using conda yml file""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path] - logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Conda environment %s created successfully.", env_name) - - def _install_req_txt_in_conda_env(self, env_name, local_path): - """Install requirements.txt in the given conda environment""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"] - logger.info("Activating conda env and installing requirements: %s", " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Requirements installed successfully in conda env %s", env_name) - - def _update_conda_env(self, env_name, local_path): - """Update conda env using conda yml file""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path] - logger.info("Updating conda env: %s", " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Conda env %s updated succesfully", env_name) - - def _export_conda_env_from_prefix(self, prefix, local_path): - """Export the conda env to a conda yml file""" - - cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path] - logger.info("Exporting conda environment: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Conda environment %s exported successfully", prefix) - - def _write_conda_env_to_file(self, env_name): - """Writes conda env to the text file""" - - file_name = "remote_function_conda_env.txt" - file_path = os.path.join(os.getcwd(), file_name) - with open(file_path, "w") as output_file: - output_file.write(env_name) - - def _get_conda_exe(self): - """Checks whether conda or mamba is available to use""" - - if not subprocess.Popen(["which", "mamba"]).wait(): - return "mamba" - if not subprocess.Popen(["which", "conda"]).wait(): - return "conda" - raise ValueError("Neither conda nor mamba is installed on the image") - - def _python_version_in_conda_env(self, env_name): - """Returns python version inside a conda environment""" - cmd = f"{self._get_conda_exe()} run -n {env_name} python --version" - try: - output = ( - subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT) - .decode("utf-8") - .strip() - ) - # convert 'Python 3.7.16' to [3, 7, 16] - version = output.split("Python ")[1].split(".") - return version[0] + "." + version[1] - except subprocess.CalledProcessError as e: - raise RuntimeEnvironmentError(e.output) - - def _current_python_version(self): - """Returns the current python version where program is running""" - - return f"{sys.version_info.major}.{sys.version_info.minor}".strip() - - def _current_sagemaker_pysdk_version(self): - """Returns the current sagemaker python sdk version where program is running""" - try: - from importlib import metadata - return metadata.version("sagemaker") - except Exception: - return "3.0.0.dev0" # Development version fallback - - def _validate_python_version(self, client_python_version: str, conda_env: str = None): - """Validate the python version - - Validates if the python version where remote function runs - matches the one used on client side. - """ - if conda_env: - job_python_version = self._python_version_in_conda_env(conda_env) - else: - job_python_version = self._current_python_version() - if client_python_version.strip() != job_python_version.strip(): - raise RuntimeEnvironmentError( - f"Python version found in the container is '{job_python_version}' which " - f"does not match python version '{client_python_version}' on the local client. " - f"Please make sure that the python version used in the training container " - f"is same as the local python version." - ) - - def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): - """Validate the sagemaker python sdk version - - Validates if the sagemaker python sdk version where remote function runs - matches the one used on client side. - Otherwise, log a warning to call out that unexpected behaviors - may occur in this case. - """ - job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version() - if ( - client_sagemaker_pysdk_version - and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version - ): - logger.warning( - "Inconsistent sagemaker versions found: " - "sagemaker python sdk version found in the container is " - "'%s' which does not match the '%s' on the local client. " - "Please make sure that the sagemaker version used in the training container " - "is the same as the local sagemaker version in case of unexpected behaviors.", - job_sagemaker_pysdk_version, - client_sagemaker_pysdk_version, - ) - - -def _run_and_get_output_shell_cmd(cmd: str) -> str: - """Run and return the output of the given shell command""" - return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") - - -def _run_pre_execution_command_script(script_path: str): - """This method runs a given shell script using subprocess - - Raises RuntimeEnvironmentError if the shell script fails - """ - current_dir = os.path.dirname(script_path) - - process = subprocess.Popen( - ["/bin/bash", "-eu", script_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=current_dir, - ) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - - return return_code, error_logs - - -def _run_shell_cmd(cmd: list): - """This method runs a given shell command using subprocess - - Args: - cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt']) - - Raises: - RuntimeEnvironmentError: If the command fails - ValueError: If cmd is not a list - """ - if not isinstance(cmd, list): - raise ValueError("Command must be a list of arguments for security reasons") - - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - if return_code: - error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}" - raise RuntimeEnvironmentError(error_message) - - -def _log_output(process): - """This method takes in Popen process and logs the output of that process""" - with process.stdout as pipe: - for line in iter(pipe.readline, b""): - logger.info(str(line, "UTF-8")) - - -def _log_error(process): - """This method takes in Popen process and logs the error of that process. - - Returns those logs as a string - """ - - error_logs = "" - with process.stderr as pipe: - for line in iter(pipe.readline, b""): - error_str = str(line, "UTF-8") - if "ERROR:" in error_str: - logger.error(error_str) - else: - logger.warning(error_str) - error_logs = error_logs + error_str - - return error_logs - - -def _python_executable(): - """Return the real path for the Python executable, if it exists. - - Return RuntimeEnvironmentError otherwise. - - Returns: - (str): The real path of the current Python executable. - """ - if not sys.executable: - raise RuntimeEnvironmentError( - "Failed to retrieve the path for the Python executable binary" - ) - return sys.executable - - -class RuntimeEnvironmentError(Exception): - """The base exception class for bootstrap env excepitons""" - - def __init__(self, message): - self.message = message - super().__init__(self.message) \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py deleted file mode 100644 index 6d4eaeb18e..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This is a simple scrip of spark which invokes the pickled remote function""" -from __future__ import absolute_import - -from sagemaker.train.remote_function import invoke_function - -invoke_function.main() diff --git a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py b/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py deleted file mode 100644 index b5083b0566..0000000000 --- a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.spark_config - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.spark_config import * # noqa: F401, F403 - -warnings.warn( - "sagemaker.train.remote_function.spark_config has been moved to sagemaker.core.remote_function.spark_config. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) From 8020538f17b94828493eceb06a1bd4b47bd99f90 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:56:12 -0800 Subject: [PATCH 3/8] Move remote_funtion folder to sagemaker train --- .../src/sagemaker/train}/remote_function/__init__.py | 0 .../src/sagemaker/train}/remote_function/checkpoint_location.py | 0 .../src/sagemaker/train}/remote_function/client.py | 0 .../src/sagemaker/train}/remote_function/core/__init__.py | 0 .../train}/remote_function/core/_custom_dispatch_table.py | 0 .../sagemaker/train}/remote_function/core/pipeline_variables.py | 0 .../src/sagemaker/train}/remote_function/core/serialization.py | 0 .../src/sagemaker/train}/remote_function/core/stored_function.py | 0 .../src/sagemaker/train}/remote_function/custom_file_filter.py | 0 .../src/sagemaker/train}/remote_function/errors.py | 0 .../src/sagemaker/train}/remote_function/invoke_function.py | 0 .../src/sagemaker/train}/remote_function/job.py | 0 .../src/sagemaker/train}/remote_function/logging_config.py | 0 .../train}/remote_function/runtime_environment/__init__.py | 0 .../runtime_environment/bootstrap_runtime_environment.py | 0 .../remote_function/runtime_environment/mpi_utils_remote.py | 0 .../runtime_environment/runtime_environment_manager.py | 0 .../train}/remote_function/runtime_environment/spark_app.py | 0 .../src/sagemaker/train}/remote_function/spark_config.py | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/__init__.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/checkpoint_location.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/client.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/core/__init__.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/core/_custom_dispatch_table.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/core/pipeline_variables.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/core/serialization.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/core/stored_function.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/custom_file_filter.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/errors.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/invoke_function.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/job.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/logging_config.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/runtime_environment/__init__.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/runtime_environment/bootstrap_runtime_environment.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/runtime_environment/mpi_utils_remote.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/runtime_environment/runtime_environment_manager.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/runtime_environment/spark_app.py (100%) rename {sagemaker-core/src/sagemaker/core => sagemaker-train/src/sagemaker/train}/remote_function/spark_config.py (100%) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/__init__.py rename to sagemaker-train/src/sagemaker/train/remote_function/__init__.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/checkpoint_location.py b/sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/checkpoint_location.py rename to sagemaker-train/src/sagemaker/train/remote_function/checkpoint_location.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-train/src/sagemaker/train/remote_function/client.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/client.py rename to sagemaker-train/src/sagemaker/train/remote_function/client.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/core/__init__.py rename to sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/_custom_dispatch_table.py b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/core/_custom_dispatch_table.py rename to sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py rename to sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py rename to sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py rename to sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/custom_file_filter.py b/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/custom_file_filter.py rename to sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-train/src/sagemaker/train/remote_function/errors.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/errors.py rename to sagemaker-train/src/sagemaker/train/remote_function/errors.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py rename to sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-train/src/sagemaker/train/remote_function/job.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/job.py rename to sagemaker-train/src/sagemaker/train/remote_function/job.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/logging_config.py b/sagemaker-train/src/sagemaker/train/remote_function/logging_config.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/logging_config.py rename to sagemaker-train/src/sagemaker/train/remote_function/logging_config.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/__init__.py rename to sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/__init__.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py rename to sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py rename to sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py rename to sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/spark_app.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/spark_app.py rename to sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/spark_config.py b/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py similarity index 100% rename from sagemaker-core/src/sagemaker/core/remote_function/spark_config.py rename to sagemaker-train/src/sagemaker/train/remote_function/spark_config.py From 9cc183d0f2885a1486f329a2b8c74cf991ae1c2b Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:09:38 -0800 Subject: [PATCH 4/8] Update import paths after moving remote_function folder --- .../core/workflow/execution_variables.py | 2 +- .../src/sagemaker/core/workflow/parameters.py | 2 +- .../src/sagemaker/core/workflow/properties.py | 2 +- .../tests/unit/remote_function/__init__.py | 12 - .../runtime_environment/__init__.py | 12 - .../test_bootstrap_runtime_environment.py | 679 ------------- .../test_mpi_utils_remote.py | 424 -------- .../test_runtime_environment_manager.py | 572 ----------- .../test_checkpoint_location.py | 82 -- .../tests/unit/remote_function/test_client.py | 97 -- .../test_custom_file_filter.py | 169 ---- .../remote_function/test_invoke_function.py | 280 ------ .../tests/unit/remote_function/test_job.py | 932 ------------------ .../remote_function/test_job_comprehensive.py | 533 ---------- .../remote_function/test_logging_config.py | 86 -- .../sagemaker/mlops/workflow/function_step.py | 18 +- .../src/sagemaker/mlops/workflow/pipeline.py | 8 +- .../tests/unit/workflow/test_pipeline.py | 10 +- .../train/remote_function/__init__.py | 8 +- .../sagemaker/train/remote_function/client.py | 16 +- .../core/_custom_dispatch_table.py | 2 +- .../core/pipeline_variables.py | 2 +- .../remote_function/core/serialization.py | 2 +- .../remote_function/core/stored_function.py | 6 +- .../sagemaker/train/remote_function/errors.py | 2 +- .../train/remote_function/invoke_function.py | 10 +- .../sagemaker/train/remote_function/job.py | 14 +- .../bootstrap_runtime_environment.py | 2 +- .../runtime_environment/mpi_utils_remote.py | 2 +- .../runtime_environment/spark_app.py | 2 +- 30 files changed, 55 insertions(+), 3933 deletions(-) delete mode 100644 sagemaker-core/tests/unit/remote_function/__init__.py delete mode 100644 sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py delete mode 100644 sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py delete mode 100644 sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py delete mode 100644 sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_client.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_invoke_function.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_job.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py delete mode 100644 sagemaker-core/tests/unit/remote_function/test_logging_config.py diff --git a/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py b/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py index efb0b8b6ef..380ad0c280 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py +++ b/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py @@ -56,7 +56,7 @@ def expr(self) -> RequestType: def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import _ExecutionVariable + from sagemaker.train.remote_function.core.pipeline_variables import _ExecutionVariable return _ExecutionVariable(name=self.name) diff --git a/sagemaker-core/src/sagemaker/core/workflow/parameters.py b/sagemaker-core/src/sagemaker/core/workflow/parameters.py index 90505c99cc..81d6d59d94 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/parameters.py +++ b/sagemaker-core/src/sagemaker/core/workflow/parameters.py @@ -96,7 +96,7 @@ def expr(self) -> Dict[str, str]: def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import ( + from sagemaker.train.remote_function.core.pipeline_variables import ( _ParameterString, _ParameterInteger, _ParameterBoolean, diff --git a/sagemaker-core/src/sagemaker/core/workflow/properties.py b/sagemaker-core/src/sagemaker/core/workflow/properties.py index c9e897e178..366cecfd3f 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/properties.py +++ b/sagemaker-core/src/sagemaker/core/workflow/properties.py @@ -137,7 +137,7 @@ def __reduce__(self): def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import _Properties + from sagemaker.train.remote_function.core.pipeline_variables import _Properties prefix = f"Steps.{self.step_name}" full_path = prefix if self.path is None else f"{prefix}.{self.path}" diff --git a/sagemaker-core/tests/unit/remote_function/__init__.py b/sagemaker-core/tests/unit/remote_function/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/remote_function/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py deleted file mode 100644 index 461a3ecb73..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for bootstrap_runtime_environment module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import subprocess -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment import ( - _parse_args, - _bootstrap_runtime_env_for_remote_function, - _bootstrap_runtime_env_for_pipeline_step, - _handle_pre_exec_scripts, - _install_dependencies, - _unpack_user_workspace, - _write_failure_reason_file, - log_key_value, - log_env_variables, - mask_sensitive_info, - num_cpus, - num_gpus, - num_neurons, - safe_serialize, - set_env, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - REMOTE_FUNCTION_WORKSPACE, - BASE_CHANNEL_PATH, - JOB_REMOTE_FUNCTION_WORKSPACE, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - SENSITIVE_KEYWORDS, - HIDDEN_VALUE, -) -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, -) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_required_args(self): - """Test parsing required arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.client_python_version == "3.8" - - def test_parse_all_args(self): - """Test parsing all arguments.""" - args = [ - "--job_conda_env", "my-env", - "--client_python_version", "3.9", - "--client_sagemaker_pysdk_version", "2.100.0", - "--pipeline_execution_id", "exec-123", - "--dependency_settings", '{"dependency_file": "requirements.txt"}', - "--func_step_s3_dir", "s3://bucket/func", - "--distribution", "torchrun", - "--user_nproc_per_node", "4", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env == "my-env" - assert parsed.client_python_version == "3.9" - assert parsed.client_sagemaker_pysdk_version == "2.100.0" - assert parsed.pipeline_execution_id == "exec-123" - assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}' - assert parsed.func_step_s3_dir == "s3://bucket/func" - assert parsed.distribution == "torchrun" - assert parsed.user_nproc_per_node == "4" - - def test_parse_default_values(self): - """Test default values for optional arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env is None - assert parsed.client_sagemaker_pysdk_version is None - assert parsed.pipeline_execution_id is None - assert parsed.dependency_settings is None - assert parsed.func_step_s3_dir is None - assert parsed.distribution is None - assert parsed.user_nproc_per_node is None - - -class TestLogKeyValue: - """Test log_key_value function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_regular_value(self, mock_logger): - """Test logs regular key-value pair.""" - log_key_value("my_name", "my_value") - mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value") - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_masks_sensitive_key(self, mock_logger): - """Test masks sensitive keywords.""" - for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]: - mock_logger.reset_mock() - log_key_value(f"my_{keyword}", "sensitive_value") - mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE) - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_dict_value(self, mock_logger): - """Test logs dictionary value.""" - value = {"field1": "value1", "field2": "value2"} - log_key_value("my_config", value) - mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value)) - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_json_string_value(self, mock_logger): - """Test logs JSON string value.""" - value = '{"key1": "value1"}' - log_key_value("my_key", value) - mock_logger.info.assert_called_once() - - -class TestLogEnvVariables: - """Test log_env_variables function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value") - @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"}) - def test_logs_env_and_dict_variables(self, mock_log_kv): - """Test logs both environment and dictionary variables.""" - env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"} - log_env_variables(env_dict) - - # Should be called for env vars and dict vars - assert mock_log_kv.call_count >= 4 - - -class TestMaskSensitiveInfo: - """Test mask_sensitive_info function.""" - - def test_masks_sensitive_keys_in_dict(self): - """Test masks sensitive keys in dictionary.""" - data = { - "username": "user", - "password": "secret123", - "api_key": "key123", - } - result = mask_sensitive_info(data) - assert result["username"] == "user" - assert result["password"] == HIDDEN_VALUE - assert result["api_key"] == HIDDEN_VALUE - - def test_masks_nested_dict(self): - """Test masks sensitive keys in nested dictionary.""" - data = { - "config": { - "username": "user", - "secret": "secret123", - } - } - result = mask_sensitive_info(data) - assert result["config"]["username"] == "user" - assert result["config"]["secret"] == HIDDEN_VALUE - - def test_returns_non_dict_unchanged(self): - """Test returns non-dictionary unchanged.""" - data = "string_value" - result = mask_sensitive_info(data) - assert result == "string_value" - - -class TestNumCpus: - """Test num_cpus function.""" - - @patch("multiprocessing.cpu_count") - def test_returns_cpu_count(self, mock_cpu_count): - """Test returns CPU count.""" - mock_cpu_count.return_value = 8 - assert num_cpus() == 8 - - -class TestNumGpus: - """Test num_gpus function.""" - - @patch("subprocess.check_output") - def test_returns_gpu_count(self, mock_check_output): - """Test returns GPU count.""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" - assert num_gpus() == 2 - - @patch("subprocess.check_output") - def test_returns_zero_on_error(self, mock_check_output): - """Test returns zero when nvidia-smi fails.""" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") - assert num_gpus() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when nvidia-smi not found.""" - mock_check_output.side_effect = OSError() - assert num_gpus() == 0 - - -class TestNumNeurons: - """Test num_neurons function.""" - - @patch("subprocess.check_output") - def test_returns_neuron_count(self, mock_check_output): - """Test returns neuron core count.""" - mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}]) - mock_check_output.return_value = mock_output.encode("utf-8") - assert num_neurons() == 6 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when neuron-ls not found.""" - mock_check_output.side_effect = OSError() - assert num_neurons() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_called_process_error(self, mock_check_output): - """Test returns zero when neuron-ls fails.""" - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = b"error=No neuron devices found" - mock_check_output.side_effect = error - assert num_neurons() == 0 - - -class TestSafeSerialize: - """Test safe_serialize function.""" - - def test_returns_string_as_is(self): - """Test returns string without quotes.""" - assert safe_serialize("test_string") == "test_string" - - def test_serializes_dict(self): - """Test serializes dictionary.""" - data = {"key": "value"} - assert safe_serialize(data) == '{"key": "value"}' - - def test_serializes_list(self): - """Test serializes list.""" - data = [1, 2, 3] - assert safe_serialize(data) == "[1, 2, 3]" - - def test_returns_str_for_non_serializable(self): - """Test returns str() for non-serializable objects.""" - class CustomObj: - def __str__(self): - return "custom_object" - - obj = CustomObj() - assert safe_serialize(obj) == "custom_object" - - -class TestSetEnv: - """Test set_env function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets basic environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config) - - mock_file.assert_called_once() - mock_log_env.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets torchrun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p4d.24xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="torchrun") - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets mpirun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="mpirun") - - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test uses user-specified nproc_per_node.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, user_nproc_per_node="4") - - mock_file.assert_called_once() - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestUnpackUserWorkspace: - """Test _unpack_user_workspace function.""" - - @patch("os.path.exists") - def test_returns_none_if_dir_not_exists(self, mock_exists): - """Test returns None if workspace directory doesn't exist.""" - mock_exists.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("os.path.isfile") - @patch("os.path.exists") - def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile): - """Test returns None if workspace archive doesn't exist.""" - mock_exists.return_value = True - mock_isfile.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("shutil.unpack_archive") - @patch("os.path.isfile") - @patch("os.path.exists") - @patch("os.getcwd") - def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack): - """Test unpacks workspace successfully.""" - mock_getcwd.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_isfile.return_value = True - - result = _unpack_user_workspace() - - mock_unpack.assert_called_once() - assert result is not None - - -class TestHandlePreExecScripts: - """Test _handle_pre_exec_scripts function.""" - - @patch("os.path.isfile") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_runs_pre_exec_script(self, mock_manager_class, mock_isfile): - """Test runs pre-execution script.""" - mock_isfile.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/tmp/scripts") - - mock_manager.run_pre_exec_script.assert_called_once() - - -class TestInstallDependencies: - """Test _install_dependencies function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_installs_with_dependency_settings(self, mock_manager_class): - """Test installs dependencies with dependency settings.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file="requirements.txt") - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_skips_if_no_dependency_file(self, mock_manager_class): - """Test skips installation if no dependency file.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file=None) - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_not_called() - - @patch("os.listdir") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir): - """Test finds dependency file in legacy mode.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - mock_listdir.return_value = ["requirements.txt", "script.py"] - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - None - ) - - mock_manager.bootstrap.assert_called_once() - - -class TestBootstrapRuntimeEnvForRemoteFunction: - """Test _bootstrap_runtime_env_for_remote_function function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install): - """Test bootstraps runtime environment successfully.""" - mock_unpack.return_value = "/tmp/workspace" - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_returns_early_if_no_workspace(self, mock_unpack): - """Test returns early if no workspace to unpack.""" - mock_unpack.return_value = None - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - - -class TestBootstrapRuntimeEnvForPipelineStep: - """Test _bootstrap_runtime_env_for_pipeline_step function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("shutil.copy") - @patch("os.listdir") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install): - """Test bootstraps pipeline step with workspace.""" - mock_unpack.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_listdir.return_value = ["requirements.txt"] - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("os.getcwd") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install): - """Test creates workspace directory if none exists.""" - mock_unpack.return_value = None - mock_getcwd.return_value = "/tmp" - mock_exists.return_value = False - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_mkdir.assert_called_once() - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function") - @patch("getpass.getuser") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function successful execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = None - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = None - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure): - """Test main function handles exceptions.""" - mock_getuser.return_value = "root" - mock_manager = MagicMock() - mock_manager._validate_python_version.side_effect = Exception("Test error") - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step") - @patch("getpass.getuser") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function for pipeline execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = "exec-123" - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = "s3://bucket/func" - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - "--pipeline_execution_id", "exec-123", - "--func_step_s3_dir", "s3://bucket/func", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_non_root_user(self, mock_getuser, mock_manager_class): - """Test main function with non-root user.""" - mock_getuser.return_value = "ubuntu" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit): - main(args) - - mock_manager.change_dir_permission.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py deleted file mode 100644 index b84dda5c1a..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py +++ /dev/null @@ -1,424 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for mpi_utils_remote module.""" -from __future__ import absolute_import - -import os -import pytest -import subprocess -import time -from unittest.mock import patch, MagicMock, mock_open, call -import paramiko - -from sagemaker.core.remote_function.runtime_environment.mpi_utils_remote import ( - CustomHostKeyPolicy, - _parse_args, - _can_connect, - _write_file_to_host, - _write_failure_reason_file, - _wait_for_master, - _wait_for_status_file, - _wait_for_workers, - bootstrap_master_node, - bootstrap_worker_node, - start_sshd_daemon, - write_status_file_to_workers, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - FINISHED_STATUS_FILE, - READY_FILE, - DEFAULT_SSH_PORT, -) - - -class TestCustomHostKeyPolicy: - """Test CustomHostKeyPolicy class.""" - - def test_accepts_algo_hostname(self): - """Test accepts hostnames starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "algo-1234" - mock_key = MagicMock() - mock_key.get_name.return_value = "ssh-rsa" - - # Should not raise exception - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key) - - def test_rejects_non_algo_hostname(self): - """Test rejects hostnames not starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "unknown-host" - mock_key = MagicMock() - - with pytest.raises(paramiko.SSHException): - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_default_args(self): - """Test parsing with default arguments.""" - args = [] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - def test_parse_job_ended_true(self): - """Test parsing with job_ended set to true.""" - args = ["--job_ended", "1"] - parsed = _parse_args(args) - assert parsed.job_ended == "1" - - def test_parse_job_ended_false(self): - """Test parsing with job_ended set to false.""" - args = ["--job_ended", "0"] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - -class TestCanConnect: - """Test _can_connect function.""" - - @patch("paramiko.SSHClient") - def test_can_connect_success(self, mock_ssh_client_class): - """Test successful connection.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is True - mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT) - - @patch("paramiko.SSHClient") - def test_can_connect_failure(self, mock_ssh_client_class): - """Test failed connection.""" - mock_client = MagicMock() - mock_client.connect.side_effect = Exception("Connection failed") - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is False - - @patch("paramiko.SSHClient") - def test_can_connect_uses_custom_port(self, mock_ssh_client_class): - """Test connection with custom port.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - _can_connect("algo-1", 2222) - - mock_client.connect.assert_called_once_with("algo-1", port=2222) - - -class TestWriteFileToHost: - """Test _write_file_to_host function.""" - - @patch("subprocess.run") - def test_write_file_success(self, mock_run): - """Test successful file write.""" - mock_run.return_value = MagicMock(returncode=0) - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is True - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_write_file_failure(self, mock_run): - """Test failed file write.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is False - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestWaitForMaster: - """Test _wait_for_master function.""" - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_success(self, mock_can_connect, mock_sleep): - """Test successful wait for master.""" - mock_can_connect.return_value = True - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time): - """Test timeout waiting for master.""" - mock_can_connect.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] # Simulate time passing - - with pytest.raises(TimeoutError): - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time): - """Test retries before successful connection.""" - mock_can_connect.side_effect = [False, False, True] - # Return value instead of side_effect for time.time() - mock_time.return_value = 0 - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 3 - - -class TestWaitForStatusFile: - """Test _wait_for_status_file function.""" - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_exists(self, mock_exists, mock_sleep): - """Test wait for status file that exists.""" - mock_exists.return_value = True - - _wait_for_status_file("/tmp/status") - - mock_exists.assert_called_once_with("/tmp/status") - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_waits(self, mock_exists, mock_sleep): - """Test waits until status file exists.""" - mock_exists.side_effect = [False, False, True] - - _wait_for_status_file("/tmp/status") - - assert mock_exists.call_count == 3 - assert mock_sleep.call_count == 2 - - -class TestWaitForWorkers: - """Test _wait_for_workers function.""" - - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists): - """Test wait for workers with empty list.""" - _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_not_called() - - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep): - """Test successful wait for workers.""" - mock_can_connect.return_value = True - mock_exists.return_value = True - - _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 2 - - @patch("time.time") - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time): - """Test timeout waiting for workers.""" - mock_can_connect.return_value = False - mock_exists.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] - - with pytest.raises(TimeoutError): - _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300) - - -class TestBootstrapMasterNode: - """Test bootstrap_master_node function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers") - def test_bootstrap_master_node(self, mock_wait): - """Test bootstrap master node.""" - worker_hosts = ["algo-2", "algo-3"] - - bootstrap_master_node(worker_hosts) - - mock_wait.assert_called_once_with(worker_hosts) - - -class TestBootstrapWorkerNode: - """Test bootstrap_worker_node function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_master") - def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status): - """Test bootstrap worker node.""" - bootstrap_worker_node("algo-1", "algo-2", "/tmp/status") - - mock_wait_master.assert_called_once_with("algo-1") - mock_write.assert_called_once() - mock_wait_status.assert_called_once_with("/tmp/status") - - -class TestStartSshdDaemon: - """Test start_sshd_daemon function.""" - - @patch("subprocess.Popen") - @patch("os.path.exists") - def test_starts_sshd_successfully(self, mock_exists, mock_popen): - """Test starts SSH daemon successfully.""" - mock_exists.return_value = True - - start_sshd_daemon() - - mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - - @patch("os.path.exists") - def test_raises_error_if_sshd_not_found(self, mock_exists): - """Test raises error if SSH daemon not found.""" - mock_exists.return_value = False - - with pytest.raises(RuntimeError): - start_sshd_daemon() - - -class TestWriteStatusFileToWorkers: - """Test write_status_file_to_workers function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_writes_to_all_workers(self, mock_write): - """Test writes status file to all workers.""" - mock_write.return_value = True - worker_hosts = ["algo-2", "algo-3"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_retries_on_failure(self, mock_write, mock_sleep): - """Test retries writing status file on failure.""" - mock_write.side_effect = [False, False, True] - worker_hosts = ["algo-2"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 3 - assert mock_sleep.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_raises_timeout_after_retries(self, mock_write, mock_sleep): - """Test raises timeout after max retries.""" - mock_write.return_value = False - worker_hosts = ["algo-2"] - - with pytest.raises(TimeoutError): - write_status_file_to_workers(worker_hosts, "/tmp/status") - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker): - """Test main function for worker node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_worker.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master): - """Test main function for master node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_master.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_job_ended(self, mock_write_status): - """Test main function for master node after job ends.""" - args = ["--job_ended", "1"] - - main(args) - - mock_write_status.assert_called_once() - - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_job_ended(self): - """Test main function for worker node after job ends.""" - args = ["--job_ended", "1"] - - # Should not raise any exceptions - main(args) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_handles_exception(self, mock_start_sshd, mock_write_failure): - """Test main function handles exceptions.""" - mock_start_sshd.side_effect = Exception("Test error") - args = ["--job_ended", "0"] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py deleted file mode 100644 index a300daf2b3..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py +++ /dev/null @@ -1,572 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for runtime_environment_manager module.""" -from __future__ import absolute_import - -import json -import os -import subprocess -import sys -import pytest -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, - RuntimeEnvironmentManager, - RuntimeEnvironmentError, - get_logger, - _run_and_get_output_shell_cmd, - _run_pre_execution_command_script, - _run_shell_cmd, - _log_output, - _log_error, - _python_executable, -) - - -class TestDependencySettings: - """Test _DependencySettings class.""" - - def test_init_with_no_file(self): - """Test initialization without dependency file.""" - settings = _DependencySettings() - assert settings.dependency_file is None - - def test_init_with_file(self): - """Test initialization with dependency file.""" - settings = _DependencySettings(dependency_file="requirements.txt") - assert settings.dependency_file == "requirements.txt" - - def test_to_string(self): - """Test converts to JSON string.""" - settings = _DependencySettings(dependency_file="requirements.txt") - result = settings.to_string() - assert result == '{"dependency_file": "requirements.txt"}' - - def test_from_string_with_file(self): - """Test creates from JSON string with file.""" - json_str = '{"dependency_file": "requirements.txt"}' - settings = _DependencySettings.from_string(json_str) - assert settings.dependency_file == "requirements.txt" - - def test_from_string_with_none(self): - """Test creates from None.""" - settings = _DependencySettings.from_string(None) - assert settings is None - - def test_from_dependency_file_path_with_none(self): - """Test creates from None file path.""" - settings = _DependencySettings.from_dependency_file_path(None) - assert settings.dependency_file is None - - def test_from_dependency_file_path_with_auto_capture(self): - """Test creates from auto_capture.""" - settings = _DependencySettings.from_dependency_file_path("auto_capture") - assert settings.dependency_file == "env_snapshot.yml" - - def test_from_dependency_file_path_with_path(self): - """Test creates from file path.""" - settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") - assert settings.dependency_file == "requirements.txt" - - -class TestGetLogger: - """Test get_logger function.""" - - def test_returns_logger(self): - """Test returns logger instance.""" - logger = get_logger() - assert logger is not None - assert logger.name == "sagemaker.remote_function" - - -class TestRuntimeEnvironmentManager: - """Test RuntimeEnvironmentManager class.""" - - def test_init(self): - """Test initialization.""" - manager = RuntimeEnvironmentManager() - assert manager is not None - - @patch("os.path.isfile") - def test_snapshot_returns_none_for_none(self, mock_isfile): - """Test snapshot returns None when dependencies is None.""" - manager = RuntimeEnvironmentManager() - result = manager.snapshot(None) - assert result is None - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime") - def test_snapshot_auto_capture(self, mock_capture): - """Test snapshot with auto_capture.""" - mock_capture.return_value = "/path/to/env_snapshot.yml" - manager = RuntimeEnvironmentManager() - result = manager.snapshot("auto_capture") - assert result == "/path/to/env_snapshot.yml" - mock_capture.assert_called_once() - - @patch("os.path.isfile") - def test_snapshot_with_txt_file(self, mock_isfile): - """Test snapshot with requirements.txt file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("requirements.txt") - assert result == "requirements.txt" - - @patch("os.path.isfile") - def test_snapshot_with_yml_file(self, mock_isfile): - """Test snapshot with conda.yml file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("environment.yml") - assert result == "environment.yml" - - @patch("os.path.isfile") - def test_snapshot_raises_error_for_invalid_file(self, mock_isfile): - """Test snapshot raises error for invalid file.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("requirements.txt") - - def test_snapshot_raises_error_for_invalid_format(self): - """Test snapshot raises error for invalid format.""" - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("invalid.json") - - @patch("os.getenv") - def test_get_active_conda_env_prefix(self, mock_getenv): - """Test gets active conda environment prefix.""" - mock_getenv.return_value = "/opt/conda/envs/myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_prefix() - assert result == "/opt/conda/envs/myenv" - - @patch("os.getenv") - def test_get_active_conda_env_name(self, mock_getenv): - """Test gets active conda environment name.""" - mock_getenv.return_value = "myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_name() - assert result == "myenv" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix") - @patch("os.getcwd") - @patch("os.getenv") - def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export): - """Test captures from local runtime.""" - mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - result = manager._capture_from_local_runtime() - assert result == "/tmp/env_snapshot.yml" - mock_export.assert_called_once() - - @patch("os.getenv") - def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv): - """Test raises error when no conda environment active.""" - mock_getenv.return_value = None - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._capture_from_local_runtime() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt") - def test_bootstrap_with_txt_file_no_conda(self, mock_install): - """Test bootstrap with requirements.txt without conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", None) - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env") - def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write): - """Test bootstrap with requirements.txt with conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", "myenv") - mock_install.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env") - def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write): - """Test bootstrap with conda.yml with existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", "myenv") - mock_update.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env") - def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write): - """Test bootstrap with conda.yml without existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", None) - mock_create.assert_called_once() - mock_validate.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script): - """Test runs pre-execution script when it exists.""" - mock_isfile.return_value = True - mock_run_script.return_value = (0, "") - manager = RuntimeEnvironmentManager() - manager.run_pre_exec_script("/path/to/script.sh") - mock_run_script.assert_called_once() - - @patch("os.path.isfile") - def test_run_pre_exec_script_not_exists(self, mock_isfile): - """Test handles pre-execution script not existing.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script): - """Test raises error when pre-execution script fails.""" - mock_isfile.return_value = True - mock_run_script.return_value = (1, "Error message") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("subprocess.run") - def test_change_dir_permission_success(self, mock_run): - """Test changes directory permissions successfully.""" - manager = RuntimeEnvironmentManager() - manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777") - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_on_failure(self, mock_run): - """Test raises error when permission change fails.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_no_sudo(self, mock_run): - """Test raises error when sudo not found.""" - mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - def test_install_requirements_txt(self, mock_run_cmd): - """Test installs requirements.txt.""" - manager = RuntimeEnvironmentManager() - manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_create_conda_env(self, mock_get_conda, mock_run_cmd): - """Test creates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._create_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd): - """Test installs requirements.txt in conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_update_conda_env(self, mock_get_conda, mock_run_cmd): - """Test updates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._update_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("subprocess.Popen") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_export_conda_env_from_prefix(self, mock_get_conda, mock_popen, mock_file): - """Test exports conda environment.""" - mock_get_conda.return_value = "conda" - mock_process = MagicMock() - mock_process.communicate.return_value = (b"env output", b"") - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - - manager = RuntimeEnvironmentManager() - manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml") - - mock_popen.assert_called_once() - mock_file.assert_called_once_with("/tmp/env.yml", "w") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.getcwd") - def test_write_conda_env_to_file(self, mock_getcwd, mock_file): - """Test writes conda environment name to file.""" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - manager._write_conda_env_to_file("myenv") - mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w") - mock_file().write.assert_called_once_with("myenv") - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_mamba(self, mock_popen): - """Test returns mamba when available.""" - mock_popen.return_value.wait.side_effect = [0, 1] # mamba exists, conda doesn't - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "mamba" - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_conda(self, mock_popen): - """Test returns conda when mamba not available.""" - mock_popen.return_value.wait.side_effect = [1, 0] # mamba doesn't exist, conda does - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "conda" - - @patch("subprocess.Popen") - def test_get_conda_exe_raises_error(self, mock_popen): - """Test raises error when neither conda nor mamba available.""" - mock_popen.return_value.wait.return_value = 1 - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._get_conda_exe() - - @patch("subprocess.check_output") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output): - """Test gets Python version in conda environment.""" - mock_get_conda.return_value = "conda" - mock_check_output.return_value = b"Python 3.8.10" - manager = RuntimeEnvironmentManager() - result = manager._python_version_in_conda_env("myenv") - assert result == "3.8" - - @patch("subprocess.check_output") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output): - """Test raises error when getting Python version fails.""" - mock_get_conda.return_value = "conda" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._python_version_in_conda_env("myenv") - - def test_current_python_version(self): - """Test gets current Python version.""" - manager = RuntimeEnvironmentManager() - result = manager._current_python_version() - expected = f"{sys.version_info.major}.{sys.version_info.minor}" - assert result == expected - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_with_conda(self, mock_python_version): - """Test validates Python version with conda environment.""" - mock_python_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_mismatch_with_conda(self, mock_python_version): - """Test raises error on Python version mismatch with conda.""" - mock_python_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_without_conda(self, mock_current_version): - """Test validates Python version without conda environment.""" - mock_current_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", None) - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_mismatch_without_conda(self, mock_current_version): - """Test raises error on Python version mismatch without conda.""" - mock_current_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", None) - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_match(self, mock_current_version): - """Test validates matching SageMaker SDK version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception or warning - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version): - """Test logs warning on SageMaker SDK version mismatch.""" - mock_current_version.return_value = "2.101.0" - manager = RuntimeEnvironmentManager() - # Should log warning but not raise exception - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_none(self, mock_current_version): - """Test handles None client version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_sagemaker_pysdk_version(None) - - -class TestRunAndGetOutputShellCmd: - """Test _run_and_get_output_shell_cmd function.""" - - @patch("subprocess.check_output") - def test_runs_command_successfully(self, mock_check_output): - """Test runs command and returns output.""" - mock_check_output.return_value = b"command output" - result = _run_and_get_output_shell_cmd("echo test") - assert result == "command output" - - -class TestRunPreExecutionCommandScript: - """Test _run_pre_execution_command_script function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script successfully.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 0 - assert error_logs == "" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script that returns error.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 1 - assert error_logs == "Error message" - - -class TestRunShellCmd: - """Test _run_shell_cmd function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error): - """Test runs command successfully.""" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - _run_shell_cmd(["echo", "test"]) - - mock_popen.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error): - """Test raises error when command fails.""" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - with pytest.raises(RuntimeEnvironmentError): - _run_shell_cmd(["false"]) - - -class TestLogOutput: - """Test _log_output function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_output(self, mock_logger): - """Test logs process output.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stdout = BytesIO(b"line1\nline2\n") - - _log_output(mock_process) - - assert mock_logger.info.call_count == 2 - - -class TestLogError: - """Test _log_error function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_error(self, mock_logger): - """Test logs process errors.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n") - - error_logs = _log_error(mock_process) - - assert "ERROR: error message" in error_logs - assert "warning message" in error_logs - - -class TestPythonExecutable: - """Test _python_executable function.""" - - def test_returns_python_executable(self): - """Test returns Python executable path.""" - result = _python_executable() - assert result == sys.executable - - @patch("sys.executable", None) - def test_raises_error_if_no_executable(self): - """Test raises error if no Python executable.""" - with pytest.raises(RuntimeEnvironmentError): - _python_executable() - - -class TestRuntimeEnvironmentError: - """Test RuntimeEnvironmentError class.""" - - def test_creates_error_with_message(self): - """Test creates error with message.""" - error = RuntimeEnvironmentError("Test error") - assert str(error) == "Test error" - assert error.message == "Test error" diff --git a/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py deleted file mode 100644 index 98a5f8bcc8..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for checkpoint_location module.""" -from __future__ import absolute_import - -import pytest -from sagemaker.core.remote_function.checkpoint_location import ( - CheckpointLocation, - _validate_s3_uri_for_checkpoint, - _JOB_CHECKPOINT_LOCATION, -) - - -class TestValidateS3Uri: - """Test _validate_s3_uri_for_checkpoint function.""" - - def test_valid_s3_uri(self): - """Test valid s3:// URI.""" - assert _validate_s3_uri_for_checkpoint("s3://my-bucket/path/to/checkpoints") - - def test_valid_https_uri(self): - """Test valid https:// URI.""" - assert _validate_s3_uri_for_checkpoint("https://my-bucket.s3.amazonaws.com/path") - - def test_valid_s3_uri_no_path(self): - """Test valid s3:// URI without path.""" - assert _validate_s3_uri_for_checkpoint("s3://my-bucket") - - def test_invalid_uri_no_protocol(self): - """Test invalid URI without protocol.""" - assert not _validate_s3_uri_for_checkpoint("my-bucket/path") - - def test_invalid_uri_wrong_protocol(self): - """Test invalid URI with wrong protocol.""" - assert not _validate_s3_uri_for_checkpoint("http://my-bucket/path") - - def test_invalid_uri_empty(self): - """Test invalid empty URI.""" - assert not _validate_s3_uri_for_checkpoint("") - - -class TestCheckpointLocation: - """Test CheckpointLocation class.""" - - def test_init_with_valid_s3_uri(self): - """Test initialization with valid s3 URI.""" - s3_uri = "s3://my-bucket/checkpoints" - checkpoint_loc = CheckpointLocation(s3_uri) - assert checkpoint_loc._s3_uri == s3_uri - - def test_init_with_valid_https_uri(self): - """Test initialization with valid https URI.""" - s3_uri = "https://my-bucket.s3.amazonaws.com/checkpoints" - checkpoint_loc = CheckpointLocation(s3_uri) - assert checkpoint_loc._s3_uri == s3_uri - - def test_init_with_invalid_uri_raises_error(self): - """Test initialization with invalid URI raises ValueError.""" - with pytest.raises(ValueError, match="CheckpointLocation should be specified with valid s3 URI"): - CheckpointLocation("invalid-uri") - - def test_fspath_returns_local_path(self): - """Test __fspath__ returns the job local path.""" - checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") - assert checkpoint_loc.__fspath__() == _JOB_CHECKPOINT_LOCATION - - def test_can_be_used_as_pathlike(self): - """Test CheckpointLocation can be used as os.PathLike.""" - import os - checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") - path = os.fspath(checkpoint_loc) - assert path == _JOB_CHECKPOINT_LOCATION diff --git a/sagemaker-core/tests/unit/remote_function/test_client.py b/sagemaker-core/tests/unit/remote_function/test_client.py deleted file mode 100644 index 83e1a2db80..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_client.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -from unittest.mock import Mock -from collections import deque - -from sagemaker.core.remote_function.client import ( - RemoteExecutor, - _submit_worker, - _polling_worker, - _API_CALL_LIMIT, - _PENDING, - _RUNNING, - _CANCELLED, - _FINISHED, -) - - -class TestConstants: - """Test module constants""" - - def test_api_call_limit_constants(self): - assert _API_CALL_LIMIT["SubmittingIntervalInSecs"] == 1 - assert _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"] == 10 - assert _API_CALL_LIMIT["PollingIntervalInSecs"] == 0.5 - - def test_future_state_constants(self): - assert _PENDING == "PENDING" - assert _RUNNING == "RUNNING" - assert _CANCELLED == "CANCELLED" - assert _FINISHED == "FINISHED" - - -class TestRemoteExecutorValidation: - """Test RemoteExecutor argument validation""" - - def test_validate_submit_args_with_valid_args(self): - def my_function(x, y, z=10): - return x + y + z - - RemoteExecutor._validate_submit_args(my_function, 1, 2, z=3) - - def test_validate_submit_args_with_missing_args(self): - def my_function(x, y): - return x + y - - with pytest.raises(TypeError): - RemoteExecutor._validate_submit_args(my_function, 1) - - def test_validate_submit_args_with_extra_args(self): - def my_function(x): - return x - - with pytest.raises(TypeError): - RemoteExecutor._validate_submit_args(my_function, 1, 2) - - -class TestWorkerFunctions: - """Test worker thread functions""" - - def test_submit_worker_exits_on_none(self): - """Test that submit worker exits when None is in queue""" - executor = Mock() - executor._pending_request_queue = deque([None]) - executor._running_jobs = {} - executor.max_parallel_jobs = 1 - - mock_condition = Mock() - mock_condition.__enter__ = Mock(return_value=mock_condition) - mock_condition.__exit__ = Mock(return_value=False) - mock_condition.wait_for = Mock(return_value=True) - executor._state_condition = mock_condition - - _submit_worker(executor) - - assert len(executor._pending_request_queue) == 0 - - def test_polling_worker_exits_on_shutdown(self): - """Test that polling worker exits when shutdown flag is set""" - executor = Mock() - executor._running_jobs = {} - executor._pending_request_queue = deque() - executor._shutdown = True - executor._state_condition = Mock() - - _polling_worker(executor) diff --git a/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py deleted file mode 100644 index 5145a77adf..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for custom_file_filter module.""" -from __future__ import absolute_import - -import os -import tempfile -import shutil -from unittest.mock import patch, MagicMock -import pytest - -from sagemaker.core.remote_function.custom_file_filter import ( - CustomFileFilter, - resolve_custom_file_filter_from_config_file, - copy_workdir, -) - - -class TestCustomFileFilter: - """Test CustomFileFilter class.""" - - def test_init_with_no_patterns(self): - """Test initialization without ignore patterns.""" - filter_obj = CustomFileFilter() - assert filter_obj.ignore_name_patterns == [] - assert filter_obj.workdir == os.getcwd() - - def test_init_with_patterns(self): - """Test initialization with ignore patterns.""" - patterns = ["*.pyc", "__pycache__", "*.log"] - filter_obj = CustomFileFilter(ignore_name_patterns=patterns) - assert filter_obj.ignore_name_patterns == patterns - - def test_ignore_name_patterns_property(self): - """Test ignore_name_patterns property.""" - patterns = ["*.txt", "temp*"] - filter_obj = CustomFileFilter(ignore_name_patterns=patterns) - assert filter_obj.ignore_name_patterns == patterns - - def test_workdir_property(self): - """Test workdir property.""" - filter_obj = CustomFileFilter() - assert filter_obj.workdir == os.getcwd() - - -class TestResolveCustomFileFilterFromConfigFile: - """Test resolve_custom_file_filter_from_config_file function.""" - - def test_returns_direct_input_when_provided_as_filter(self): - """Test returns direct input when CustomFileFilter is provided.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.pyc"]) - result = resolve_custom_file_filter_from_config_file(direct_input=filter_obj) - assert result is filter_obj - - def test_returns_direct_input_when_provided_as_callable(self): - """Test returns direct input when callable is provided.""" - def custom_filter(path, names): - return [] - result = resolve_custom_file_filter_from_config_file(direct_input=custom_filter) - assert result is custom_filter - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_returns_none_when_no_config(self, mock_resolve): - """Test returns None when no config is found.""" - mock_resolve.return_value = None - result = resolve_custom_file_filter_from_config_file() - assert result is None - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_creates_filter_from_config(self, mock_resolve): - """Test creates CustomFileFilter from config.""" - patterns = ["*.pyc", "*.log"] - mock_resolve.return_value = patterns - result = resolve_custom_file_filter_from_config_file() - assert isinstance(result, CustomFileFilter) - assert result.ignore_name_patterns == patterns - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_passes_sagemaker_session_to_resolve(self, mock_resolve): - """Test passes sagemaker_session to resolve_value_from_config.""" - mock_session = MagicMock() - mock_resolve.return_value = None - resolve_custom_file_filter_from_config_file(sagemaker_session=mock_session) - mock_resolve.assert_called_once() - assert mock_resolve.call_args[1]["sagemaker_session"] == mock_session - - -class TestCopyWorkdir: - """Test copy_workdir function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.temp_src = tempfile.mkdtemp() - self.temp_dst = tempfile.mkdtemp() - - # Create test files - with open(os.path.join(self.temp_src, "test.py"), "w") as f: - f.write("print('test')") - with open(os.path.join(self.temp_src, "test.txt"), "w") as f: - f.write("text file") - os.makedirs(os.path.join(self.temp_src, "__pycache__")) - with open(os.path.join(self.temp_src, "__pycache__", "test.pyc"), "w") as f: - f.write("compiled") - - def teardown_method(self): - """Clean up test fixtures.""" - if os.path.exists(self.temp_src): - shutil.rmtree(self.temp_src) - if os.path.exists(self.temp_dst): - shutil.rmtree(self.temp_dst) - - @patch("os.getcwd") - def test_copy_workdir_without_filter_only_python_files(self, mock_getcwd): - """Test copy_workdir without filter copies only Python files.""" - mock_getcwd.return_value = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - assert not os.path.exists(os.path.join(dst, "__pycache__")) - - @patch("os.getcwd") - def test_copy_workdir_with_callable_filter(self, mock_getcwd): - """Test copy_workdir with callable filter.""" - mock_getcwd.return_value = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - def custom_filter(path, names): - return ["test.txt"] - - copy_workdir(dst, custom_file_filter=custom_filter) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - - def test_copy_workdir_with_custom_file_filter_object(self): - """Test copy_workdir with CustomFileFilter object.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.py"]) - filter_obj._workdir = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst, custom_file_filter=filter_obj) - - assert not os.path.exists(os.path.join(dst, "test.py")) - assert os.path.exists(os.path.join(dst, "test.txt")) - - def test_copy_workdir_with_pattern_matching(self): - """Test copy_workdir with pattern matching in CustomFileFilter.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.txt", "__pycache__"]) - filter_obj._workdir = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst, custom_file_filter=filter_obj) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - assert not os.path.exists(os.path.join(dst, "__pycache__")) diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py deleted file mode 100644 index 4810eba2e0..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for invoke_function module.""" -from __future__ import absolute_import - -import json -import pytest -from unittest.mock import patch, MagicMock, call - -from sagemaker.core.remote_function.invoke_function import ( - _parse_args, - _get_sagemaker_session, - _load_run_object, - _load_pipeline_context, - _execute_remote_function, - main, - SUCCESS_EXIT_CODE, -) -from sagemaker.core.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_required_args(self): - """Test parsing required arguments.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://my-bucket/path", - ] - parsed = _parse_args(args) - assert parsed.region == "us-west-2" - assert parsed.s3_base_uri == "s3://my-bucket/path" - - def test_parse_all_args(self): - """Test parsing all arguments.""" - args = [ - "--region", "us-east-1", - "--s3_base_uri", "s3://bucket/path", - "--s3_kms_key", "key-123", - "--run_in_context", '{"experiment": "exp1"}', - "--pipeline_step_name", "step1", - "--pipeline_execution_id", "exec-123", - "--property_references", "prop1", "val1", "prop2", "val2", - "--serialize_output_to_json", "true", - "--func_step_s3_dir", "s3://bucket/func", - ] - parsed = _parse_args(args) - assert parsed.region == "us-east-1" - assert parsed.s3_base_uri == "s3://bucket/path" - assert parsed.s3_kms_key == "key-123" - assert parsed.run_in_context == '{"experiment": "exp1"}' - assert parsed.pipeline_step_name == "step1" - assert parsed.pipeline_execution_id == "exec-123" - assert parsed.property_references == ["prop1", "val1", "prop2", "val2"] - assert parsed.serialize_output_to_json is True - assert parsed.func_step_s3_dir == "s3://bucket/func" - - def test_parse_serialize_output_false(self): - """Test parsing serialize_output_to_json as false.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://bucket/path", - "--serialize_output_to_json", "false", - ] - parsed = _parse_args(args) - assert parsed.serialize_output_to_json is False - - def test_parse_default_values(self): - """Test default values for optional arguments.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://bucket/path", - ] - parsed = _parse_args(args) - assert parsed.s3_kms_key is None - assert parsed.run_in_context is None - assert parsed.pipeline_step_name is None - assert parsed.pipeline_execution_id is None - assert parsed.property_references == [] - assert parsed.serialize_output_to_json is False - assert parsed.func_step_s3_dir is None - - -class TestGetSagemakerSession: - """Test _get_sagemaker_session function.""" - - @patch("sagemaker.core.remote_function.invoke_function.boto3.session.Session") - @patch("sagemaker.core.remote_function.invoke_function.Session") - def test_creates_session_with_region(self, mock_session_class, mock_boto_session): - """Test creates SageMaker session with correct region.""" - mock_boto = MagicMock() - mock_boto_session.return_value = mock_boto - - _get_sagemaker_session("us-west-2") - - mock_boto_session.assert_called_once_with(region_name="us-west-2") - mock_session_class.assert_called_once_with(boto_session=mock_boto) - - -class TestLoadRunObject: - """Test _load_run_object function.""" - - @patch("sagemaker.core.experiments.run.Run") - def test_loads_run_from_json(self, mock_run_class): - """Test loads Run object from JSON string.""" - run_dict = { - KEY_EXPERIMENT_NAME: "my-experiment", - KEY_RUN_NAME: "my-run", - } - run_json = json.dumps(run_dict) - mock_session = MagicMock() - - _load_run_object(run_json, mock_session) - - mock_run_class.assert_called_once_with( - experiment_name="my-experiment", - run_name="my-run", - sagemaker_session=mock_session, - ) - - -class TestLoadPipelineContext: - """Test _load_pipeline_context function.""" - - def test_loads_context_with_all_fields(self): - """Test loads pipeline context with all fields.""" - args = MagicMock() - args.pipeline_step_name = "step1" - args.pipeline_execution_id = "exec-123" - args.property_references = ["prop1", "val1", "prop2", "val2"] - args.serialize_output_to_json = True - args.func_step_s3_dir = "s3://bucket/func" - - context = _load_pipeline_context(args) - - assert context.step_name == "step1" - assert context.execution_id == "exec-123" - assert context.property_references == {"prop1": "val1", "prop2": "val2"} - assert context.serialize_output_to_json is True - assert context.func_step_s3_dir == "s3://bucket/func" - - def test_loads_context_with_empty_property_references(self): - """Test loads pipeline context with empty property references.""" - args = MagicMock() - args.pipeline_step_name = "step1" - args.pipeline_execution_id = "exec-123" - args.property_references = [] - args.serialize_output_to_json = False - args.func_step_s3_dir = None - - context = _load_pipeline_context(args) - - assert context.property_references == {} - - -class TestExecuteRemoteFunction: - """Test _execute_remote_function function.""" - - @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") - def test_executes_without_run_context(self, mock_stored_function_class): - """Test executes stored function without run context.""" - mock_stored_func = MagicMock() - mock_stored_function_class.return_value = mock_stored_func - mock_session = MagicMock() - mock_context = MagicMock() - - _execute_remote_function( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key="key-123", - run_in_context=None, - context=mock_context, - ) - - mock_stored_function_class.assert_called_once_with( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key="key-123", - context=mock_context, - ) - mock_stored_func.load_and_invoke.assert_called_once() - - @patch("sagemaker.core.remote_function.invoke_function._load_run_object") - @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") - def test_executes_with_run_context(self, mock_stored_function_class, mock_load_run): - """Test executes stored function with run context.""" - mock_stored_func = MagicMock() - mock_stored_function_class.return_value = mock_stored_func - mock_run = MagicMock() - mock_load_run.return_value = mock_run - mock_session = MagicMock() - mock_context = MagicMock() - run_json = '{"experiment": "exp1"}' - - _execute_remote_function( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key=None, - run_in_context=run_json, - context=mock_context, - ) - - # Verify run object was loaded and used as context manager - mock_load_run.assert_called_once_with(run_json, mock_session) - mock_run.__enter__.assert_called_once() - mock_run.__exit__.assert_called_once() - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.core.remote_function.invoke_function._parse_args") - def test_main_success(self, mock_parse, mock_load_context, mock_get_session, mock_execute): - """Test main function successful execution.""" - mock_args = MagicMock() - mock_args.region = "us-west-2" - mock_args.s3_base_uri = "s3://bucket/path" - mock_args.s3_kms_key = None - mock_args.run_in_context = None - mock_parse.return_value = mock_args - - mock_context = MagicMock() - mock_context.step_name = None - mock_load_context.return_value = mock_context - - mock_session = MagicMock() - mock_get_session.return_value = mock_session - - with pytest.raises(SystemExit) as exc_info: - main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_execute.assert_called_once() - - @patch("sagemaker.core.remote_function.invoke_function.handle_error") - @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.core.remote_function.invoke_function._parse_args") - def test_main_handles_exception( - self, mock_parse, mock_load_context, mock_get_session, mock_execute, mock_handle_error - ): - """Test main function handles exceptions.""" - mock_args = MagicMock() - mock_args.region = "us-west-2" - mock_args.s3_base_uri = "s3://bucket/path" - mock_args.s3_kms_key = None - mock_args.run_in_context = None - mock_parse.return_value = mock_args - - mock_context = MagicMock() - mock_context.step_name = None - mock_load_context.return_value = mock_context - - mock_session = MagicMock() - mock_get_session.return_value = mock_session - - test_exception = Exception("Test error") - mock_execute.side_effect = test_exception - mock_handle_error.return_value = 1 - - with pytest.raises(SystemExit) as exc_info: - main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) - - assert exc_info.value.code == 1 - mock_handle_error.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_job.py b/sagemaker-core/tests/unit/remote_function/test_job.py deleted file mode 100644 index 6f10016643..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_job.py +++ /dev/null @@ -1,932 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.remote_function.job module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import sys -from unittest.mock import Mock, patch, MagicMock, call, mock_open -from io import BytesIO - -from sagemaker.core.remote_function.job import ( - _JobSettings, - _Job, - _prepare_and_upload_runtime_scripts, - _generate_input_data_config, - _prepare_dependencies_and_pre_execution_scripts, - _prepare_and_upload_workspace, - _convert_run_to_json, - _prepare_and_upload_spark_dependent_files, - _upload_spark_submit_deps, - _upload_serialized_spark_configuration, - _extend_mpirun_to_request, - _extend_torchrun_to_request, - _extend_spark_config_to_request, - _update_job_request_with_checkpoint_config, - _RunInfo, - _get_initial_job_state, - _logs_for_job, - _check_job_status, - _flush_log_streams, - _rule_statuses_changed, - _logs_init, - LogState, -) -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation - - -@pytest.fixture -def mock_session(): - session = Mock() - session.boto_region_name = "us-west-2" - session.default_bucket.return_value = "test-bucket" - session.default_bucket_prefix = "prefix" - session.sagemaker_client = Mock() - session.boto_session = Mock() - session.sagemaker_config = None - return session - - -class TestJobSettings: - """Test _JobSettings class.""" - - def test_init_with_spark_and_image_raises_error(self, mock_session): - """Test that spark_config and image_uri cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="spark_config and image_uri cannot be specified"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - image_uri="test-image", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_spark_and_conda_env_raises_error(self, mock_session): - """Test that spark_config and job_conda_env cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support job_conda_env"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - job_conda_env="test-env", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_spark_and_auto_capture_raises_error(self, mock_session): - """Test that spark_config and auto_capture dependencies cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support automatically"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - dependencies="auto_capture", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_pre_execution_commands_and_script_raises_error(self, mock_session): - """Test that pre_execution_commands and pre_execution_script cannot be set together.""" - with pytest.raises( - ValueError, match="Only one of pre_execution_commands or pre_execution_script" - ): - _JobSettings( - sagemaker_session=mock_session, - pre_execution_commands=["echo test"], - pre_execution_script="/path/to/script.sh", - instance_type="ml.m5.xlarge", - image_uri="test-image", - ) - - def test_init_without_instance_type_raises_error(self, mock_session): - """Test that instance_type is required.""" - with pytest.raises(ValueError, match="instance_type is a required parameter"): - _JobSettings(sagemaker_session=mock_session, image_uri="test-image") - - @patch.dict(os.environ, {"SAGEMAKER_INTERNAL_IMAGE_URI": "custom-image"}) - def test_get_default_image_from_env(self, mock_session): - """Test getting default image from environment variable.""" - image = _JobSettings._get_default_image(mock_session) - assert image == "custom-image" - - def test_get_default_image_unsupported_python_raises_error(self, mock_session): - """Test that unsupported Python version raises error.""" - with patch.object(sys, "version_info", (3, 7, 0)): - with pytest.raises( - ValueError, match="Default image is supported only for Python versions" - ): - _JobSettings._get_default_image(mock_session) - - def test_get_default_spark_image_unsupported_python_raises_error(self, mock_session): - """Test that unsupported Python version for Spark raises error.""" - with patch.object(sys, "version_info", (3, 8, 0)): - with pytest.raises( - ValueError, - match="SageMaker Spark image for remote job only supports Python version 3.9", - ): - _JobSettings._get_default_spark_image(mock_session) - - -class TestJob: - """Test _Job class.""" - - def test_init(self, mock_session): - """Test _Job initialization.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - - def test_from_describe_response(self, mock_session): - """Test creating _Job from describe response.""" - response = { - "TrainingJobName": "test-job", - "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - } - job = _Job.from_describe_response(response, mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - - def test_describe_returns_cached_response(self, mock_session): - """Test that describe returns cached response for completed jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Completed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Completed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_calls_api_for_in_progress_jobs(self, mock_session): - """Test that describe calls API for in-progress jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_session.sagemaker_client.describe_training_job.return_value = { - "TrainingJobStatus": "InProgress" - } - - result = job.describe() - assert result["TrainingJobStatus"] == "InProgress" - mock_session.sagemaker_client.describe_training_job.assert_called_once() - - def test_stop(self, mock_session): - """Test stopping a job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job.stop() - mock_session.sagemaker_client.stop_training_job.assert_called_once_with( - TrainingJobName="test-job" - ) - - @patch("sagemaker.core.remote_function.job._logs_for_job") - def test_wait(self, mock_logs, mock_session): - """Test waiting for job completion.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_logs.return_value = {"TrainingJobStatus": "Completed"} - - job.wait(timeout=100) - mock_logs.assert_called_once_with( - sagemaker_session=mock_session, job_name="test-job", wait=True, timeout=100 - ) - - -class TestUpdateJobRequestWithCheckpointConfig: - """Test _update_job_request_with_checkpoint_config function.""" - - def test_with_checkpoint_in_args(self): - """Test checkpoint config in positional args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = (checkpoint,) - kwargs = {} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - assert request_dict["CheckpointConfig"]["LocalPath"] == "/opt/ml/checkpoints/" - - def test_with_checkpoint_in_kwargs(self): - """Test checkpoint config in keyword args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = () - kwargs = {"checkpoint": checkpoint} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - - def test_with_multiple_checkpoints_raises_error(self): - """Test that multiple checkpoints raise error.""" - checkpoint1 = CheckpointLocation(s3_uri="s3://bucket/checkpoint1") - checkpoint2 = CheckpointLocation(s3_uri="s3://bucket/checkpoint2") - args = (checkpoint1,) - kwargs = {"checkpoint": checkpoint2} - request_dict = {} - - with pytest.raises( - ValueError, match="cannot have more than one argument of type CheckpointLocation" - ): - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - def test_without_checkpoint(self): - """Test without checkpoint location.""" - args = ("arg1", "arg2") - kwargs = {"key": "value"} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" not in request_dict - - -class TestConvertRunToJson: - """Test _convert_run_to_json function.""" - - def test_convert_run_to_json(self): - """Test converting run to JSON.""" - mock_run = Mock() - mock_run.experiment_name = "test-experiment" - mock_run.run_name = "test-run" - - result = _convert_run_to_json(mock_run) - data = json.loads(result) - - assert data["experiment_name"] == "test-experiment" - assert data["run_name"] == "test-run" - - -class TestUploadSerializedSparkConfiguration: - """Test _upload_serialized_spark_configuration function.""" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_upload_spark_config(self, mock_uploader, mock_session): - """Test uploading Spark configuration.""" - config = {"spark.executor.memory": "4g"} - - _upload_serialized_spark_configuration("s3://bucket/base", "kms-key", config, mock_session) - - mock_uploader.upload_string_as_file_body.assert_called_once() - - def test_upload_spark_config_none(self, mock_session): - """Test uploading None Spark configuration.""" - result = _upload_serialized_spark_configuration( - "s3://bucket/base", "kms-key", None, mock_session - ) - - assert result is None - - -class TestUploadSparkSubmitDeps: - """Test _upload_spark_submit_deps function.""" - - def test_with_none_deps(self, mock_session): - """Test with None dependencies.""" - result = _upload_spark_submit_deps( - None, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert result is None - - def test_with_s3_uri(self, mock_session): - """Test with S3 URI.""" - deps = ["s3://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3://bucket/dep.jar" in result - - def test_with_empty_workspace_raises_error(self, mock_session): - """Test with empty workspace name.""" - deps = ["s3://bucket/dep.jar"] - with pytest.raises(ValueError, match="workspace_name or s3_base_uri may not be empty"): - _upload_spark_submit_deps(deps, "", "s3://bucket", "kms-key", mock_session) - - @patch("os.path.isfile", return_value=False) - def test_with_invalid_local_file_raises_error(self, mock_isfile, mock_session): - """Test with invalid local file.""" - deps = ["/invalid/path.jar"] - with pytest.raises(ValueError, match="is not a valid local file"): - _upload_spark_submit_deps(deps, "workspace", "s3://bucket", "kms-key", mock_session) - - -class TestExtendMpirunToRequest: - """Test _extend_mpirun_to_request function.""" - - def test_without_mpirun(self, mock_session): - """Test without mpirun enabled.""" - job_settings = Mock() - job_settings.use_mpirun = False - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_single_instance(self, mock_session): - """Test with single instance.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_multiple_instances(self, mock_session): - """Test with multiple instances.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestExtendTorchrunToRequest: - """Test _extend_torchrun_to_request function.""" - - def test_without_torchrun(self, mock_session): - """Test without torchrun enabled.""" - job_settings = Mock() - job_settings.use_torchrun = False - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_single_instance(self, mock_session): - """Test with single instance.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_multiple_instances(self, mock_session): - """Test with multiple instances.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestExtendSparkConfigToRequest: - """Test _extend_spark_config_to_request function.""" - - def test_without_spark_config(self, mock_session): - """Test without spark config.""" - job_settings = Mock() - job_settings.spark_config = None - request_dict = {"AlgorithmSpecification": {"ContainerEntrypoint": []}} - - result = _extend_spark_config_to_request(request_dict, job_settings, "s3://bucket") - assert result == request_dict - - @patch("sagemaker.core.remote_function.job._prepare_and_upload_spark_dependent_files") - def test_with_spark_config(self, mock_upload, mock_session): - """Test with spark config.""" - mock_upload.return_value = (None, None, None, "s3://bucket/config.json") - - job_settings = Mock() - spark_config = SparkConfig(spark_event_logs_uri="s3://bucket/logs") - job_settings.spark_config = spark_config - job_settings.s3_kms_key = None - job_settings.sagemaker_session = mock_session - - request_dict = { - "AlgorithmSpecification": {"ContainerEntrypoint": []}, - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}], - } - - result = _extend_spark_config_to_request(request_dict, job_settings, "s3://bucket") - assert ( - "--spark-event-logs-s3-uri" in result["AlgorithmSpecification"]["ContainerEntrypoint"] - ) - - -class TestGetInitialJobState: - """Test _get_initial_job_state function.""" - - def test_with_completed_job_and_wait(self): - """Test with completed job and wait=True.""" - description = {"TrainingJobStatus": "Completed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_with_in_progress_job_and_wait(self): - """Test with in-progress job and wait=True.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.TAILING - - def test_with_in_progress_job_and_no_wait(self): - """Test with in-progress job and wait=False.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", False) - assert state == LogState.COMPLETE - - -class TestCheckJobStatus: - """Test _check_job_status function.""" - - def test_with_completed_status(self): - """Test with completed status.""" - desc = {"TrainingJobStatus": "Completed"} - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_with_stopped_status(self): - """Test with stopped status.""" - desc = {"TrainingJobStatus": "Stopped"} - with patch("sagemaker.core.remote_function.job.logger") as mock_logger: - _check_job_status("test-job", desc, "TrainingJobStatus") - mock_logger.warning.assert_called_once() - - def test_with_failed_status_raises_error(self): - """Test with failed status.""" - desc = {"TrainingJobStatus": "Failed", "FailureReason": "Test failure"} - with pytest.raises(Exception): - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_with_capacity_error_raises_capacity_error(self): - """Test with CapacityError.""" - desc = { - "TrainingJobStatus": "Failed", - "FailureReason": "CapacityError: Insufficient capacity", - } - from sagemaker.core import exceptions - - with pytest.raises(exceptions.CapacityError): - _check_job_status("test-job", desc, "TrainingJobStatus") - - -class TestRuleStatusesChanged: - """Test _rule_statuses_changed function.""" - - def test_with_no_last_statuses(self): - """Test with no last statuses.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, None) - assert result is True - - def test_with_changed_status(self): - """Test with changed status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "Completed"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is True - - def test_with_unchanged_status(self): - """Test with unchanged status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is False - - -class TestLogsInit: - """Test _logs_init function.""" - - def test_with_training_job(self, mock_session): - """Test with training job.""" - description = {"ResourceConfig": {"InstanceCount": 2}} - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 2 - assert log_group == "/aws/sagemaker/TrainingJobs" - - def test_with_training_job_instance_groups(self, mock_session): - """Test with training job using instance groups.""" - description = { - "ResourceConfig": {"InstanceGroups": [{"InstanceCount": 2}, {"InstanceCount": 3}]} - } - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 5 - - def test_with_transform_job(self, mock_session): - """Test with transform job.""" - description = {"TransformResources": {"InstanceCount": 1}} - result = _logs_init(mock_session.boto_session, description, "Transform") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 1 - assert log_group == "/aws/sagemaker/TransformJobs" - - def test_with_processing_job(self, mock_session): - """Test with processing job.""" - description = {"ProcessingResources": {"ClusterConfig": {"InstanceCount": 3}}} - result = _logs_init(mock_session.boto_session, description, "Processing") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 3 - assert log_group == "/aws/sagemaker/ProcessingJobs" - - def test_with_automl_job(self, mock_session): - """Test with AutoML job.""" - description = {} - result = _logs_init(mock_session.boto_session, description, "AutoML") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 0 - assert log_group == "/aws/sagemaker/AutoMLJobs" - - -class TestFlushLogStreams: - """Test _flush_log_streams function.""" - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_no_streams(self, mock_logs, mock_session): - """Test with no log streams.""" - stream_names = [] - positions = {} - client = Mock() - client.describe_log_streams.return_value = {"logStreams": []} - - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_client_error_resource_not_found(self, mock_logs, mock_session): - """Test with ResourceNotFoundException.""" - from botocore.exceptions import ClientError - - stream_names = [] - positions = {} - client = Mock() - error_response = {"Error": {"Code": "ResourceNotFoundException"}} - client.describe_log_streams.side_effect = ClientError( - error_response, "describe_log_streams" - ) - - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_client_error_other(self, mock_logs, mock_session): - """Test with other ClientError.""" - from botocore.exceptions import ClientError - - stream_names = [] - positions = {} - client = Mock() - error_response = {"Error": {"Code": "OtherError"}} - client.describe_log_streams.side_effect = ClientError( - error_response, "describe_log_streams" - ) - - with pytest.raises(ClientError): - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - -class TestPrepareAndUploadRuntimeScripts: - """Test _prepare_and_upload_runtime_scripts function.""" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_without_spark_or_distributed( - self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session - ): - """Test without Spark or distributed training.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, False, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_spark(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with Spark config.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - spark_config = SparkConfig() - result = _prepare_and_upload_runtime_scripts( - spark_config, "s3://bucket", "kms-key", mock_session, False, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_torchrun(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with torchrun.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, True, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_mpirun(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with mpirun.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, False, True - ) - - assert result == "s3://bucket/scripts" - - -class TestPrepareAndUploadWorkspace: - """Test _prepare_and_upload_workspace function.""" - - def test_without_dependencies_or_workdir(self, mock_session): - """Test without dependencies or workdir.""" - result = _prepare_and_upload_workspace( - None, False, None, None, "s3://bucket", "kms-key", mock_session, None - ) - assert result is None - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.copy_workdir") - @patch("os.mkdir") - @patch("os.path.isdir", return_value=False) - def test_with_workdir( - self, - mock_isdir, - mock_mkdir, - mock_copy, - mock_shutil, - mock_tmpdir, - mock_uploader, - mock_session, - ): - """Test with workdir.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_shutil.make_archive.return_value = "/tmp/test/workspace.zip" - mock_uploader.upload.return_value = "s3://bucket/workspace.zip" - - result = _prepare_and_upload_workspace( - None, True, None, None, "s3://bucket", "kms-key", mock_session, None - ) - - assert result == "s3://bucket/workspace.zip" - - -class TestPrepareDependenciesAndPreExecutionScripts: - """Test _prepare_dependencies_and_pre_execution_scripts function.""" - - def test_without_dependencies_or_scripts(self, mock_session): - """Test without dependencies or scripts.""" - result = _prepare_dependencies_and_pre_execution_scripts( - None, None, None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - assert result is None - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_dependencies(self, mock_uploader, mock_shutil, mock_context, mock_session): - """Test with dependencies file.""" - mock_shutil.copy2.return_value = "/tmp/requirements.txt" - mock_uploader.upload.return_value = "s3://bucket/deps" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - "/path/to/requirements.txt", None, None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/deps" - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("builtins.open", create=True) - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_pre_execution_commands( - self, mock_uploader, mock_open, mock_context, mock_session - ): - """Test with pre-execution commands.""" - mock_uploader.upload.return_value = "s3://bucket/scripts" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - None, ["echo test"], None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_pre_execution_script( - self, mock_uploader, mock_shutil, mock_context, mock_session - ): - """Test with pre-execution script.""" - mock_shutil.copy2.return_value = "/tmp/pre_exec.sh" - mock_uploader.upload.return_value = "s3://bucket/scripts" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - None, None, "/path/to/script.sh", "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/scripts" - - -class TestPrepareAndUploadSparkDependentFiles: - """Test _prepare_and_upload_spark_dependent_files function.""" - - def test_without_spark_config(self, mock_session): - """Test without Spark config.""" - result = _prepare_and_upload_spark_dependent_files( - None, "s3://bucket", "kms-key", mock_session - ) - assert result == (None, None, None, None) - - @patch("sagemaker.core.remote_function.job._upload_spark_submit_deps") - @patch("sagemaker.core.remote_function.job._upload_serialized_spark_configuration") - def test_with_spark_config(self, mock_upload_config, mock_upload_deps, mock_session): - """Test with Spark config.""" - mock_upload_deps.return_value = "s3://bucket/deps" - mock_upload_config.return_value = "s3://bucket/config.json" - - spark_config = SparkConfig( - submit_jars=["test.jar"], - submit_py_files=["test.py"], - submit_files=["test.txt"], - configuration={"Classification": "spark-defaults", "Properties": {"key": "value"}}, - ) - - result = _prepare_and_upload_spark_dependent_files( - spark_config, "s3://bucket", "kms-key", mock_session - ) - - assert len(result) == 4 - - -class TestJobCompile: - """Test _Job.compile method.""" - - @patch("sagemaker.core.remote_function.job.StoredFunction") - @patch("sagemaker.core.remote_function.job._generate_input_data_config") - def test_compile_basic(self, mock_input_config, mock_stored_func, mock_session): - """Test basic compile.""" - mock_input_config.return_value = [] - mock_stored_func.return_value.save = Mock() - - job_settings = Mock() - job_settings.max_runtime_in_seconds = 3600 - job_settings.max_wait_time_in_seconds = None - job_settings.max_retry_attempts = 1 - job_settings.role = "arn:aws:iam::123456789012:role/test" - job_settings.tags = None - job_settings.s3_kms_key = None - job_settings.disable_output_compression = False - job_settings.volume_size = 30 - job_settings.instance_count = 1 - job_settings.instance_type = "ml.m5.xlarge" - job_settings.volume_kms_key = None - job_settings.keep_alive_period_in_seconds = None - job_settings.enable_network_isolation = False - job_settings.encrypt_inter_container_traffic = False - job_settings.vpc_config = None - job_settings.use_spot_instances = False - job_settings.environment_variables = {} - job_settings.image_uri = "test-image" - job_settings.sagemaker_session = mock_session - job_settings.use_torchrun = False - job_settings.use_mpirun = False - job_settings.nproc_per_node = None - job_settings.job_conda_env = None - job_settings.spark_config = None - job_settings.dependencies = None - - def test_func(): - pass - - result = _Job.compile(job_settings, "test-job", "s3://bucket", test_func, (), {}) - - assert result["TrainingJobName"] == "test-job" - assert result["RoleArn"] == "arn:aws:iam::123456789012:role/test" - - -class TestJobStart: - """Test _Job.start method.""" - - @patch("sagemaker.core.remote_function.job._Job.compile") - @patch("sagemaker.core.remote_function.job._Job._get_job_name") - def test_start(self, mock_get_name, mock_compile, mock_session): - """Test starting a job.""" - mock_get_name.return_value = "test-job" - mock_compile.return_value = { - "TrainingJobName": "test-job", - "Environment": {}, - } - - job_settings = Mock() - job_settings.s3_root_uri = "s3://bucket" - job_settings.sagemaker_session = mock_session - - def test_func(): - pass - - job = _Job.start(job_settings, test_func, (), {}) - - assert job.job_name == "test-job" - mock_session.sagemaker_client.create_training_job.assert_called_once() - - -class TestJobGetJobName: - """Test _Job._get_job_name method.""" - - def test_with_job_name_prefix(self, mock_session): - """Test with job_name_prefix.""" - job_settings = Mock() - job_settings.job_name_prefix = "my-job" - - def test_func(): - pass - - result = _Job._get_job_name(job_settings, test_func) - assert "my-job" in result - - def test_without_job_name_prefix(self, mock_session): - """Test without job_name_prefix.""" - job_settings = Mock() - job_settings.job_name_prefix = None - - def test_func(): - pass - - result = _Job._get_job_name(job_settings, test_func) - assert "test-func" in result - - def test_with_special_characters_in_func_name(self, mock_session): - """Test with special characters in function name.""" - job_settings = Mock() - job_settings.job_name_prefix = None - - def _test_func(): - pass - - result = _Job._get_job_name(job_settings, _test_func) - assert result.startswith("test-func") diff --git a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py deleted file mode 100644 index bc8d5a8e56..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Comprehensive unit tests for uncovered lines in sagemaker.core.remote_function.job module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import sys -import tempfile -from unittest.mock import Mock, patch, MagicMock, mock_open -from io import BytesIO - -from sagemaker.core.remote_function.job import ( - _JobSettings, - _Job, - _update_job_request_with_checkpoint_config, - _convert_run_to_json, - _upload_spark_submit_deps, - _upload_serialized_spark_configuration, - _extend_mpirun_to_request, - _extend_torchrun_to_request, - _check_job_status, - _rule_statuses_changed, - _logs_init, - _get_initial_job_state, - LogState, - _RunInfo, -) -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation - - -@pytest.fixture -def mock_session(): - session = Mock() - session.boto_region_name = "us-west-2" - session.default_bucket.return_value = "test-bucket" - session.default_bucket_prefix = "prefix" - session.sagemaker_client = Mock() - session.boto_session = Mock() - session.sagemaker_config = {} - return session - - -class TestJobSettingsValidation: - """Test _JobSettings validation logic for uncovered lines.""" - - def test_spark_config_with_image_uri_raises_error(self, mock_session): - """Test lines 619-620: spark_config and image_uri validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="spark_config and image_uri cannot be specified"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - image_uri="test-image", - instance_type="ml.m5.xlarge", - ) - - def test_spark_config_with_conda_env_raises_error(self, mock_session): - """Test lines 622-623: spark_config and job_conda_env validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support job_conda_env"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - job_conda_env="test-env", - instance_type="ml.m5.xlarge", - ) - - def test_spark_config_with_auto_capture_raises_error(self, mock_session): - """Test lines 625-628: spark_config and auto_capture validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support automatically"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - dependencies="auto_capture", - instance_type="ml.m5.xlarge", - ) - - def test_pre_execution_commands_and_script_raises_error(self, mock_session): - """Test lines 651-653: pre_execution validation.""" - with pytest.raises( - ValueError, match="Only one of pre_execution_commands or pre_execution_script" - ): - _JobSettings( - sagemaker_session=mock_session, - pre_execution_commands=["echo test"], - pre_execution_script="/path/to/script.sh", - instance_type="ml.m5.xlarge", - image_uri="test-image", - ) - - def test_instance_type_required(self, mock_session): - """Test lines 665-666: instance_type validation.""" - with pytest.raises(ValueError, match="instance_type is a required parameter"): - _JobSettings(sagemaker_session=mock_session, image_uri="test-image") - - @patch.dict(os.environ, {"SAGEMAKER_INTERNAL_IMAGE_URI": "custom-image"}) - def test_get_default_image_from_env(self, mock_session): - """Test lines 785-788: get default image from environment.""" - image = _JobSettings._get_default_image(mock_session) - assert image == "custom-image" - - def test_get_default_image_unsupported_python(self, mock_session): - """Test lines 792-795: unsupported Python version.""" - with patch.object(sys, "version_info", (3, 7, 0)): - with pytest.raises( - ValueError, match="Default image is supported only for Python versions" - ): - _JobSettings._get_default_image(mock_session) - - def test_get_default_spark_image_unsupported_python(self, mock_session): - """Test lines 815-817: unsupported Python for Spark.""" - with patch.object(sys, "version_info", (3, 8, 0)): - with pytest.raises( - ValueError, - match="SageMaker Spark image for remote job only supports Python version 3.9", - ): - _JobSettings._get_default_spark_image(mock_session) - - -class TestJobMethods: - """Test _Job class methods for uncovered lines.""" - - def test_from_describe_response(self, mock_session): - """Test lines 848-852: from_describe_response method.""" - response = { - "TrainingJobName": "test-job", - "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - } - job = _Job.from_describe_response(response, mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - assert job._last_describe_response == response - - def test_describe_cached_completed(self, mock_session): - """Test lines 865-871: describe with cached completed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Completed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Completed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_cached_failed(self, mock_session): - """Test lines 865-871: describe with cached failed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Failed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Failed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_cached_stopped(self, mock_session): - """Test lines 865-871: describe with cached stopped job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Stopped"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Stopped" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_stop(self, mock_session): - """Test lines 886-887: stop method.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job.stop() - mock_session.sagemaker_client.stop_training_job.assert_called_once_with( - TrainingJobName="test-job" - ) - - @patch("sagemaker.core.remote_function.job._logs_for_job") - def test_wait(self, mock_logs, mock_session): - """Test lines 889-903: wait method.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_logs.return_value = {"TrainingJobStatus": "Completed"} - - job.wait(timeout=100) - mock_logs.assert_called_once_with( - sagemaker_session=mock_session, job_name="test-job", wait=True, timeout=100 - ) - assert job._last_describe_response["TrainingJobStatus"] == "Completed" - - -class TestCheckpointConfig: - """Test checkpoint configuration for uncovered lines.""" - - def test_checkpoint_in_args(self): - """Test lines 1219-1227: checkpoint in positional args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = (checkpoint,) - kwargs = {} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - assert request_dict["CheckpointConfig"]["LocalPath"] == "/opt/ml/checkpoints/" - - def test_checkpoint_in_kwargs(self): - """Test lines 1228-1230: checkpoint in keyword args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = () - kwargs = {"checkpoint": checkpoint} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - - def test_multiple_checkpoints_raises_error(self): - """Test lines 1237-1239: multiple checkpoints error.""" - checkpoint1 = CheckpointLocation(s3_uri="s3://bucket/checkpoint1") - checkpoint2 = CheckpointLocation(s3_uri="s3://bucket/checkpoint2") - args = (checkpoint1,) - kwargs = {"checkpoint": checkpoint2} - request_dict = {} - - with pytest.raises( - ValueError, match="cannot have more than one argument of type CheckpointLocation" - ): - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - def test_no_checkpoint(self): - """Test lines 1232-1233: no checkpoint location.""" - args = ("arg1", "arg2") - kwargs = {"key": "value"} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" not in request_dict - - -class TestConvertRunToJson: - """Test _convert_run_to_json for uncovered lines.""" - - def test_convert_run(self): - """Test lines 1276-1278: convert run to JSON.""" - mock_run = Mock() - mock_run.experiment_name = "test-experiment" - mock_run.run_name = "test-run" - - result = _convert_run_to_json(mock_run) - data = json.loads(result) - - assert data["experiment_name"] == "test-experiment" - assert data["run_name"] == "test-run" - - -class TestSparkDependencies: - """Test Spark dependency functions for uncovered lines.""" - - def test_upload_spark_config_none(self, mock_session): - """Test lines 1356: upload None Spark configuration.""" - result = _upload_serialized_spark_configuration( - "s3://bucket/base", "kms-key", None, mock_session - ) - assert result is None - - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_upload_spark_config(self, mock_uploader, mock_session): - """Test lines 1339-1356: upload Spark configuration.""" - config = {"spark.executor.memory": "4g"} - mock_uploader.upload_string_as_file_body = Mock() - - _upload_serialized_spark_configuration("s3://bucket/base", "kms-key", config, mock_session) - - mock_uploader.upload_string_as_file_body.assert_called_once() - - def test_upload_spark_deps_none(self, mock_session): - """Test lines 1379-1380: None dependencies.""" - result = _upload_spark_submit_deps( - None, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert result is None - - def test_upload_spark_deps_s3_uri(self, mock_session): - """Test lines 1388-1389: S3 URI dependency.""" - deps = ["s3://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3://bucket/dep.jar" in result - - def test_upload_spark_deps_s3a_uri(self, mock_session): - """Test lines 1388-1389: S3A URI dependency.""" - deps = ["s3a://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3a://bucket/dep.jar" in result - - def test_upload_spark_deps_empty_workspace_raises_error(self, mock_session): - """Test lines 1382-1383: empty workspace validation.""" - deps = ["s3://bucket/dep.jar"] - with pytest.raises(ValueError, match="workspace_name or s3_base_uri may not be empty"): - _upload_spark_submit_deps(deps, "", "s3://bucket", "kms-key", mock_session) - - @patch("os.path.isfile", return_value=False) - def test_upload_spark_deps_invalid_file_raises_error(self, mock_isfile, mock_session): - """Test lines 1391-1392: invalid local file.""" - deps = ["/invalid/path.jar"] - with pytest.raises(ValueError, match="is not a valid local file"): - _upload_spark_submit_deps(deps, "workspace", "s3://bucket", "kms-key", mock_session) - - -class TestDistributedTraining: - """Test distributed training functions for uncovered lines.""" - - def test_extend_mpirun_no_mpirun(self, mock_session): - """Test lines 1441-1442: mpirun disabled.""" - job_settings = Mock() - job_settings.use_mpirun = False - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_mpirun_single_instance(self, mock_session): - """Test lines 1444-1445: single instance.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_mpirun_multiple_instances(self, mock_session): - """Test lines 1447-1453: multiple instances.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - def test_extend_torchrun_no_torchrun(self, mock_session): - """Test lines 1506-1507: torchrun disabled.""" - job_settings = Mock() - job_settings.use_torchrun = False - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_torchrun_single_instance(self, mock_session): - """Test lines 1524-1525: single instance.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_torchrun_multiple_instances(self, mock_session): - """Test lines 1527-1533: multiple instances.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestJobStatus: - """Test job status functions for uncovered lines.""" - - def test_check_job_status_completed(self): - """Test lines 1978-1979: completed status.""" - desc = {"TrainingJobStatus": "Completed"} - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_check_job_status_stopped(self): - """Test lines 1978-1986: stopped status.""" - desc = {"TrainingJobStatus": "Stopped"} - with patch("sagemaker.core.remote_function.job.logger") as mock_logger: - _check_job_status("test-job", desc, "TrainingJobStatus") - mock_logger.warning.assert_called_once() - - def test_check_job_status_failed(self): - """Test lines 1987-2011: failed status.""" - desc = {"TrainingJobStatus": "Failed", "FailureReason": "Test failure"} - from sagemaker.core import exceptions - - with pytest.raises(exceptions.UnexpectedStatusException): - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_check_job_status_capacity_error(self): - """Test lines 2002-2007: CapacityError.""" - desc = { - "TrainingJobStatus": "Failed", - "FailureReason": "CapacityError: Insufficient capacity", - } - from sagemaker.core import exceptions - - with pytest.raises(exceptions.CapacityError): - _check_job_status("test-job", desc, "TrainingJobStatus") - - -class TestRuleStatuses: - """Test rule status functions for uncovered lines.""" - - def test_rule_statuses_no_last(self): - """Test lines 2092-2093: no last statuses.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, None) - assert result is True - - def test_rule_statuses_changed(self): - """Test lines 2095-2098: changed status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "Completed"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is True - - def test_rule_statuses_unchanged(self): - """Test lines 2100: unchanged status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is False - - -class TestLogsInit: - """Test _logs_init function for uncovered lines.""" - - def test_logs_init_training_job(self, mock_session): - """Test lines 2098-2105: training job.""" - description = {"ResourceConfig": {"InstanceCount": 2}} - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 2 - assert log_group == "/aws/sagemaker/TrainingJobs" - - def test_logs_init_training_job_instance_groups(self, mock_session): - """Test lines 2098-2103: training job with instance groups.""" - description = { - "ResourceConfig": {"InstanceGroups": [{"InstanceCount": 2}, {"InstanceCount": 3}]} - } - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 5 - - def test_logs_init_transform_job(self, mock_session): - """Test lines 2106-2107: transform job.""" - description = {"TransformResources": {"InstanceCount": 1}} - result = _logs_init(mock_session.boto_session, description, "Transform") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 1 - assert log_group == "/aws/sagemaker/TransformJobs" - - def test_logs_init_processing_job(self, mock_session): - """Test lines 2108-2109: processing job.""" - description = {"ProcessingResources": {"ClusterConfig": {"InstanceCount": 3}}} - result = _logs_init(mock_session.boto_session, description, "Processing") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 3 - assert log_group == "/aws/sagemaker/ProcessingJobs" - - def test_logs_init_automl_job(self, mock_session): - """Test lines 2110-2111: AutoML job.""" - description = {} - result = _logs_init(mock_session.boto_session, description, "AutoML") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 0 - assert log_group == "/aws/sagemaker/AutoMLJobs" - - -class TestGetInitialJobState: - """Test _get_initial_job_state for uncovered lines.""" - - def test_completed_with_wait(self): - """Test lines 2021-2023: completed job with wait.""" - description = {"TrainingJobStatus": "Completed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_failed_with_wait(self): - """Test lines 2021-2023: failed job with wait.""" - description = {"TrainingJobStatus": "Failed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_stopped_with_wait(self): - """Test lines 2021-2023: stopped job with wait.""" - description = {"TrainingJobStatus": "Stopped"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_in_progress_with_wait(self): - """Test lines 2022: in-progress job with wait.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.TAILING - - def test_in_progress_without_wait(self): - """Test lines 2022: in-progress job without wait.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", False) - assert state == LogState.COMPLETE diff --git a/sagemaker-core/tests/unit/remote_function/test_logging_config.py b/sagemaker-core/tests/unit/remote_function/test_logging_config.py deleted file mode 100644 index 6454ea1071..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_logging_config.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for logging_config module.""" -from __future__ import absolute_import - -import logging -import time -from unittest.mock import patch -from sagemaker.core.remote_function.logging_config import _UTCFormatter, get_logger - - -class TestUTCFormatter: - """Test _UTCFormatter class.""" - - def test_converter_is_gmtime(self): - """Test that converter is set to gmtime.""" - formatter = _UTCFormatter() - assert formatter.converter == time.gmtime - - def test_formats_time_in_utc(self): - """Test that time is formatted in UTC.""" - formatter = _UTCFormatter("%(asctime)s") - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test message", - args=(), - exc_info=None, - ) - formatted = formatter.format(record) - # Should contain UTC time format - assert formatted - - -class TestGetLogger: - """Test get_logger function.""" - - def test_returns_logger_with_correct_name(self): - """Test that logger has correct name.""" - logger = get_logger() - assert logger.name == "sagemaker.remote_function" - - def test_logger_has_info_level(self): - """Test that logger is set to INFO level.""" - logger = get_logger() - assert logger.level == logging.INFO - - def test_logger_has_handler(self): - """Test that logger has at least one handler.""" - logger = get_logger() - assert len(logger.handlers) > 0 - - def test_logger_handler_has_utc_formatter(self): - """Test that logger handler uses UTC formatter.""" - logger = get_logger() - handler = logger.handlers[0] - # Check that formatter has gmtime converter (UTC formatter characteristic) - assert handler.formatter.converter == time.gmtime - - def test_logger_does_not_propagate(self): - """Test that logger does not propagate to root logger.""" - logger = get_logger() - assert logger.propagate == 0 - - def test_get_logger_is_idempotent(self): - """Test that calling get_logger multiple times returns same logger.""" - logger1 = get_logger() - logger2 = get_logger() - assert logger1 is logger2 - - def test_logger_handler_is_stream_handler(self): - """Test that logger uses StreamHandler.""" - logger = get_logger() - assert isinstance(logger.handlers[0], logging.StreamHandler) diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py index 1f51612c59..ecdf1135a6 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py @@ -44,8 +44,8 @@ from sagemaker.core.common_utils import unique_name_from_base_uuid4, format_tags, Tags if TYPE_CHECKING: - from sagemaker.core.remote_function.spark_config import SparkConfig - from sagemaker.core.remote_function.job import _JobSettings + from sagemaker.train.remote_function.spark_config import SparkConfig + from sagemaker.train.remote_function.job import _JobSettings logger = logging.getLogger(__name__) @@ -83,11 +83,11 @@ def __init__( func_kwargs (dict): keyword arguments of the python function. **kwargs: Additional arguments to be passed to the `step` decorator. """ - from sagemaker.core.remote_function.core.pipeline_variables import ( + from sagemaker.train.remote_function.core.pipeline_variables import ( convert_pipeline_variables_to_pickleable, ) - from sagemaker.core.remote_function.core.serialization import CloudpickleSerializer - from sagemaker.core.remote_function.core.stored_function import _SerializedData + from sagemaker.train.remote_function.core.serialization import CloudpickleSerializer + from sagemaker.train.remote_function.core.stored_function import _SerializedData super(_FunctionStep, self).__init__( name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies @@ -151,7 +151,7 @@ def depends_on(self, depends_on: List[Union[str, "Step", StepOutput]]): def _job_settings(self) -> "_JobSettings": """Returns the job settings for the step.""" - from sagemaker.core.remote_function.job import _JobSettings + from sagemaker.train.remote_function.job import _JobSettings context = load_step_compilation_context() @@ -193,7 +193,7 @@ def _job_settings(self) -> "_JobSettings": @property def arguments(self) -> RequestType: """Generates the arguments dictionary that is used to call `create_training_job`.""" - from sagemaker.core.remote_function.job import _Job + from sagemaker.train.remote_function.job import _Job step_compilation_context = load_step_compilation_context() @@ -274,7 +274,7 @@ def expr(self) -> RequestType: def _to_json_get(self) -> JsonGet: """Expression structure for workflow service calls using JsonGet resolution.""" - from sagemaker.core.remote_function.core.stored_function import ( + from sagemaker.train.remote_function.core.stored_function import ( JSON_SERIALIZED_RESULT_KEY, JSON_RESULTS_FILE, ) @@ -547,7 +547,7 @@ def _step(func): raise ValueError("Auto Capture of dependencies is not supported for pipeline steps.") # avoid circular import - from sagemaker.core.remote_function.client import RemoteExecutor + from sagemaker.train.remote_function.client import RemoteExecutor @wraps(func) def wrapper(*args, **kwargs): diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py index 9b4b9a191b..4d9f485ace 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py @@ -28,10 +28,10 @@ from sagemaker.core.local.local_session import LocalSession from sagemaker.core._studio import _append_project_tags from sagemaker.core.config.config_schema import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH -from sagemaker.core.remote_function.core.serialization import deserialize_obj_from_s3 -from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER -from sagemaker.core.remote_function.errors import RemoteFunctionError -from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT +from sagemaker.train.remote_function.core.serialization import deserialize_obj_from_s3 +from sagemaker.train.remote_function.core.stored_function import RESULTS_FOLDER +from sagemaker.train.remote_function.errors import RemoteFunctionError +from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT from sagemaker.core.s3 import s3_path_join from sagemaker.core.helper.session_helper import Session from sagemaker.core.common_utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py index 9169f1ce7f..eb2abe916d 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py @@ -352,8 +352,8 @@ def test_get_function_step_result_wrong_container(mock_session): def test_get_function_step_result_incomplete_job(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT - from sagemaker.core.remote_function.errors import RemoteFunctionError + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.errors import RemoteFunctionError step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { @@ -368,7 +368,7 @@ def test_get_function_step_result_incomplete_job(mock_session): def test_get_function_step_result_success(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { @@ -431,7 +431,7 @@ def test_pipeline_execution_result_waiter_error(mock_session): def test_pipeline_execution_result_terminal_failure(mock_session): from sagemaker.mlops.workflow.pipeline import _PipelineExecution from botocore.exceptions import WaiterError - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session) mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = { @@ -451,7 +451,7 @@ def test_pipeline_execution_result_terminal_failure(mock_session): def test_get_function_step_result_obsolete_s3_path(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { diff --git a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py index 6436ddaa22..87e9aca383 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py @@ -13,7 +13,7 @@ """Defines classes and helper methods used in remote function executions.""" from __future__ import absolute_import -from sagemaker.core.remote_function.client import remote, RemoteExecutor # noqa: F401 -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 -from sagemaker.core.remote_function.spark_config import SparkConfig # noqa: F401 +from sagemaker.train.remote_function.client import remote, RemoteExecutor # noqa: F401 +from sagemaker.train.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 +from sagemaker.train.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 +from sagemaker.train.remote_function.spark_config import SparkConfig # noqa: F401 diff --git a/sagemaker-train/src/sagemaker/train/remote_function/client.py b/sagemaker-train/src/sagemaker/train/remote_function/client.py index a38b57662a..ecc193b8b4 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/client.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/client.py @@ -26,24 +26,24 @@ from sagemaker.core.exceptions import UnexpectedStatusException from sagemaker.core.experiments._run_context import _RunContext -import sagemaker.core.remote_function.core.serialization as serialization -from sagemaker.core.remote_function.errors import ( +import sagemaker.train.remote_function.core.serialization as serialization +from sagemaker.train.remote_function.errors import ( RemoteFunctionError, ServiceError, DeserializationError, ) -from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( +from sagemaker.train.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER +from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( RuntimeEnvironmentError, ) from sagemaker.core.helper.session_helper import Session from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.job import _JobSettings, _Job, _RunInfo -from sagemaker.core.remote_function import logging_config +from sagemaker.train.remote_function.job import _JobSettings, _Job, _RunInfo +from sagemaker.train.remote_function import logging_config from sagemaker.core.common_utils import name_from_base, base_from_name -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.train.remote_function.spark_config import SparkConfig +from sagemaker.train.remote_function.custom_file_filter import CustomFileFilter from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py index 3217e88672..857ac40eb0 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py @@ -13,7 +13,7 @@ """SageMaker remote function data serializer/deserializer.""" from __future__ import absolute_import -from sagemaker.core.remote_function.errors import SerializationError +from sagemaker.train.remote_function.errors import SerializationError from sagemaker.core.helper.pipeline_variable import PipelineVariable from sagemaker.core.workflow.parameters import ( diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py index 491267b35f..2497302080 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py @@ -18,7 +18,7 @@ from typing import Any, Union, Dict, List, Tuple from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.core.serialization import deserialize_obj_from_s3 +from sagemaker.train.remote_function.core.serialization import deserialize_obj_from_s3 from sagemaker.core.workflow.step_outputs import get_step diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py index 8871f6727f..7eed7e0d21 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py @@ -27,7 +27,7 @@ import cloudpickle from tblib import pickling_support -from sagemaker.core.remote_function.errors import ( +from sagemaker.train.remote_function.errors import ( ServiceError, SerializationError, DeserializationError, diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py index c7ee86f8a7..ff5d8ddad4 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py @@ -19,13 +19,13 @@ from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.core.pipeline_variables import ( +from sagemaker.train.remote_function import logging_config +from sagemaker.train.remote_function.core.pipeline_variables import ( Context, resolve_pipeline_variables, ) -import sagemaker.core.remote_function.core.serialization as serialization +import sagemaker.train.remote_function.core.serialization as serialization from sagemaker.core.helper.session_helper import Session diff --git a/sagemaker-train/src/sagemaker/train/remote_function/errors.py b/sagemaker-train/src/sagemaker/train/remote_function/errors.py index 3f391570cf..bfebb0726a 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/errors.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/errors.py @@ -17,7 +17,7 @@ from tblib import pickling_support from sagemaker.core.s3 import s3_path_join -import sagemaker.core.remote_function.core.serialization as serialization +import sagemaker.train.remote_function.core.serialization as serialization DEFAULT_FAILURE_CODE = 1 diff --git a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py index 2e69f4f116..4606b73459 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py @@ -21,16 +21,16 @@ from typing import TYPE_CHECKING import boto3 -from sagemaker.core.remote_function.job import ( +from sagemaker.train.remote_function.job import ( KEY_EXPERIMENT_NAME, KEY_RUN_NAME, ) from sagemaker.core.helper.session_helper import Session from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.errors import handle_error -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.core.pipeline_variables import Context +from sagemaker.train.remote_function.errors import handle_error +from sagemaker.train.remote_function import logging_config +from sagemaker.train.remote_function.core.pipeline_variables import Context if TYPE_CHECKING: from sagemaker.core.experiments.run import Run @@ -101,7 +101,7 @@ def _execute_remote_function( sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context ): """Execute stored remote function""" - from sagemaker.core.remote_function.core.stored_function import StoredFunction + from sagemaker.train.remote_function.core.stored_function import StoredFunction stored_function = StoredFunction( sagemaker_session=sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/remote_function/job.py b/sagemaker-train/src/sagemaker/train/remote_function/job.py index 435062db57..90c6807b53 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/job.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/job.py @@ -49,7 +49,7 @@ from sagemaker.core.experiments.run import Run from sagemaker.core.image_uris import get_base_python_image_uri from sagemaker.core import image_uris -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation +from sagemaker.train.remote_function.checkpoint_location import CheckpointLocation from sagemaker.core.helper.session_helper import get_execution_role, expand_role, Session from sagemaker.core.common_utils import ( name_from_base, @@ -60,16 +60,16 @@ ) from sagemaker.core.s3 import s3_path_join, S3Uploader -from sagemaker.core.remote_function.core.stored_function import StoredFunction, _SerializedData -from sagemaker.core.remote_function.core.pipeline_variables import Context +from sagemaker.train.remote_function.core.stored_function import StoredFunction, _SerializedData +from sagemaker.train.remote_function.core.pipeline_variables import Context -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( +from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( RuntimeEnvironmentManager, _DependencySettings, ) -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.custom_file_filter import ( +from sagemaker.train.remote_function import logging_config +from sagemaker.train.remote_function.spark_config import SparkConfig +from sagemaker.train.remote_function.custom_file_filter import ( CustomFileFilter, copy_workdir, resolve_custom_file_filter_from_config_file, diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py index 2c20151ed1..f07e860cf9 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -31,7 +31,7 @@ get_logger, ) else: - from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( RuntimeEnvironmentManager, _DependencySettings, get_logger, diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py index f36e17a04c..c5d9f15ee2 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py @@ -28,7 +28,7 @@ get_logger, ) else: - from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( get_logger, ) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py index 21eef068b9..6d4eaeb18e 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/spark_app.py @@ -13,6 +13,6 @@ """This is a simple scrip of spark which invokes the pickled remote function""" from __future__ import absolute_import -from sagemaker.core.remote_function import invoke_function +from sagemaker.train.remote_function import invoke_function invoke_function.main() From e539879f64ac642d46e420abdc0d0b226d05b73e Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:13:09 -0800 Subject: [PATCH 5/8] Remove duplicated base_serializer and base_deserializer files --- .../src/sagemaker/core/base_deserializers.py | 35 ------------- .../src/sagemaker/core/base_serializers.py | 35 ------------- .../tests/unit/test_base_deserializers.py | 52 ------------------- .../sagemaker/serve/model_builder_utils.py | 2 +- 4 files changed, 1 insertion(+), 123 deletions(-) delete mode 100644 sagemaker-core/src/sagemaker/core/base_deserializers.py delete mode 100644 sagemaker-core/src/sagemaker/core/base_serializers.py delete mode 100644 sagemaker-core/tests/unit/test_base_deserializers.py diff --git a/sagemaker-core/src/sagemaker/core/base_deserializers.py b/sagemaker-core/src/sagemaker/core/base_deserializers.py deleted file mode 100644 index 69c5be63e4..0000000000 --- a/sagemaker-core/src/sagemaker/core/base_deserializers.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Backward compatibility module for base deserializers. - -This module provides backward compatibility for code importing from -sagemaker.core.base_deserializers. The actual implementation is in -sagemaker.core.deserializers. - -.. deprecated:: 3.0.0 - Use :mod:`sagemaker.core.deserializers` instead. -""" -from __future__ import absolute_import - -import warnings - -# Re-export all deserializers from the correct location -from sagemaker.core.deserializers import * # noqa: F401, F403 - -# Issue deprecation warning -warnings.warn( - "Importing from sagemaker.core.base_deserializers is deprecated. " - "Use sagemaker.core.deserializers instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/sagemaker-core/src/sagemaker/core/base_serializers.py b/sagemaker-core/src/sagemaker/core/base_serializers.py deleted file mode 100644 index ea9a665866..0000000000 --- a/sagemaker-core/src/sagemaker/core/base_serializers.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Backward compatibility module for base serializers. - -This module provides backward compatibility for code importing from -sagemaker.core.base_serializers. The actual implementation is in -sagemaker.core.serializers. - -.. deprecated:: 3.0.0 - Use :mod:`sagemaker.core.serializers` instead. -""" -from __future__ import absolute_import - -import warnings - -# Re-export all serializers from the correct location -from sagemaker.core.serializers import * # noqa: F401, F403 - -# Issue deprecation warning -warnings.warn( - "Importing from sagemaker.core.base_serializers is deprecated. " - "Use sagemaker.core.serializers instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/sagemaker-core/tests/unit/test_base_deserializers.py b/sagemaker-core/tests/unit/test_base_deserializers.py deleted file mode 100644 index b6c94370f0..0000000000 --- a/sagemaker-core/tests/unit/test_base_deserializers.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from __future__ import absolute_import - -import pytest -import warnings - - -def test_base_deserializers_deprecation_warning(): - """Test that importing from base_deserializers raises DeprecationWarning.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # Import the module which should trigger the warning - import sagemaker.core.base_deserializers # noqa: F401 - - # Check that a warning was raised - assert len(w) >= 1 - - # Find the deprecation warning - deprecation_warnings = [ - warning for warning in w if issubclass(warning.category, DeprecationWarning) - ] - assert len(deprecation_warnings) >= 1 - - # Check the warning message - assert "base_deserializers is deprecated" in str(deprecation_warnings[0].message) - assert "sagemaker.core.deserializers" in str(deprecation_warnings[0].message) - - -def test_base_deserializers_imports_from_deserializers(): - """Test that base_deserializers re-exports from deserializers module.""" - import sagemaker.core.base_deserializers as base_deser - import sagemaker.core.deserializers as deser - - # Check that the modules have the same attributes - # (excluding private attributes and module-specific ones) - base_attrs = {attr for attr in dir(base_deser) if not attr.startswith("_")} - deser_attrs = {attr for attr in dir(deser) if not attr.startswith("_")} - - # base_deserializers should have at least the public attributes from deserializers - assert base_attrs.intersection(deser_attrs) == deser_attrs diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 1c3016cf86..75cb927ca0 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -126,7 +126,7 @@ def build(self): from sagemaker.core.helper.pipeline_variable import PipelineVariable from sagemaker.core import model_uris from sagemaker.serve.utils.local_hardware import _get_available_gpus -from sagemaker.core.base_serializers import JSONSerializer +from sagemaker.core.serializers import JSONSerializer from sagemaker.core.deserializers import JSONDeserializer from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.builder.requirements_manager import RequirementsManager From 46013fcb36902fb56015f58191fe5722906460b2 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:18:59 -0800 Subject: [PATCH 6/8] Remove deprecated primary_container No unit tests need to be updated --- .../sagemaker/core/helper/session_helper.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index bc35c7eb9d..c202efd5c8 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1669,22 +1669,10 @@ def _create_model_request( container_defs, vpc_config=None, enable_network_isolation=False, - primary_container=None, tags=None, ): # pylint: disable=redefined-outer-name """Placeholder docstring""" - if container_defs and primary_container: - raise ValueError("Both container_defs and primary_container can not be passed as input") - - if primary_container: - msg = ( - "primary_container is going to be deprecated in a future release. Please use " - "container_defs instead." - ) - warnings.warn(msg, DeprecationWarning) - container_defs = primary_container - role = self.expand_role(role) if isinstance(container_defs, list): @@ -1726,7 +1714,6 @@ def create_model( container_defs=None, vpc_config=None, enable_network_isolation=None, - primary_container=None, tags=None, ): """Create an Amazon SageMaker ``Model``. @@ -1754,11 +1741,6 @@ def create_model( * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. enable_network_isolation (bool): Whether the model requires network isolation or not. - primary_container (str or dict[str, str]): Docker image which defines the inference - code. You can also specify the return value of ``sagemaker.container_def()``, - which is used to create more advanced container configurations, including model - containers which need artifacts from S3. This field is deprecated, please use - container_defs instead. tags(Optional[Tags]): Optional. The list of tags to add to the model. Example: @@ -1794,7 +1776,6 @@ def create_model( container_defs=container_defs, vpc_config=vpc_config, enable_network_isolation=enable_network_isolation, - primary_container=primary_container, tags=tags, ) From b45ea24de533e6f4c87f5e3edb53c1c4cb234a31 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:45:20 -0800 Subject: [PATCH 7/8] Remove the duplicate training modules in sagemaker-core --- .../sagemaker/core/modules/train/__init__.py | 14 - .../train/container_drivers/__init__.py | 14 - .../container_drivers/common/__init__.py | 14 - .../train/container_drivers/common/utils.py | 213 -------- .../distributed_drivers/__init__.py | 14 - .../basic_script_driver.py | 81 --- .../distributed_drivers/mpi_driver.py | 123 ----- .../distributed_drivers/mpi_utils.py | 302 ------------ .../distributed_drivers/torchrun_driver.py | 129 ----- .../container_drivers/scripts/__init__.py | 14 - .../container_drivers/scripts/environment.py | 305 ------------ .../core/modules/train/sm_recipes/__init__.py | 0 .../train/sm_recipes/training_recipes.json | 17 - .../core/modules/train/sm_recipes/utils.py | 330 ------------- .../tests/unit/modules/train/__init__.py | 12 - .../train/container_drivers/__init__.py | 0 .../distributed_drivers/__init__.py | 0 .../distributed_drivers/test_mpi_utils.py | 466 ------------------ .../unit/modules/train/test_environment.py | 386 --------------- .../modules/train/test_sm_recipes_utils.py | 436 ---------------- .../train/container_drivers/common/utils.py | 12 +- .../src/sagemaker/train/sm_recipes/utils.py | 32 +- 22 files changed, 25 insertions(+), 2889 deletions(-) delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json delete mode 100644 sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py delete mode 100644 sagemaker-core/tests/unit/modules/train/__init__.py delete mode 100644 sagemaker-core/tests/unit/modules/train/container_drivers/__init__.py delete mode 100644 sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/__init__.py delete mode 100644 sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py delete mode 100644 sagemaker-core/tests/unit/modules/train/test_environment.py delete mode 100644 sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py diff --git a/sagemaker-core/src/sagemaker/core/modules/train/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/__init__.py deleted file mode 100644 index c5b5d01ed4..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules train directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py deleted file mode 100644 index 864f3663b8..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py deleted file mode 100644 index aab88c6b97..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - common directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py deleted file mode 100644 index c07aa1359a..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import logging -import sys -import subprocess -import traceback -import json - -from typing import List, Dict, Any, Tuple, IO, Optional - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -FAILURE_FILE = "/opt/ml/output/failure" -DEFAULT_FAILURE_MESSAGE = """ -Training Execution failed. -For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. -TrainingJob - {training_job_name} -""" - -USER_CODE_PATH = "/opt/ml/input/data/code" -SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" -DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" - -HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - - -def write_failure_file(message: Optional[str] = None): - """Write a failure file with the message.""" - if message is None: - message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) - if not os.path.exists(FAILURE_FILE): - with open(FAILURE_FILE, "w") as f: - f.write(message) - - -def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): - """Read the source code config json file.""" - try: - with open(source_code_json, "r") as f: - source_code_dict = json.load(f) or {} - except FileNotFoundError: - source_code_dict = {} - return source_code_dict - - -def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): - """Read the distribution config json file.""" - try: - with open(distributed_json, "r") as f: - distributed_dict = json.load(f) or {} - except FileNotFoundError: - distributed_dict = {} - return distributed_dict - - -def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): - """Read the hyperparameters config json file.""" - try: - with open(hyperparameters_json, "r") as f: - hyperparameters_dict = json.load(f) or {} - except FileNotFoundError: - hyperparameters_dict = {} - return hyperparameters_dict - - -def get_process_count(process_count: Optional[int] = None) -> int: - """Get the number of processes to run on each node in the training job.""" - return ( - process_count - or int(os.environ.get("SM_NUM_GPUS", 0)) - or int(os.environ.get("SM_NUM_NEURONS", 0)) - or 1 - ) - - -def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: - """Convert the hyperparameters to CLI arguments.""" - cli_args = [] - for key, value in hyperparameters.items(): - value = safe_deserialize(value) - cli_args.extend([f"--{key}", safe_serialize(value)]) - - return cli_args - - -def safe_deserialize(data: Any) -> Any: - """Safely deserialize data from a JSON string. - - This function handles the following cases: - 1. If `data` is not a string, it returns the input as-is. - 2. If `data` is a string and matches common boolean values ("true" or "false"), - it returns the corresponding boolean value (True or False). - 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. - 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. - - Returns: - Any: The deserialized data, or the original input if it cannot be JSON-decoded. - """ - if not isinstance(data, str): - return data - - lower_data = data.lower() - if lower_data in ["true"]: - return True - if lower_data in ["false"]: - return False - - try: - return json.loads(data) - except json.JSONDecodeError: - return data - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def get_python_executable() -> str: - """Get the python executable path.""" - return sys.executable - - -def log_subprocess_output(pipe: IO[bytes]): - """Log the output from the subprocess.""" - for line in iter(pipe.readline, b""): - logger.info(line.decode("utf-8").strip()) - - -def execute_commands(commands: List[str]) -> Tuple[int, str]: - """Execute the provided commands and return exit code with failure traceback if any.""" - try: - process = subprocess.Popen( - commands, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - with process.stdout: - log_subprocess_output(process.stdout) - exitcode = process.wait() - if exitcode != 0: - raise subprocess.CalledProcessError(exitcode, commands) - return exitcode, "" - except subprocess.CalledProcessError as e: - # Capture the traceback in case of failure - error_traceback = traceback.format_exc() - print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") - return e.returncode, error_traceback - - -def is_worker_node() -> bool: - """Check if the current node is a worker node.""" - return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") - - -def is_master_node() -> bool: - """Check if the current node is the master node.""" - return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py deleted file mode 100644 index a44e7e81a9..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py deleted file mode 100644 index 0b086a8e4f..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Basic Script Driver.""" -from __future__ import absolute_import - -import os -import sys -import json -import shlex - -from pathlib import Path -from typing import List - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - get_python_executable, - execute_commands, - write_failure_file, - hyperparameters_to_cli_args, -) - - -def create_commands() -> List[str]: - """Create the commands to execute.""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - hyperparameters = json.loads(os.environ["SM_HPS"]) - python_executable = get_python_executable() - - args = hyperparameters_to_cli_args(hyperparameters) - if entry_script.endswith(".py"): - commands = [python_executable, entry_script] - commands += args - elif entry_script.endswith(".sh"): - args_str = " ".join(shlex.quote(arg) for arg in args) - commands = [ - "/bin/sh", - "-c", - f"chmod +x {entry_script} && ./{entry_script} {args_str}", - ] - else: - raise ValueError( - f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." - ) - return commands - - -def main(): - """Main function for the Basic Script Driver. - - This function is the entry point for the Basic Script Driver. - - Execution Lifecycle: - 1. Read the source code and hyperparameters JSON files. - 2. Set hyperparameters as command line arguments. - 3. Create the commands to execute. - 4. Execute the commands. - """ - - cmd = create_commands() - - logger.info(f"Executing command: {' '.join(cmd)}") - exit_code, traceback = execute_commands(cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py deleted file mode 100644 index 7d991e30da..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the MPI driver script.""" -from __future__ import absolute_import - -import os -import sys -import json -from pathlib import Path - -try: - from mpi_utils import ( - start_sshd_daemon, - bootstrap_master_node, - bootstrap_worker_node, - get_mpirun_command, - write_status_file_to_workers, - write_env_vars_to_file, - ) -except ImportError: - # mpi_utils is an optional external dependency for MPI distributed training - # If not available, provide stub functions that raise helpful errors - def _mpi_not_available(*args, **kwargs): - raise ImportError( - "MPI distributed training requires the 'mpi_utils' package. " - "Please install it to use MPI-based distributed training." - ) - - start_sshd_daemon = _mpi_not_available - bootstrap_master_node = _mpi_not_available - bootstrap_worker_node = _mpi_not_available - get_mpirun_command = _mpi_not_available - write_status_file_to_workers = _mpi_not_available - write_env_vars_to_file = _mpi_not_available - - -sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, -) - - -def main(): - """Main function for the MPI driver script. - - The MPI Dirver is responsible for setting up the MPI environment, - generating the correct mpi commands, and launching the MPI job. - - Execution Lifecycle: - 1. Setup General Environment Variables at /etc/environment - 2. Start SSHD Daemon - 3. Bootstrap Worker Nodes - a. Wait to establish connection with Master Node - b. Wait for Master Node to write status file - 4. Bootstrap Master Node - a. Wait to establish connection with Worker Nodes - b. Generate MPI Command - c. Execute MPI Command with user script provided in `entry_script` - d. Write status file to Worker Nodes - 5. Exit - - """ - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - sm_current_host = os.environ["SM_CURRENT_HOST"] - sm_hosts = json.loads(os.environ["SM_HOSTS"]) - sm_master_addr = os.environ["SM_MASTER_ADDR"] - - write_env_vars_to_file() - start_sshd_daemon() - - if sm_current_host != sm_master_addr: - bootstrap_worker_node(sm_master_addr) - else: - worker_hosts = [host for host in sm_hosts if host != sm_master_addr] - bootstrap_master_node(worker_hosts) - - host_list = json.loads(os.environ["SM_HOSTS"]) - host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - - if process_count > 1: - host_list = ["{}:{}".format(host, process_count) for host in host_list] - - mpi_command = get_mpirun_command( - host_count=host_count, - host_list=host_list, - num_processes=process_count, - additional_options=distributed_config["mpi_additional_options"] or [], - entry_script_path=entry_script, - ) - - args = hyperparameters_to_cli_args(hyperparameters) - mpi_command += args - - logger.info(f"Executing command: {' '.join(mpi_command)}") - exit_code, error_traceback = execute_commands(mpi_command) - write_status_file_to_workers(worker_hosts) - - if exit_code != 0: - write_failure_file(error_traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py deleted file mode 100644 index ec9e1fcef9..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides mpi related utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import sys -import subprocess -import time - -from pathlib import Path -from typing import List - -import paramiko - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, - get_python_executable, - logger, -) - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info(f"Writing {status_file} to {host}") - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info(f"Cannot connect to {host}") - return False - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError(f"Timed out waiting for {worker} to be reachable.") - logger.info(f"Retrying to write status file to {worker}") - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info(f"Waiting for status file {status_file}") - while not os.path.exists(status_file): - time.sleep(30) - logger.info(f"Found status file {status_file}") - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - logger.debug("Testing connection to host %s", host) - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug(f"Connection failed with exception: {e}") - return False - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info(f"Worker is attempting to connect to the master node {master_host}...") - if _can_connect(master_host, port): - logger.info(f"Worker can connect to master node {master_host}.") - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"]) - _wait_for_status_file(status_file) - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def validate_smddprun() -> bool: - """Whether smddprun is installed. - - Returns: - bool: True if installed - """ - try: - output = subprocess.run( - ["which", "smddprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def validate_smddpmprun() -> bool: - """Whether smddpmprun is installed. - - Returns: - bool: True if both are installed - """ - try: - output = subprocess.run( - ["which", "smddpmprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def write_env_vars_to_file(): - """Write environment variables to /etc/environment file.""" - with open("/etc/environment", "a", encoding="utf-8") as f: - for name in os.environ: - f.write(f"{name}={os.environ.get(name)}\n") - - -def get_mpirun_command( - host_count: int, - host_list: List[str], - num_processes: int, - additional_options: List[str], - entry_script_path: str, -): - """Fetch mpi command""" - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - mpirun_command = [ - "mpirun", - "--host", - ",".join(host_list), - "-np", - str(num_processes), - "--allow-run-as-root", - "--tag-output", - "-mca", - "btl_tcp_if_include", - network_interface_name, - "-mca", - "oob_tcp_if_include", - network_interface_name, - "-mca", - "plm_rsh_no_tree_spawn", - "1", - "-mca", - "pml", - "ob1", - "-mca", - "btl", - "^openib", - "-mca", - "orte_abort_on_non_zero_status", - "1", - "-mca", - "btl_vader_single_copy_mechanism", - "none", - "-mca", - "plm_rsh_num_concurrent", - str(host_count), - "-x", - "NCCL_SOCKET_IFNAME=%s" % network_interface_name, - "-x", - "LD_LIBRARY_PATH", - "-x", - "PATH", - ] - - if additional_options: - mpirun_command.extend(additional_options) - - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - # EFA settings - if instance_type in SM_EFA_NCCL_INSTANCES: - mpirun_command.extend(["-x", "FI_PROVIDER=efa"]) - # Use simple protocol to handle the out-of-order data delivery from EFA - mpirun_command.extend(["-x", "NCCL_PROTO=simple"]) - - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) - - for credential in [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - ]: - if credential in os.environ: - mpirun_command.extend(["-x", credential]) - - mpirun_command.extend([get_python_executable()]) - mpirun_command.extend(["-m", "mpi4py", entry_script_path]) - return mpirun_command diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py deleted file mode 100644 index 7fcfabe05d..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Torchrun driver script.""" -from __future__ import absolute_import - -import os -import sys -import json - -from pathlib import Path -from typing import List, Tuple - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - hyperparameters_to_cli_args, - get_process_count, - get_python_executable, - execute_commands, - write_failure_file, - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, -) - - -def pytorch_version() -> Tuple[int, int]: - """Get the PyTorch version as a tuple of integers.""" - import torch - - return tuple(map(int, torch.__version__.split(".")[:2])) - - -def get_base_pytorch_command() -> List[str]: - """Get the base Torch Distributed launcher to execute""" - if pytorch_version() >= (1, 9): - return ["torchrun"] - return [f"{get_python_executable()}", "-m", "torch.distributed.launch"] - - -def setup_env(): - """Setup the environment variables for PyTorch distributed training""" - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - os.environ["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" - os.environ["RDMAV_FORK_SAFE"] = "1" - os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - os.environ["NCCL_PROTO"] = "simple" - - -def create_commands(): - """Create the Torch Distributed command to execute""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - host_count = int(os.environ["SM_HOST_COUNT"]) - - torch_cmd = [] - if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": - torch_cmd.append("neuron_parallel_compile") - - torch_cmd.extend(get_base_pytorch_command()) - torch_cmd.extend( - [ - f"--nnodes={host_count}", - f"--nproc_per_node={process_count}", - ] - ) - - # If more than one node is used, add node rank information - if int(host_count) > 1: - torch_cmd.extend( - [ - f"--master_addr={os.environ['SM_MASTER_ADDR']}", - f"--master_port={os.environ['SM_MASTER_PORT']}", - f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}", - ] - ) - - torch_cmd.extend([entry_script]) - - args = hyperparameters_to_cli_args(hyperparameters) - torch_cmd += args - - return torch_cmd - - -def main(): - """Main function to execute the PyTorch distributed training script. - - This function sets some environment variables and executes the PyTorch - distributed training script. - - Execution Lifecycle: - 1. Setup Environment Variables for PyTorch Distributed Training - 2. Create Torch Distributed Command - 3. Execute Torch Distributed Command with user script provided in `entry_script` - 4. Exit - - """ - setup_env() - torch_cmd = create_commands() - logger.info(f"Executing command: {' '.join(torch_cmd)}") - exit_code, traceback = execute_commands(torch_cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py deleted file mode 100644 index f04c5b17a0..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - scripts directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py deleted file mode 100644 index 897b1f8af4..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the environment variables for the training job container.""" -from __future__ import absolute_import - -from typing import Dict, Any -import multiprocessing -import subprocess -import json -import os -import sys -from pathlib import Path -import logging - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - safe_serialize, - safe_deserialize, - read_distributed_json, - read_source_code_json, -) - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" -SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json" -HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json" - -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: - """Deserialize hyperparameters from string to their original types. - - Args: - hyperparameters (Dict[str, str]): Hyperparameters as strings. - - Returns: - Dict[str, Any]: Hyperparameters as their original types. - """ - deserialized_hyperparameters = {} - for key, value in hyperparameters.items(): - deserialized_hyperparameters[key] = safe_deserialize(value) - return deserialized_hyperparameters - - -def set_env( - resource_config: Dict[str, Any], - input_data_config: Dict[str, Any], - hyperparameters_config: Dict[str, Any], - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - input_data_config (Dict[str, Any]): Input data configuration for the training job. - hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_LOG_LEVEL": SM_LOG_LEVEL, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # SourceCode and DistributedConfig Environment Variables - source_code = read_source_code_json() - if source_code: - env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH - env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") - - distributed = read_distributed_json() - if distributed: - env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH - env_vars["SM_DISTRIBUTED_CONFIG"] = distributed - - # Data Channels - channels = list(input_data_config.keys()) - for channel in channels: - env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" - env_vars["SM_CHANNELS"] = channels - - # Hyperparameters - hps = deserialize_hyperparameters(hyperparameters_config) - for key, value in hps.items(): - key_upper = key.replace("-", "_").upper() - env_vars[f"SM_HP_{key_upper}"] = value - env_vars["SM_HPS"] = hps - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "channel_input_dirs": { - channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels - }, - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "hyperparameters": env_vars["SM_HPS"], - "input_data_config": input_data_config, - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "log_level": env_vars["SM_LOG_LEVEL"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def main(): - """Main function to set the environment variables for the training job container.""" - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - with open(INPUT_DATA_CONFIG, "r") as f: - input_data_config = json.load(f) - with open(HYPERPARAMETERS_CONFIG, "r") as f: - hyperparameters_config = json.load(f) - - set_env( - resource_config=resource_config, - input_data_config=input_data_config, - hyperparameters_config=hyperparameters_config, - output_file=ENV_OUTPUT_FILE, - ) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json deleted file mode 100644 index a51513f49f..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", - "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", - "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", - "gpu_image" : { - "framework": "pytorch-smp", - "version": "2.4.1", - "additional_args": { - "container_version": "cu121" - } - }, - "neuron_image": { - "framework": "hyperpod-recipes-neuron", - "version": "2.1.2", - "additional_args": {} - } -} \ No newline at end of file diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py deleted file mode 100644 index 1eb0b83e97..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utility functions for SageMaker training recipes.""" -from __future__ import absolute_import - -import math -import os -import json -import shutil -import tempfile -from urllib.request import urlretrieve -from typing import Dict, Any, Optional, Tuple - -import omegaconf -from omegaconf import OmegaConf, dictconfig - -from sagemaker.core.image_uris import retrieve - -from sagemaker.core.modules import logger -from sagemaker.core.modules.utils import _run_clone_command_silent -from sagemaker.core.modules.configs import Compute, SourceCode -from sagemaker.core.modules.distributed import Torchrun, SMP - - -def _try_resolve_recipe(recipe, key=None): - """Try to resolve recipe and return resolved recipe.""" - if key is not None: - recipe = dictconfig.DictConfig({key: recipe}) - try: - OmegaConf.resolve(recipe) - except omegaconf.errors.OmegaConfBaseException: - return None - if key is None: - return recipe - return recipe[key] - - -def _determine_device_type(instance_type: str) -> str: - """Determine device type (gpu, cpu, trainium) based on instance type.""" - instance_family = instance_type.split(".")[1] - if instance_family.startswith(("p", "g")): - return "gpu" - if instance_family.startswith("trn"): - return "trainium" - return "cpu" - - -def _load_recipes_cfg() -> str: - """Load training recipes configuration json.""" - training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") - with open(training_recipes_cfg_filename) as training_recipes_cfg_file: - training_recipes_cfg = json.load(training_recipes_cfg_file) - return training_recipes_cfg - - -def _load_base_recipe( - training_recipe: str, - recipe_overrides: Optional[Dict[str, Any]] = None, - training_recipes_cfg: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Load recipe and apply overrides.""" - if recipe_overrides is None: - recipe_overrides = dict() - - temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name - - if training_recipe.endswith(".yaml"): - if os.path.isfile(training_recipe): - shutil.copy(training_recipe, temp_local_recipe) - else: - try: - urlretrieve(training_recipe, temp_local_recipe) - except Exception as e: - raise ValueError( - f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" - ) - else: - recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") - - launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( - "launcher_repo" - ) - _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name) - - recipe = os.path.join( - recipe_launcher_dir.name, - "recipes_collection", - "recipes", - training_recipe + ".yaml", - ) - if os.path.isfile(recipe): - shutil.copy(recipe, temp_local_recipe) - else: - raise ValueError(f"Recipe {training_recipe} not found.") - - recipe = OmegaConf.load(temp_local_recipe) - os.unlink(temp_local_recipe) - recipe = OmegaConf.merge(recipe, recipe_overrides) - return recipe - - -def _register_custom_resolvers(): - """Register custom resolvers for OmegaConf.""" - if not OmegaConf.has_resolver("multiply"): - OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) - if not OmegaConf.has_resolver("divide_ceil"): - OmegaConf.register_new_resolver( - "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True - ) - if not OmegaConf.has_resolver("divide_floor"): - OmegaConf.register_new_resolver( - "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True - ) - if not OmegaConf.has_resolver("add"): - OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) - - -def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): - """Get the model base name and script for the training recipe.""" - - model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - "deepseek": ("deepseek", "deepseek_pretrain.py"), - } - - for key in model_type_to_script: - if model_type.startswith(key): - model_type = key - break - - if model_type not in model_type_to_script: - raise ValueError(f"Model type {model_type} not supported") - - return model_type_to_script[model_type][0], model_type_to_script[model_type][1] - - -def _configure_gpu_args( - training_recipes_cfg: Dict[str, Any], - region_name: str, - recipe: OmegaConf, - recipe_train_dir: tempfile.TemporaryDirectory, -) -> Dict[str, Any]: - """Configure arguments specific to GPU.""" - source_code = SourceCode() - args = dict() - - adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( - "adapter_repo" - ) - _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - - if "model" not in recipe: - raise ValueError("Supplied recipe does not contain required field model.") - if "model_type" not in recipe["model"]: - raise ValueError("Supplied recipe does not contain required field model_type.") - model_type = recipe["model"]["model_type"] - - model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) - - source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) - source_code.entry_script = script - - gpu_image_cfg = training_recipes_cfg.get("gpu_image") - if isinstance(gpu_image_cfg, str): - training_image = gpu_image_cfg - else: - training_image = retrieve( - gpu_image_cfg.get("framework"), - region=region_name, - version=gpu_image_cfg.get("version"), - image_scope="training", - **gpu_image_cfg.get("additional_args"), - ) - - # Setting dummy parameters for now - torch_distributed = Torchrun(smp=SMP(random_seed="123456")) - args.update( - { - "source_code": source_code, - "training_image": training_image, - "distributed": torch_distributed, - } - ) - return args - - -def _configure_trainium_args( - training_recipes_cfg: Dict[str, Any], - region_name: str, - recipe_train_dir: tempfile.TemporaryDirectory, -) -> Dict[str, Any]: - """Configure arguments specific to Trainium.""" - source_code = SourceCode() - args = dict() - - _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) - - source_code.source_dir = os.path.join(recipe_train_dir.name, "examples") - source_code.entry_script = "training_orchestrator.py" - neuron_image_cfg = training_recipes_cfg.get("neuron_image") - if isinstance(neuron_image_cfg, str): - training_image = neuron_image_cfg - else: - training_image = retrieve( - neuron_image_cfg.get("framework"), - region=region_name, - version=neuron_image_cfg.get("version"), - image_scope="training", - **neuron_image_cfg.get("additional_args"), - ) - - args.update( - { - "source_code": source_code, - "training_image": training_image, - "distributed": Torchrun(), - } - ) - return args - - -def _get_args_from_recipe( - training_recipe: str, - compute: Compute, - region_name: str, - recipe_overrides: Optional[Dict[str, Any]], - requirements: Optional[str], -) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: - """Get arguments for ModelTrainer from a training recipe. - - Returns a dictionary of arguments to be used with ModelTrainer like: - ```python - { - "source_code": SourceCode, - "training_image": str, - "distributed": DistributedConfig, - "compute": Compute, - "hyperparameters": Dict[str, Any], - } - ``` - - Args: - training_recipe (str): - Name of the training recipe or path to the recipe file. - compute (Compute): - Compute configuration for training. - region_name (str): - Name of the AWS region. - recipe_overrides (Optional[Dict[str, Any]]): - Overrides for the training recipe. - requirements (Optional[str]): - Path to the requirements file. - """ - if compute.instance_type is None: - raise ValueError("Must set `instance_type` in compute when using training recipes.") - - training_recipes_cfg = _load_recipes_cfg() - recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) - - if "trainer" not in recipe: - raise ValueError("Supplied recipe does not contain required field trainer.") - - # Set instance_count - if compute.instance_count and "num_nodes" in recipe["trainer"]: - logger.warning( - f"Using Compute to set instance_count:\n{compute}." - "\nIgnoring trainer -> num_nodes in recipe." - ) - if compute.instance_count is None: - if "num_nodes" not in recipe["trainer"]: - raise ValueError( - "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." - ) - compute.instance_count = recipe["trainer"]["num_nodes"] - - if requirements and not os.path.isfile(requirements): - raise ValueError(f"Recipe requirements file {requirements} not found.") - - # Get Training Image, SourceCode, and distributed args - device_type = _determine_device_type(compute.instance_type) - recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") - if device_type == "gpu": - args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir) - elif device_type == "trainium": - args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir) - else: - raise ValueError(f"Devices of type {device_type} are not supported with training recipes.") - - _register_custom_resolvers() - - # Resolve Final Recipe - final_recipe = _try_resolve_recipe(recipe) - if final_recipe is None: - final_recipe = _try_resolve_recipe(recipe, "recipes") - if final_recipe is None: - final_recipe = _try_resolve_recipe(recipe, "training") - if final_recipe is None: - raise RuntimeError("Could not resolve provided recipe.") - - # Save Final Recipe to source_dir - OmegaConf.save( - config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") - ) - - # If recipe_requirements is provided, copy it to source_dir - if requirements: - shutil.copy(requirements, args["source_code"].source_dir) - args["source_code"].requirements = os.path.basename(requirements) - - # Update args with compute and hyperparameters - args.update( - { - "compute": compute, - "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"}, - } - ) - - return args, recipe_train_dir diff --git a/sagemaker-core/tests/unit/modules/train/__init__.py b/sagemaker-core/tests/unit/modules/train/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/modules/train/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/__init__.py b/sagemaker-core/tests/unit/modules/train/container_drivers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/__init__.py b/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py b/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py deleted file mode 100644 index ee3d60c6f2..0000000000 --- a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils module.""" -from __future__ import absolute_import - -import pytest -import os -import subprocess -import paramiko -from unittest.mock import Mock, patch, MagicMock, call - -from sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils import ( - _write_file_to_host, - write_status_file_to_workers, - _wait_for_status_file, - start_sshd_daemon, - CustomHostKeyPolicy, - _can_connect, - _wait_for_workers, - _wait_for_master, - bootstrap_worker_node, - bootstrap_master_node, - validate_smddprun, - validate_smddpmprun, - write_env_vars_to_file, - get_mpirun_command, - FINISHED_STATUS_FILE, - READY_FILE, - DEFAULT_SSH_PORT, -) - - -class TestWriteFileToHost: - """Test _write_file_to_host function.""" - - @patch("subprocess.run") - def test_write_file_to_host_success(self, mock_run): - """Test successful file write to host.""" - mock_run.return_value = Mock(returncode=0) - - result = _write_file_to_host("algo-1", "/tmp/test.txt") - - assert result is True - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_write_file_to_host_failure(self, mock_run): - """Test failed file write to host.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - - result = _write_file_to_host("algo-1", "/tmp/test.txt") - - assert result is False - - -class TestWriteStatusFileToWorkers: - """Test write_status_file_to_workers function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - def test_write_status_file_to_workers_success(self, mock_write): - """Test writing status file to workers successfully.""" - mock_write.return_value = True - - write_status_file_to_workers(["algo-1", "algo-2"]) - - assert mock_write.call_count == 2 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch("time.sleep") - def test_write_status_file_to_workers_with_retry(self, mock_sleep, mock_write): - """Test writing status file with retry.""" - mock_write.side_effect = [False, False, True] - - write_status_file_to_workers(["algo-1"]) - - assert mock_write.call_count == 3 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch("time.sleep") - def test_write_status_file_to_workers_timeout(self, mock_sleep, mock_write): - """Test writing status file timeout.""" - mock_write.return_value = False - - with pytest.raises(TimeoutError): - write_status_file_to_workers(["algo-1"]) - - -class TestWaitForStatusFile: - """Test _wait_for_status_file function.""" - - @patch("os.path.exists") - @patch("time.sleep") - def test_wait_for_status_file_exists(self, mock_sleep, mock_exists): - """Test waiting for status file that exists.""" - mock_exists.return_value = True - - _wait_for_status_file("/tmp/test.txt") - - mock_exists.assert_called_once() - - @patch("os.path.exists") - @patch("time.sleep") - def test_wait_for_status_file_eventually_exists(self, mock_sleep, mock_exists): - """Test waiting for status file that eventually exists.""" - mock_exists.side_effect = [False, False, True] - - _wait_for_status_file("/tmp/test.txt") - - assert mock_exists.call_count == 3 - - -class TestStartSshdDaemon: - """Test start_sshd_daemon function.""" - - @patch("os.path.exists") - @patch("subprocess.Popen") - def test_start_sshd_daemon_success(self, mock_popen, mock_exists): - """Test starting SSH daemon successfully.""" - mock_exists.return_value = True - - start_sshd_daemon() - - mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - - @patch("os.path.exists") - def test_start_sshd_daemon_not_found(self, mock_exists): - """Test starting SSH daemon when not found.""" - mock_exists.return_value = False - - with pytest.raises(RuntimeError, match="SSH daemon not found"): - start_sshd_daemon() - - -class TestCustomHostKeyPolicy: - """Test CustomHostKeyPolicy class.""" - - def test_custom_host_key_policy_algo_hostname(self): - """Test accepting algo-* hostnames.""" - policy = CustomHostKeyPolicy() - mock_client = Mock() - mock_client.get_host_keys.return_value = Mock() - mock_key = Mock() - mock_key.get_name.return_value = "ssh-rsa" - - # Should not raise exception - policy.missing_host_key(mock_client, "algo-1234", mock_key) - - def test_custom_host_key_policy_unknown_hostname(self): - """Test rejecting unknown hostnames.""" - policy = CustomHostKeyPolicy() - mock_client = Mock() - mock_key = Mock() - - with pytest.raises(paramiko.SSHException): - policy.missing_host_key(mock_client, "unknown-host", mock_key) - - -class TestCanConnect: - """Test _can_connect function.""" - - @patch("paramiko.SSHClient") - def test_can_connect_success(self, mock_ssh_client): - """Test successful connection.""" - mock_client_instance = Mock() - mock_ssh_client.return_value.__enter__.return_value = mock_client_instance - - result = _can_connect("algo-1") - - assert result is True - - @patch("paramiko.SSHClient") - def test_can_connect_failure(self, mock_ssh_client): - """Test failed connection.""" - mock_client_instance = Mock() - mock_client_instance.connect.side_effect = Exception("Connection failed") - mock_ssh_client.return_value.__enter__.return_value = mock_client_instance - - result = _can_connect("algo-1") - - assert result is False - - -class TestWaitForWorkers: - """Test _wait_for_workers function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("os.path.exists") - def test_wait_for_workers_success(self, mock_exists, mock_connect): - """Test waiting for workers successfully.""" - mock_connect.return_value = True - mock_exists.return_value = True - - _wait_for_workers(["algo-1", "algo-2"]) - - assert mock_connect.call_count >= 2 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("os.path.exists") - @patch("time.sleep") - @patch("time.time") - def test_wait_for_workers_timeout(self, mock_time, mock_sleep, mock_exists, mock_connect): - """Test waiting for workers timeout.""" - mock_connect.return_value = False - mock_exists.return_value = False - # Use side_effect with a generator to provide unlimited values - mock_time.side_effect = (i * 200 for i in range(1000)) # Simulate timeout - - with pytest.raises(TimeoutError): - _wait_for_workers(["algo-1"]) - - def test_wait_for_workers_empty_list(self): - """Test waiting for workers with empty list.""" - # Should not raise exception - _wait_for_workers([]) - - -class TestWaitForMaster: - """Test _wait_for_master function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - def test_wait_for_master_success(self, mock_connect): - """Test waiting for master successfully.""" - mock_connect.return_value = True - - _wait_for_master("algo-1") - - mock_connect.assert_called() - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("time.sleep") - @patch("time.time") - def test_wait_for_master_timeout(self, mock_time, mock_sleep, mock_connect): - """Test waiting for master timeout.""" - mock_connect.return_value = False - # Use side_effect with a generator to provide unlimited values - mock_time.side_effect = (i * 200 for i in range(1000)) # Simulate timeout - - with pytest.raises(TimeoutError): - _wait_for_master("algo-1") - - -class TestBootstrapWorkerNode: - """Test bootstrap_worker_node function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_master" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_status_file" - ) - @patch.dict(os.environ, {"SM_CURRENT_HOST": "algo-2"}) - def test_bootstrap_worker_node(self, mock_wait_status, mock_write, mock_wait_master): - """Test bootstrapping worker node.""" - mock_write.return_value = True - - bootstrap_worker_node("algo-1") - - mock_wait_master.assert_called_once_with("algo-1") - mock_write.assert_called_once() - mock_wait_status.assert_called_once() - - -class TestBootstrapMasterNode: - """Test bootstrap_master_node function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_workers" - ) - def test_bootstrap_master_node(self, mock_wait): - """Test bootstrapping master node.""" - bootstrap_master_node(["algo-2", "algo-3"]) - - mock_wait.assert_called_once_with(["algo-2", "algo-3"]) - - -class TestValidateSmddprun: - """Test validate_smddprun function.""" - - @patch("subprocess.run") - def test_validate_smddprun_installed(self, mock_run): - """Test validating smddprun when installed.""" - mock_run.return_value = Mock(stdout="smddprun") - - result = validate_smddprun() - - assert result is True - - @patch("subprocess.run") - def test_validate_smddprun_not_installed(self, mock_run): - """Test validating smddprun when not installed.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "which") - - result = validate_smddprun() - - assert result is False - - -class TestValidateSmddpmprun: - """Test validate_smddpmprun function.""" - - @patch("subprocess.run") - def test_validate_smddpmprun_installed(self, mock_run): - """Test validating smddpmprun when installed.""" - mock_run.return_value = Mock(stdout="smddpmprun") - - result = validate_smddpmprun() - - assert result is True - - @patch("subprocess.run") - def test_validate_smddpmprun_not_installed(self, mock_run): - """Test validating smddpmprun when not installed.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "which") - - result = validate_smddpmprun() - - assert result is False - - -class TestWriteEnvVarsToFile: - """Test write_env_vars_to_file function.""" - - @patch("builtins.open", create=True) - @patch.dict(os.environ, {"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}) - def test_write_env_vars_to_file(self, mock_open_func): - """Test writing environment variables to file.""" - mock_file = MagicMock() - mock_open_func.return_value.__enter__.return_value = mock_file - - write_env_vars_to_file() - - mock_open_func.assert_called_once_with("/etc/environment", "a", encoding="utf-8") - assert mock_file.write.called - - -class TestGetMpirunCommand: - """Test get_mpirun_command function.""" - - @patch.dict( - os.environ, - {"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge"}, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_basic(self, mock_python): - """Test getting basic mpirun command.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "mpirun" in result - assert "--host" in result - assert "algo-1,algo-2" in result - assert "-np" in result - assert "4" in result - - @patch.dict( - os.environ, - { - "SM_NETWORK_INTERFACE_NAME": "eth0", - "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", - "AWS_ACCESS_KEY_ID": "test_key", - "AWS_SECRET_ACCESS_KEY": "test_secret", - }, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_efa(self, mock_python): - """Test getting mpirun command with EFA instance.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "FI_PROVIDER=efa" in result - - @patch.dict( - os.environ, - {"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge"}, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_additional_options(self, mock_python): - """Test getting mpirun command with additional options.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=["-x", "CUSTOM_VAR"], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "-x" in result - assert "CUSTOM_VAR" in result - - @patch.dict( - os.environ, - { - "SM_NETWORK_INTERFACE_NAME": "eth0", - "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge", - "AWS_ACCESS_KEY_ID": "test_key", - "AWS_SECRET_ACCESS_KEY": "test_secret", - "AWS_SESSION_TOKEN": "test_token", - }, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_credentials(self, mock_python): - """Test getting mpirun command with AWS credentials.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "AWS_ACCESS_KEY_ID" in result - assert "AWS_SECRET_ACCESS_KEY" in result - assert "AWS_SESSION_TOKEN" in result diff --git a/sagemaker-core/tests/unit/modules/train/test_environment.py b/sagemaker-core/tests/unit/modules/train/test_environment.py deleted file mode 100644 index e80eeef005..0000000000 --- a/sagemaker-core/tests/unit/modules/train/test_environment.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -import json -import os -from unittest.mock import Mock, patch, mock_open, MagicMock -from sagemaker.core.modules.train.container_drivers.scripts.environment import ( - num_cpus, - num_gpus, - num_neurons, - deserialize_hyperparameters, - set_env, - mask_sensitive_info, - log_key_value, - log_env_variables, -) - - -class TestEnvironment: - """Test cases for environment module""" - - def test_num_cpus(self): - """Test num_cpus returns positive integer""" - result = num_cpus() - assert isinstance(result, int) - assert result > 0 - - @patch("subprocess.check_output") - def test_num_gpus_with_gpus(self, mock_check_output): - """Test num_gpus when GPUs are available""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" - - result = num_gpus() - assert result == 2 - - @patch("subprocess.check_output") - def test_num_gpus_no_gpus(self, mock_check_output): - """Test num_gpus when no GPUs are available""" - mock_check_output.side_effect = OSError("nvidia-smi not found") - - result = num_gpus() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_gpus_command_error(self, mock_check_output): - """Test num_gpus when command fails""" - import subprocess - - mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") - - result = num_gpus() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_with_neurons(self, mock_check_output): - """Test num_neurons when Neuron cores are available""" - mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 2}]) - mock_check_output.return_value = mock_output.encode() - - result = num_neurons() - assert result == 4 - - @patch("subprocess.check_output") - def test_num_neurons_no_neurons(self, mock_check_output): - """Test num_neurons when no Neuron cores are available""" - mock_check_output.side_effect = OSError("neuron-ls not found") - - result = num_neurons() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_command_error(self, mock_check_output): - """Test num_neurons when command fails""" - import subprocess - - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = b"error=No Neuron devices found" - mock_check_output.side_effect = error - - result = num_neurons() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_command_error_no_output(self, mock_check_output): - """Test num_neurons when command fails without output""" - import subprocess - - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = None - mock_check_output.side_effect = error - - result = num_neurons() - assert result == 0 - - def test_deserialize_hyperparameters_simple(self): - """Test deserialize_hyperparameters with simple types""" - hyperparameters = {"learning_rate": "0.001", "epochs": "10", "batch_size": "32"} - - result = deserialize_hyperparameters(hyperparameters) - - assert result["learning_rate"] == 0.001 - assert result["epochs"] == 10 - assert result["batch_size"] == 32 - - def test_deserialize_hyperparameters_complex(self): - """Test deserialize_hyperparameters with complex types""" - hyperparameters = { - "layers": "[128, 64, 32]", - "config": '{"optimizer": "adam", "loss": "mse"}', - "enabled": "true", - } - - result = deserialize_hyperparameters(hyperparameters) - - assert result["layers"] == [128, 64, 32] - assert result["config"] == {"optimizer": "adam", "loss": "mse"} - assert result["enabled"] is True - - def test_mask_sensitive_info_with_password(self): - """Test mask_sensitive_info masks password fields""" - data = {"username": "user", "password": "secret123", "api_key": "key123"} - - result = mask_sensitive_info(data) - - assert result["username"] == "user" - assert result["password"] == "******" - assert result["api_key"] == "******" - - def test_mask_sensitive_info_nested(self): - """Test mask_sensitive_info with nested dictionaries""" - data = {"config": {"db_password": "secret", "db_host": "localhost"}} - - result = mask_sensitive_info(data) - - assert result["config"]["db_password"] == "******" - assert result["config"]["db_host"] == "localhost" - - def test_mask_sensitive_info_case_insensitive(self): - """Test mask_sensitive_info is case insensitive""" - data = {"API_KEY": "key123", "Secret_Token": "token123"} - - result = mask_sensitive_info(data) - - assert result["API_KEY"] == "******" - assert result["Secret_Token"] == "******" - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_sensitive(self, mock_logger): - """Test log_key_value masks sensitive values""" - log_key_value("password", "secret123") - - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "******" in str(call_args) - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_dict(self, mock_logger): - """Test log_key_value with dictionary value""" - log_key_value("config", {"key": "value"}) - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_json_string(self, mock_logger): - """Test log_key_value with JSON string value""" - log_key_value("config", '{"key": "value"}') - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_regular(self, mock_logger): - """Test log_key_value with regular value""" - log_key_value("learning_rate", "0.001") - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - @patch.dict(os.environ, {"TEST_VAR": "test_value"}) - def test_log_env_variables(self, mock_logger): - """Test log_env_variables logs both environment and dict variables""" - env_vars_dict = {"CUSTOM_VAR": "custom_value"} - - log_env_variables(env_vars_dict) - - # Should be called for both os.environ and env_vars_dict - assert mock_logger.info.call_count > 0 - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_minimal( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with minimal configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - input_data_config = {"training": {"S3Uri": "s3://bucket/data"}} - - hyperparameters_config = {"learning_rate": "0.001", "epochs": "10"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - handle = mock_file() - assert handle.write.called - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_with_source_code( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with source code configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 1 - mock_neurons.return_value = 0 - mock_source_code.return_value = {"entry_script": "train.py"} - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - input_data_config = {"training": {"S3Uri": "s3://bucket/data"}} - - hyperparameters_config = {"learning_rate": "0.001"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_with_distributed( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with distributed configuration""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 4 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = {"smdistributed": {"dataparallel": {"enabled": True}}} - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.8xlarge", - "hosts": ["algo-1", "algo-2", "algo-3"], - "network_interface_name": "eth0", - } - - input_data_config = { - "training": {"S3Uri": "s3://bucket/data"}, - "validation": {"S3Uri": "s3://bucket/validation"}, - } - - hyperparameters_config = {"learning_rate": "0.001", "batch_size": "64"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_multiple_channels( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with multiple data channels""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - input_data_config = { - "training": {"S3Uri": "s3://bucket/train"}, - "validation": {"S3Uri": "s3://bucket/val"}, - "test": {"S3Uri": "s3://bucket/test"}, - } - - hyperparameters_config = {} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() diff --git a/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py b/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py deleted file mode 100644 index 737aa60927..0000000000 --- a/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -import tempfile -from unittest.mock import Mock, patch, MagicMock -from omegaconf import OmegaConf -from sagemaker.core.modules.train.sm_recipes.utils import ( - _try_resolve_recipe, - _determine_device_type, - _load_recipes_cfg, - _load_base_recipe, - _register_custom_resolvers, - _get_trainining_recipe_gpu_model_name_and_script, - _configure_gpu_args, - _configure_trainium_args, - _get_args_from_recipe, -) -from sagemaker.core.modules.configs import Compute - - -class TestSMRecipesUtils: - """Test cases for SM recipes utility functions""" - - def test_try_resolve_recipe_success(self): - """Test _try_resolve_recipe with resolvable recipe""" - recipe = OmegaConf.create({"value": 10, "doubled": "${value}"}) - - result = _try_resolve_recipe(recipe) - - assert result is not None - assert result["doubled"] == 10 - - def test_try_resolve_recipe_with_key(self): - """Test _try_resolve_recipe with key parameter""" - recipe = 10 - - result = _try_resolve_recipe(recipe, key="test") - - assert result is not None - assert result == 10 - - def test_try_resolve_recipe_unresolvable(self): - """Test _try_resolve_recipe with unresolvable recipe""" - recipe = OmegaConf.create({"value": "${missing_var}"}) - - result = _try_resolve_recipe(recipe) - - assert result is None - - def test_determine_device_type_gpu_p_instance(self): - """Test _determine_device_type with P instance (GPU)""" - result = _determine_device_type("ml.p3.2xlarge") - assert result == "gpu" - - def test_determine_device_type_gpu_g_instance(self): - """Test _determine_device_type with G instance (GPU)""" - result = _determine_device_type("ml.g4dn.xlarge") - assert result == "gpu" - - def test_determine_device_type_trainium(self): - """Test _determine_device_type with Trainium instance""" - result = _determine_device_type("ml.trn1.2xlarge") - assert result == "trainium" - - def test_determine_device_type_cpu(self): - """Test _determine_device_type with CPU instance""" - result = _determine_device_type("ml.m5.xlarge") - assert result == "cpu" - - @patch("sagemaker.core.modules.train.sm_recipes.utils.open") - @patch("sagemaker.core.modules.train.sm_recipes.utils.json.load") - def test_load_recipes_cfg(self, mock_json_load, mock_open): - """Test _load_recipes_cfg loads configuration""" - mock_json_load.return_value = {"launcher_repo": "test_repo", "adapter_repo": "test_adapter"} - - result = _load_recipes_cfg() - - assert isinstance(result, dict) - assert "launcher_repo" in result or "adapter_repo" in result or "neuron_dist_repo" in result - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.shutil.copy") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.load") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.merge") - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.unlink") - def test_load_base_recipe_from_file( - self, mock_unlink, mock_merge, mock_load, mock_copy, mock_isfile - ): - """Test _load_base_recipe from local file""" - mock_isfile.return_value = True - mock_recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - mock_load.return_value = mock_recipe - mock_merge.return_value = mock_recipe - - result = _load_base_recipe("recipe.yaml") - - assert result is not None - mock_copy.assert_called_once() - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.urlretrieve") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.load") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.merge") - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.unlink") - def test_load_base_recipe_from_url( - self, mock_unlink, mock_merge, mock_load, mock_urlretrieve, mock_isfile - ): - """Test _load_base_recipe from URL""" - mock_isfile.return_value = False - mock_recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - mock_load.return_value = mock_recipe - mock_merge.return_value = mock_recipe - - result = _load_base_recipe("https://example.com/recipe.yaml") - - assert result is not None - mock_urlretrieve.assert_called_once() - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.urlretrieve") - def test_load_base_recipe_url_error(self, mock_urlretrieve, mock_isfile): - """Test _load_base_recipe raises error on URL fetch failure""" - mock_isfile.return_value = False - mock_urlretrieve.side_effect = Exception("Network error") - - with pytest.raises(ValueError, match="Could not fetch the provided recipe"): - _load_base_recipe("https://example.com/recipe.yaml") - - def test_register_custom_resolvers(self): - """Test _register_custom_resolvers registers OmegaConf resolvers""" - _register_custom_resolvers() - - # Test multiply resolver - recipe = OmegaConf.create({"a": 5, "b": "${multiply:${a},2}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 10 - - # Test divide_ceil resolver - recipe = OmegaConf.create({"a": 10, "b": "${divide_ceil:${a},3}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 4 - - # Test divide_floor resolver - recipe = OmegaConf.create({"a": 10, "b": "${divide_floor:${a},3}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 3 - - # Test add resolver - recipe = OmegaConf.create({"a": "${add:1,2,3}"}) - OmegaConf.resolve(recipe) - assert recipe["a"] == 6 - - def test_get_trainining_recipe_gpu_model_name_and_script_llama(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Llama""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("llama_v3_8b") - - assert model_name == "llama" - assert script == "llama_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_mistral(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Mistral""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("mistral_7b") - - assert model_name == "mistral" - assert script == "mistral_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_mixtral(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Mixtral""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("mixtral_8x7b") - - assert model_name == "mixtral" - assert script == "mixtral_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_deepseek(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for DeepSeek""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("deepseek_v2") - - assert model_name == "deepseek" - assert script == "deepseek_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_unsupported(self): - """Test _get_trainining_recipe_gpu_model_name_and_script with unsupported model""" - with pytest.raises(ValueError, match="Model type .* not supported"): - _get_trainining_recipe_gpu_model_name_and_script("unsupported_model") - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": {"framework": "pytorch", "version": "2.0", "additional_args": {}}, - } - - recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - - recipe_train_dir = tempfile.TemporaryDirectory() - mock_retrieve.return_value = "test-image:latest" - - result = _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - assert "source_code" in result - assert "training_image" in result - assert "distributed" in result - assert result["training_image"] == "test-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args_string_image(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args with string image config""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": "custom-image:latest", - } - - recipe = OmegaConf.create({"model": {"model_type": "mistral"}}) - - recipe_train_dir = tempfile.TemporaryDirectory() - - result = _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - assert result["training_image"] == "custom-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args_missing_model(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args raises error when model field is missing""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": "test-image:latest", - } - - recipe = OmegaConf.create({}) - recipe_train_dir = tempfile.TemporaryDirectory() - - with pytest.raises(ValueError, match="does not contain required field model"): - _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_trainium_args(self, mock_retrieve, mock_clone): - """Test _configure_trainium_args""" - training_recipes_cfg = { - "neuron_dist_repo": "https://github.com/test/neuron", - "neuron_image": {"framework": "pytorch", "version": "1.13", "additional_args": {}}, - } - - recipe_train_dir = tempfile.TemporaryDirectory() - mock_retrieve.return_value = "neuron-image:latest" - - result = _configure_trainium_args(training_recipes_cfg, "us-west-2", recipe_train_dir) - - assert "source_code" in result - assert "training_image" in result - assert "distributed" in result - assert result["training_image"] == "neuron-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_gpu_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.save") - def test_get_args_from_recipe_gpu( - self, - mock_save, - mock_resolve, - mock_register, - mock_configure_gpu, - mock_load_recipe, - mock_load_cfg, - ): - """Test _get_args_from_recipe for GPU instance""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=2) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create( - {"trainer": {"num_nodes": 1}, "model": {"model_type": "llama_v3"}} - ) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = mock_recipe - - mock_configure_gpu.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "test-image:latest", - "distributed": Mock(), - } - - result, temp_dir = _get_args_from_recipe( - training_recipe="llama_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - assert "source_code" in result - assert "training_image" in result - assert "compute" in result - assert "hyperparameters" in result - assert result["compute"].instance_count == 2 - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_trainium_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.save") - def test_get_args_from_recipe_trainium( - self, - mock_save, - mock_resolve, - mock_register, - mock_configure_trainium, - mock_load_recipe, - mock_load_cfg, - ): - """Test _get_args_from_recipe for Trainium instance""" - compute = Compute(instance_type="ml.trn1.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create({"trainer": {"num_nodes": 1}}) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = mock_recipe - - mock_configure_trainium.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "neuron-image:latest", - "distributed": Mock(), - } - - result, temp_dir = _get_args_from_recipe( - training_recipe="neuron_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - assert "source_code" in result - assert "training_image" in result - - def test_get_args_from_recipe_no_instance_type(self): - """Test _get_args_from_recipe raises error without instance_type""" - compute = Compute(instance_count=1) - - with pytest.raises(ValueError, match="Must set `instance_type`"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - def test_get_args_from_recipe_missing_trainer(self, mock_load_recipe, mock_load_cfg): - """Test _get_args_from_recipe raises error when trainer field is missing""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create({}) - mock_load_recipe.return_value = mock_recipe - - with pytest.raises(ValueError, match="does not contain required field trainer"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_gpu_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - def test_get_args_from_recipe_unresolvable( - self, mock_resolve, mock_register, mock_configure_gpu, mock_load_recipe, mock_load_cfg - ): - """Test _get_args_from_recipe raises error when recipe cannot be resolved""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create( - {"trainer": {"num_nodes": 1}, "model": {"model_type": "llama_v3"}} - ) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = None # Cannot resolve - - mock_configure_gpu.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "test-image:latest", - "distributed": Mock(), - } - - with pytest.raises(RuntimeError, match="Could not resolve provided recipe"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - def test_get_args_from_recipe_cpu_not_supported(self): - """Test _get_args_from_recipe raises error for CPU instances""" - compute = Compute(instance_type="ml.m5.xlarge", instance_count=1) - - with patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg"): - with patch( - "sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe" - ) as mock_load: - mock_load.return_value = OmegaConf.create({"trainer": {"num_nodes": 1}}) - - with pytest.raises(ValueError, match="Devices of type cpu are not supported"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) diff --git a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py index 03146a3bbe..c07aa1359a 100644 --- a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py +++ b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py @@ -124,8 +124,10 @@ def safe_deserialize(data: Any) -> Any: This function handles the following cases: 1. If `data` is not a string, it returns the input as-is. - 2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. - 3. If `data` is a string but cannot be decoded as JSON, it returns the original string. + 2. If `data` is a string and matches common boolean values ("true" or "false"), + it returns the corresponding boolean value (True or False). + 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. + 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. Returns: Any: The deserialized data, or the original input if it cannot be JSON-decoded. @@ -133,6 +135,12 @@ def safe_deserialize(data: Any) -> Any: if not isinstance(data, str): return data + lower_data = data.lower() + if lower_data in ["true"]: + return True + if lower_data in ["false"]: + return False + try: return json.loads(data) except json.JSONDecodeError: diff --git a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py index f7d4b978d4..4744a4d493 100644 --- a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py +++ b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py @@ -24,7 +24,7 @@ import omegaconf from omegaconf import OmegaConf, dictconfig -# from sagemaker.utils.image_uris import retrieve +from sagemaker.image_uris import retrieve from sagemaker.train import logger from sagemaker.train.utils import _run_clone_command_silent @@ -176,14 +176,13 @@ def _configure_gpu_args( if isinstance(gpu_image_cfg, str): training_image = gpu_image_cfg else: - # training_image = retrieve( - # gpu_image_cfg.get("framework"), - # region=region_name, - # version=gpu_image_cfg.get("version"), - # image_scope="training", - # **gpu_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + gpu_image_cfg.get("framework"), + region=region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) # Setting dummy parameters for now torch_distributed = Torchrun(smp=SMP(random_seed="123456")) @@ -214,14 +213,13 @@ def _configure_trainium_args( if isinstance(neuron_image_cfg, str): training_image = neuron_image_cfg else: - # training_image = retrieve( - # neuron_image_cfg.get("framework"), - # region=region_name, - # version=neuron_image_cfg.get("version"), - # image_scope="training", - # **neuron_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + neuron_image_cfg.get("framework"), + region=region_name, + version=neuron_image_cfg.get("version"), + image_scope="training", + **neuron_image_cfg.get("additional_args"), + ) args.update( { From 8400a3e6d7dce8978e9f1910575878592e0d08a6 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:29:07 -0800 Subject: [PATCH 8/8] Remove duplicate code in sagemaker/core/training --- .../src/sagemaker/core/modules/distributed.py | 4 +- .../src/sagemaker/core/training/__init__.py | 14 - .../src/sagemaker/core/training/configs.py | 333 ------------------ .../src/sagemaker/core/training/constants.py | 37 -- .../src/sagemaker/core/training/utils.py | 77 ---- .../src/sagemaker/mlops/workflow/_utils.py | 4 +- .../src/sagemaker/train/base_trainer.py | 2 +- .../src/sagemaker/train/configs.py | 302 +++++++++++++++- .../src/sagemaker/train/constants.py | 10 +- .../sagemaker/train/modules/model_trainer.py | 4 +- .../src/sagemaker/train/sm_recipes/utils.py | 2 +- sagemaker-train/src/sagemaker/train/tuner.py | 4 +- .../tests/unit/test_training_constants.py | 4 +- 13 files changed, 305 insertions(+), 492 deletions(-) delete mode 100644 sagemaker-core/src/sagemaker/core/training/__init__.py delete mode 100644 sagemaker-core/src/sagemaker/core/training/configs.py delete mode 100644 sagemaker-core/src/sagemaker/core/training/constants.py delete mode 100644 sagemaker-core/src/sagemaker/core/training/utils.py rename {sagemaker-core => sagemaker-train}/tests/unit/test_training_constants.py (96%) diff --git a/sagemaker-core/src/sagemaker/core/modules/distributed.py b/sagemaker-core/src/sagemaker/core/modules/distributed.py index 21a33343c3..14c5837846 100644 --- a/sagemaker-core/src/sagemaker/core/modules/distributed.py +++ b/sagemaker-core/src/sagemaker/core/modules/distributed.py @@ -19,8 +19,8 @@ from typing import Optional, Dict, Any, List from sagemaker.core.modules.utils import safe_serialize -from sagemaker.core.training.configs import BaseConfig -from sagemaker.core.training.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.train.configs import BaseConfig +from sagemaker.train.constants import SM_DRIVERS_LOCAL_PATH class SMP(BaseConfig): diff --git a/sagemaker-core/src/sagemaker/core/training/__init__.py b/sagemaker-core/src/sagemaker/core/training/__init__.py deleted file mode 100644 index 86ce9b7a0f..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Training configuration and utilities.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/training/configs.py b/sagemaker-core/src/sagemaker/core/training/configs.py deleted file mode 100644 index ad2232d630..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/configs.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides the configuration classes used in ``sagemaker.modules``. - -Some of these classes are re-exported from ``sagemaker.core.shapes``. For convinence, -users can import these classes directly from ``sagemaker.modules.configs``. - -For more documentation on ``sagemaker.core.shapes``, see: - - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes -""" - -from __future__ import absolute_import - -from typing import Optional, Union -from pydantic import BaseModel, model_validator, ConfigDict - -import sagemaker.core.shapes as shapes -from sagemaker.core.helper.pipeline_variable import StrPipeVar - -# TODO: Can we add custom logic to some of these to set better defaults? -from sagemaker.core.shapes import ( - StoppingCondition, - RetryStrategy, - Channel, - ShuffleConfig, - DataSource, - S3DataSource, - FileSystemDataSource, - TrainingImageConfig, - TrainingRepositoryAuthConfig, - Tag, - InfraCheckConfig, - RemoteDebugConfig, - SessionChainingConfig, - InstanceGroup, - HubAccessConfig, - ModelAccessConfig, - MetricDefinition, - DatasetSource, -) - -from sagemaker.core.training.utils import convert_unassigned_to_none - -__all__ = [ - "BaseConfig", - "SourceCode", - "StoppingCondition", - "RetryStrategy", - "OutputDataConfig", - "Channel", - "ShuffleConfig", - "DataSource", - "S3DataSource", - "FileSystemDataSource", - "TrainingImageConfig", - "TrainingRepositoryAuthConfig", - "Tag", - "InfraCheckConfig", - "RemoteDebugConfig", - "SessionChainingConfig", - "InstanceGroup", - "TensorBoardOutputConfig", - "CheckpointConfig", - "HubAccessConfig", - "ModelAccessConfig", - "Compute", - "Networking", - "InputData", - "MetricDefinition", - "DatasetSource", -] - - -class BaseConfig(BaseModel): - """BaseConfig""" - - model_config = ConfigDict(validate_assignment=True, extra="forbid") - - -class SourceCode(BaseConfig): - """SourceCode. - - The SourceCode class allows the user to specify the source code location, dependencies, - entry script, or commands to be executed in the training job container. - - Parameters: - source_dir (Optional[StrPipeVar]): - The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that - contains the source code to be used in the training job container. - requirements (Optional[StrPipeVar]): - The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed - requirements will be installed in the training job container. - entry_script (Optional[StrPipeVar]): - The path within ``source_dir`` to the entry script that will be executed in the training - job container. If not specified, command must be provided. - command (Optional[StrPipeVar]): - The command(s) to execute in the training job container. Example: "python my_script.py". - If not specified, entry_script must be provided. - """ - - source_dir: Optional[StrPipeVar] = None - requirements: Optional[StrPipeVar] = None - entry_script: Optional[StrPipeVar] = None - command: Optional[StrPipeVar] = None - - -class OutputDataConfig(shapes.OutputDataConfig): - """OutputDataConfig. - - Provides the configuration for the output data location of the training job. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - The S3 URI where the output data will be stored. This is the location where the - training job will save its output data, such as model artifacts and logs. - kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that - SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side - encryption. - compression_type (Optional[StrPipeVar]): - The model output compression type. Select None to output an uncompressed model, - recommended for large model outputs. Defaults to gzip. - """ - - s3_output_path: Optional[StrPipeVar] = None - kms_key_id: Optional[StrPipeVar] = None - compression_type: Optional[StrPipeVar] = None - - -class Compute(shapes.ResourceConfig): - """Compute. - - The Compute class is a subclass of ``sagemaker.core.shapes.ResourceConfig`` - and allows the user to specify the compute resources for the training job. - - Parameters: - instance_type (Optional[StrPipeVar]): - The ML compute instance type. For information about available instance types, - see https://aws.amazon.com/sagemaker/pricing/. - instance_count (Optional[int]): The number of ML compute instances to use. For distributed - training, provide a value greater than 1. - volume_size_in_gb (Optional[int]): - The size of the ML storage volume that you want to provision. ML storage volumes store - model artifacts and incremental states. Training algorithms might also use the ML - storage volume for scratch space. Default: 30 - volume_kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage - volume attached to the ML compute instance(s) that run the training job. - keep_alive_period_in_seconds (Optional[int]): - The duration of time in seconds to retain configured resources in a warm pool for - subsequent training jobs. - instance_groups (Optional[List[InstanceGroup]]): - A list of instance groups for heterogeneous clusters to be used in the training job. - training_plan_arn (Optional[StrPipeVar]): - The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. - enable_managed_spot_training (Optional[bool]): - To train models using managed spot training, choose True. Managed spot training - provides a fully managed and scalable infrastructure for training machine learning - models. this option is useful when training jobs can be interrupted and when there - is flexibility when the training job is run. - """ - - volume_size_in_gb: Optional[int] = 30 - enable_managed_spot_training: Optional[bool] = None - - @model_validator(mode="after") - def _model_validator(self) -> "Compute": - """Convert Unassigned values to None.""" - return convert_unassigned_to_none(self) - - def _to_resource_config(self) -> shapes.ResourceConfig: - """Convert to a sagemaker.core.shapes.ResourceConfig object.""" - compute_config_dict = self.model_dump() - resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) - filtered_dict = { - k: v - for k, v in compute_config_dict.items() - if k in resource_config_fields and v is not None - } - if not filtered_dict: - return None - return shapes.ResourceConfig(**filtered_dict) - - -class Networking(shapes.VpcConfig): - """Networking. - - The Networking class is a subclass of ``sagemaker.core.shapes.VpcConfig`` and - allows the user to specify the networking configuration for the training job. - - Parameters: - security_group_ids (Optional[List[StrPipeVar]]): - The VPC security group IDs, in the form sg-xxxxxxxx. Specify the - security groups for the VPC that is specified in the Subnets field. - subnets (Optional[List[StrPipeVar]]): - The ID of the subnets in the VPC to which you want to connect your - training job or model. - enable_network_isolation (Optional[bool]): - Isolates the training container. No inbound or outbound network calls can be made, - except for calls between peers within a training cluster for distributed training. - If you enable network isolation for training jobs that are configured to use a VPC, - SageMaker downloads and uploads customer data and model artifacts through the - specified VPC, but the training container does not have network access. - enable_inter_container_traffic_encryption (Optional[bool]): - To encrypt all communications between ML compute instances in distributed training - choose True. Encryption provides greater security for distributed training, but - training might take longer. How long it takes depends on the amount of - communication between compute instances, especially if you use a deep learning - algorithm in distributed training. - """ - - security_group_ids: Optional[list[StrPipeVar]] = None - subnets: Optional[list[StrPipeVar]] = None - enable_network_isolation: Optional[bool] = None - enable_inter_container_traffic_encryption: Optional[bool] = None - - @model_validator(mode="after") - def _model_validator(self) -> "Networking": - """Convert Unassigned values to None.""" - return convert_unassigned_to_none(self) - - def _to_vpc_config(self) -> shapes.VpcConfig: - """Convert to a sagemaker.core.shapes.VpcConfig object.""" - compute_config_dict = self.model_dump() - vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) - filtered_dict = { - k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None - } - if not filtered_dict: - return None - return shapes.VpcConfig(**filtered_dict) - - -class InputData(BaseConfig): - """InputData. - - This config allows the user to specify an input data source for the training job. - - Will be found at ``/opt/ml/input/data/`` within the training container. - For convience, can be referenced inside the training container like: - - .. code:: python - - import os - input_data_dir = os.environ['SM_CHANNEL_'] - - Parameters: - channel_name (StrPipeVar): - The name of the input data source channel. - data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]): - The data source for the channel. Can be an S3 URI string, local file path string, - S3DataSource object, FileSystemDataSource object, DatasetSource object, or a - pipeline variable (Properties) from a previous step. - content_type (StrPipeVar): - The MIME type of the data. - """ - - channel_name: StrPipeVar = None - data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None - content_type: StrPipeVar = None - - -class OutputDataConfig(shapes.OutputDataConfig): - """OutputDataConfig. - - The OutputDataConfig class is a subclass of ``sagemaker.core.shapes.OutputDataConfig`` - and allows the user to specify the output data configuration for the training job. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - The S3 URI where the output data will be stored. This is the location where the - training job will save its output data, such as model artifacts and logs. - kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that - SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side - encryption. - compression_type (Optional[StrPipeVar]): - The model output compression type. Select `NONE` to output an uncompressed model, - recommended for large model outputs. Defaults to `GZIP`. - """ - - s3_output_path: Optional[StrPipeVar] = None - kms_key_id: Optional[StrPipeVar] = None - compression_type: Optional[StrPipeVar] = None - - -class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): - """TensorBoardOutputConfig. - - The TensorBoardOutputConfig class is a subclass of ``sagemaker.core.shapes.TensorBoardOutputConfig`` - and allows the user to specify the storage locations for the Amazon SageMaker - Debugger TensorBoard. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - Path to Amazon S3 storage location for TensorBoard output. If not specified, will - default to - ``s3://////tensorboard-output`` - local_path (Optional[StrPipeVar]): - Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. - """ - - s3_output_path: Optional[StrPipeVar] = None - local_path: Optional[StrPipeVar] = "/opt/ml/output/tensorboard" - - -class CheckpointConfig(shapes.CheckpointConfig): - """CheckpointConfig. - - The CheckpointConfig class is a subclass of ``sagemaker.core.shapes.CheckpointConfig`` - and allows the user to specify the checkpoint configuration for the training job. - - Parameters: - s3_uri (Optional[StrPipeVar]): - Path to Amazon S3 storage location for the Checkpoint data. If not specified, will - default to - ``s3://////checkpoints`` - local_path (Optional[StrPipeVar]): - The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. - """ - - s3_uri: Optional[StrPipeVar] = None - local_path: Optional[StrPipeVar] = "/opt/ml/checkpoints" diff --git a/sagemaker-core/src/sagemaker/core/training/constants.py b/sagemaker-core/src/sagemaker/core/training/constants.py deleted file mode 100644 index 76a866581e..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/constants.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Constants module.""" -from __future__ import absolute_import -import os - -DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" - -SM_CODE = "code" -SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" - -SM_DRIVERS = "sm_drivers" -SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers" -SM_DRIVERS_LOCAL_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "container_drivers" -) - -SOURCE_CODE_JSON = "sourcecode.json" -DISTRIBUTED_JSON = "distributed.json" -TRAIN_SCRIPT = "sm_train.sh" - -DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"] -DEFAULT_CONTAINER_ARGUMENTS = [ - "-c", - f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} " - + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", -] diff --git a/sagemaker-core/src/sagemaker/core/training/utils.py b/sagemaker-core/src/sagemaker/core/training/utils.py deleted file mode 100644 index 67009c9131..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Training utilities.""" -from __future__ import absolute_import - -import os -from typing import Any, Literal -from sagemaker.core.utils.utils import Unassigned - - -def convert_unassigned_to_none(instance) -> Any: - """Convert Unassigned values to None for any instance.""" - for name, value in instance.__dict__.items(): - if isinstance(value, Unassigned): - setattr(instance, name, None) - return instance - - -def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: - """Check if the path is a valid local path. - - Args: - path (str): Local path to validate - path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. - Defaults to "Any". - - Returns: - bool: True if the path is a valid local path, False otherwise - """ - if not os.path.exists(path): - return False - - if path_type == "File": - return os.path.isfile(path) - if path_type == "Directory": - return os.path.isdir(path) - - return path_type == "Any" - - -def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: - """Check if the path is a valid S3 URI. - - This method checks if the path is a valid S3 URI. If the path_type is specified, - it will also check if the path is a file or a directory. - This method does not check if the S3 bucket or object exists. - - Args: - path (str): S3 URI to validate - path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. - Defaults to "Any". - - Returns: - bool: True if the path is a valid S3 URI, False otherwise - """ - # Check if the path is a valid S3 URI - if not path.startswith("s3://"): - return False - - if path_type == "File": - # If it's a file, it should not end with a slash - return not path.endswith("/") - if path_type == "Directory": - # If it's a directory, it should end with a slash - return path.endswith("/") - - return path_type == "Any" diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py index 3b98b49028..78bb32dfd5 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py @@ -20,7 +20,7 @@ import tempfile from typing import List, Union, Optional, TYPE_CHECKING from sagemaker.core import image_uris -from sagemaker.core.training.configs import InputData +from sagemaker.train.configs import InputData # Lazy import to avoid circular dependency if TYPE_CHECKING: pass @@ -152,7 +152,7 @@ def __init__( requirements_file = self._requirements if self._requirements and self._requirements.endswith('.txt') else None # Configure ModelTrainer components for repacking - from sagemaker.core.training.configs import SourceCode, Compute, Networking + from sagemaker.train.configs import SourceCode, Compute, Networking source_code = SourceCode( source_dir=self._source_dir, diff --git a/sagemaker-train/src/sagemaker/train/base_trainer.py b/sagemaker-train/src/sagemaker/train/base_trainer.py index 873b42f81b..f44c1c9b79 100644 --- a/sagemaker-train/src/sagemaker/train/base_trainer.py +++ b/sagemaker-train/src/sagemaker/train/base_trainer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List, Union from sagemaker.core.helper.session_helper import Session -from sagemaker.core.training.configs import Tag, Networking, InputData, Channel +from sagemaker.train.configs import Tag, Networking, InputData, Channel from sagemaker.core.shapes import shapes from sagemaker.core.resources import TrainingJob diff --git a/sagemaker-train/src/sagemaker/train/configs.py b/sagemaker-train/src/sagemaker/train/configs.py index 79b4eedc5e..c164c8ece4 100644 --- a/sagemaker-train/src/sagemaker/train/configs.py +++ b/sagemaker-train/src/sagemaker/train/configs.py @@ -10,22 +10,300 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.training.configs +"""This module provides the configuration classes used in ``sagemaker.modules``. + +Some of these classes are re-exported from ``sagemaker.core.shapes``. For convinence, +users can import these classes directly from ``sagemaker.modules.configs``. -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.training.configs import ... +For more documentation on ``sagemaker.core.shapes``, see: + - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes """ + from __future__ import absolute_import -import warnings +from typing import Optional, Union +from pydantic import BaseModel, model_validator, ConfigDict -# Backward compatibility: re-export from core -from sagemaker.core.training.configs import * # noqa: F401, F403 +import sagemaker.core.shapes as shapes +from sagemaker.core.helper.pipeline_variable import StrPipeVar -warnings.warn( - "sagemaker.train.configs has been moved to sagemaker.core.training.configs. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 +# TODO: Can we add custom logic to some of these to set better defaults? +from sagemaker.core.shapes import ( + StoppingCondition, + RetryStrategy, + Channel, + ShuffleConfig, + DataSource, + S3DataSource, + FileSystemDataSource, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + InstanceGroup, + HubAccessConfig, + ModelAccessConfig, + MetricDefinition, + DatasetSource, ) + +from sagemaker.train.utils import convert_unassigned_to_none + +__all__ = [ + "BaseConfig", + "SourceCode", + "StoppingCondition", + "RetryStrategy", + "OutputDataConfig", + "Channel", + "ShuffleConfig", + "DataSource", + "S3DataSource", + "FileSystemDataSource", + "TrainingImageConfig", + "TrainingRepositoryAuthConfig", + "Tag", + "InfraCheckConfig", + "RemoteDebugConfig", + "SessionChainingConfig", + "InstanceGroup", + "TensorBoardOutputConfig", + "CheckpointConfig", + "HubAccessConfig", + "ModelAccessConfig", + "Compute", + "Networking", + "InputData", + "MetricDefinition", + "DatasetSource", +] + + +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): + """SourceCode. + + The SourceCode class allows the user to specify the source code location, dependencies, + entry script, or commands to be executed in the training job container. + + Parameters: + source_dir (Optional[StrPipeVar]): + The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that + contains the source code to be used in the training job container. + requirements (Optional[StrPipeVar]): + The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed + requirements will be installed in the training job container. + entry_script (Optional[StrPipeVar]): + The path within ``source_dir`` to the entry script that will be executed in the training + job container. If not specified, command must be provided. + command (Optional[StrPipeVar]): + The command(s) to execute in the training job container. Example: "python my_script.py". + If not specified, entry_script must be provided. + """ + + source_dir: Optional[StrPipeVar] = None + requirements: Optional[StrPipeVar] = None + entry_script: Optional[StrPipeVar] = None + command: Optional[StrPipeVar] = None + + +class OutputDataConfig(shapes.OutputDataConfig): + """OutputDataConfig. + + Provides the configuration for the output data location of the training job. + + Parameters: + s3_output_path (Optional[StrPipeVar]): + The S3 URI where the output data will be stored. This is the location where the + training job will save its output data, such as model artifacts and logs. + kms_key_id (Optional[StrPipeVar]): + The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that + SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side + encryption. + compression_type (Optional[StrPipeVar]): + The model output compression type. Select None to output an uncompressed model, + recommended for large model outputs. Defaults to gzip. + """ + + s3_output_path: Optional[StrPipeVar] = None + kms_key_id: Optional[StrPipeVar] = None + compression_type: Optional[StrPipeVar] = None + + +class Compute(shapes.ResourceConfig): + """Compute. + + The Compute class is a subclass of ``sagemaker.core.shapes.ResourceConfig`` + and allows the user to specify the compute resources for the training job. + + Parameters: + instance_type (Optional[StrPipeVar]): + The ML compute instance type. For information about available instance types, + see https://aws.amazon.com/sagemaker/pricing/. + instance_count (Optional[int]): The number of ML compute instances to use. For distributed + training, provide a value greater than 1. + volume_size_in_gb (Optional[int]): + The size of the ML storage volume that you want to provision. ML storage volumes store + model artifacts and incremental states. Training algorithms might also use the ML + storage volume for scratch space. Default: 30 + volume_kms_key_id (Optional[StrPipeVar]): + The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage + volume attached to the ML compute instance(s) that run the training job. + keep_alive_period_in_seconds (Optional[int]): + The duration of time in seconds to retain configured resources in a warm pool for + subsequent training jobs. + instance_groups (Optional[List[InstanceGroup]]): + A list of instance groups for heterogeneous clusters to be used in the training job. + training_plan_arn (Optional[StrPipeVar]): + The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. + enable_managed_spot_training (Optional[bool]): + To train models using managed spot training, choose True. Managed spot training + provides a fully managed and scalable infrastructure for training machine learning + models. this option is useful when training jobs can be interrupted and when there + is flexibility when the training job is run. + """ + + volume_size_in_gb: Optional[int] = 30 + enable_managed_spot_training: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Compute": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_resource_config(self) -> shapes.ResourceConfig: + """Convert to a sagemaker.core.shapes.ResourceConfig object.""" + compute_config_dict = self.model_dump() + resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) + filtered_dict = { + k: v + for k, v in compute_config_dict.items() + if k in resource_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.ResourceConfig(**filtered_dict) + + +class Networking(shapes.VpcConfig): + """Networking. + + The Networking class is a subclass of ``sagemaker.core.shapes.VpcConfig`` and + allows the user to specify the networking configuration for the training job. + + Parameters: + security_group_ids (Optional[List[StrPipeVar]]): + The VPC security group IDs, in the form sg-xxxxxxxx. Specify the + security groups for the VPC that is specified in the Subnets field. + subnets (Optional[List[StrPipeVar]]): + The ID of the subnets in the VPC to which you want to connect your + training job or model. + enable_network_isolation (Optional[bool]): + Isolates the training container. No inbound or outbound network calls can be made, + except for calls between peers within a training cluster for distributed training. + If you enable network isolation for training jobs that are configured to use a VPC, + SageMaker downloads and uploads customer data and model artifacts through the + specified VPC, but the training container does not have network access. + enable_inter_container_traffic_encryption (Optional[bool]): + To encrypt all communications between ML compute instances in distributed training + choose True. Encryption provides greater security for distributed training, but + training might take longer. How long it takes depends on the amount of + communication between compute instances, especially if you use a deep learning + algorithm in distributed training. + """ + + security_group_ids: Optional[list[StrPipeVar]] = None + subnets: Optional[list[StrPipeVar]] = None + enable_network_isolation: Optional[bool] = None + enable_inter_container_traffic_encryption: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Networking": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_vpc_config(self) -> shapes.VpcConfig: + """Convert to a sagemaker.core.shapes.VpcConfig object.""" + compute_config_dict = self.model_dump() + vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) + filtered_dict = { + k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.VpcConfig(**filtered_dict) + + +class InputData(BaseConfig): + """InputData. + + This config allows the user to specify an input data source for the training job. + + Will be found at ``/opt/ml/input/data/`` within the training container. + For convience, can be referenced inside the training container like: + + .. code:: python + + import os + input_data_dir = os.environ['SM_CHANNEL_'] + + Parameters: + channel_name (StrPipeVar): + The name of the input data source channel. + data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]): + The data source for the channel. Can be an S3 URI string, local file path string, + S3DataSource object, FileSystemDataSource object, DatasetSource object, or a + pipeline variable (Properties) from a previous step. + content_type (StrPipeVar): + The MIME type of the data. + """ + + channel_name: StrPipeVar = None + data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None + content_type: StrPipeVar = None + + +class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): + """TensorBoardOutputConfig. + + The TensorBoardOutputConfig class is a subclass of ``sagemaker.core.shapes.TensorBoardOutputConfig`` + and allows the user to specify the storage locations for the Amazon SageMaker + Debugger TensorBoard. + + Parameters: + s3_output_path (Optional[StrPipeVar]): + Path to Amazon S3 storage location for TensorBoard output. If not specified, will + default to + ``s3://////tensorboard-output`` + local_path (Optional[StrPipeVar]): + Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. + """ + + s3_output_path: Optional[StrPipeVar] = None + local_path: Optional[StrPipeVar] = "/opt/ml/output/tensorboard" + + +class CheckpointConfig(shapes.CheckpointConfig): + """CheckpointConfig. + + The CheckpointConfig class is a subclass of ``sagemaker.core.shapes.CheckpointConfig`` + and allows the user to specify the checkpoint configuration for the training job. + + Parameters: + s3_uri (Optional[StrPipeVar]): + Path to Amazon S3 storage location for the Checkpoint data. If not specified, will + default to + ``s3://////checkpoints`` + local_path (Optional[StrPipeVar]): + The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. + """ + + s3_uri: Optional[StrPipeVar] = None + local_path: Optional[StrPipeVar] = "/opt/ml/checkpoints" diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 309265d659..2e136d4277 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -10,16 +10,12 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.training.constants - -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.training.constants import ... -""" +"""Constants module.""" from __future__ import absolute_import - import os +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" + SM_CODE = "code" SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" diff --git a/sagemaker-train/src/sagemaker/train/modules/model_trainer.py b/sagemaker-train/src/sagemaker/train/modules/model_trainer.py index 32af7f6c88..b7b0bd8bf5 100644 --- a/sagemaker-train/src/sagemaker/train/modules/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/modules/model_trainer.py @@ -23,8 +23,8 @@ from graphene.utils.str_converters import to_camel_case, to_snake_case -from sagemaker.core.training.configs import Compute, Networking, InputData, SourceCode -from sagemaker.core.training.constants import DEFAULT_INSTANCE_TYPE, DEFAULT_CONTAINER_ENTRYPOINT, \ +from sagemaker.train.configs import Compute, Networking, InputData, SourceCode +from sagemaker.train.constants import DEFAULT_INSTANCE_TYPE, DEFAULT_CONTAINER_ENTRYPOINT, \ DEFAULT_CONTAINER_ARGUMENTS, SM_DRIVERS, SM_CODE_CONTAINER_PATH, TRAIN_SCRIPT, DISTRIBUTED_JSON, SOURCE_CODE_JSON, \ SM_CODE, SM_DRIVERS_LOCAL_PATH from sagemaker.train.distributed import DistributedConfig, Torchrun diff --git a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py index 4744a4d493..09af11549f 100644 --- a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py +++ b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py @@ -129,7 +129,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): """Get the model base name and script for the training recipe.""" model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), + "llama": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index d1af08e2f1..aa54e36f3d 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -52,8 +52,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from sagemaker.train.model_trainer import ModelTrainer -from sagemaker.core.training.configs import InputData -from sagemaker.core.training.utils import _is_valid_s3_uri +from sagemaker.train.configs import InputData +from sagemaker.train.utils import _is_valid_s3_uri HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName" PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs" diff --git a/sagemaker-core/tests/unit/test_training_constants.py b/sagemaker-train/tests/unit/test_training_constants.py similarity index 96% rename from sagemaker-core/tests/unit/test_training_constants.py rename to sagemaker-train/tests/unit/test_training_constants.py index ad096fac37..4ff66e2c24 100644 --- a/sagemaker-core/tests/unit/test_training_constants.py +++ b/sagemaker-train/tests/unit/test_training_constants.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.training.constants module.""" +"""Unit tests for sagemaker.train.constants module.""" from __future__ import absolute_import import os -from sagemaker.core.training import constants +from sagemaker.train import constants class TestTrainingConstants: