diff --git a/sagemaker-core/src/sagemaker/core/base_deserializers.py b/sagemaker-core/src/sagemaker/core/base_deserializers.py deleted file mode 100644 index 69c5be63e4..0000000000 --- a/sagemaker-core/src/sagemaker/core/base_deserializers.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Backward compatibility module for base deserializers. - -This module provides backward compatibility for code importing from -sagemaker.core.base_deserializers. The actual implementation is in -sagemaker.core.deserializers. - -.. deprecated:: 3.0.0 - Use :mod:`sagemaker.core.deserializers` instead. -""" -from __future__ import absolute_import - -import warnings - -# Re-export all deserializers from the correct location -from sagemaker.core.deserializers import * # noqa: F401, F403 - -# Issue deprecation warning -warnings.warn( - "Importing from sagemaker.core.base_deserializers is deprecated. " - "Use sagemaker.core.deserializers instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/sagemaker-core/src/sagemaker/core/base_serializers.py b/sagemaker-core/src/sagemaker/core/base_serializers.py deleted file mode 100644 index ea9a665866..0000000000 --- a/sagemaker-core/src/sagemaker/core/base_serializers.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Backward compatibility module for base serializers. - -This module provides backward compatibility for code importing from -sagemaker.core.base_serializers. The actual implementation is in -sagemaker.core.serializers. - -.. deprecated:: 3.0.0 - Use :mod:`sagemaker.core.serializers` instead. -""" -from __future__ import absolute_import - -import warnings - -# Re-export all serializers from the correct location -from sagemaker.core.serializers import * # noqa: F401, F403 - -# Issue deprecation warning -warnings.warn( - "Importing from sagemaker.core.base_serializers is deprecated. " - "Use sagemaker.core.serializers instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index bc35c7eb9d..c202efd5c8 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1669,22 +1669,10 @@ def _create_model_request( container_defs, vpc_config=None, enable_network_isolation=False, - primary_container=None, tags=None, ): # pylint: disable=redefined-outer-name """Placeholder docstring""" - if container_defs and primary_container: - raise ValueError("Both container_defs and primary_container can not be passed as input") - - if primary_container: - msg = ( - "primary_container is going to be deprecated in a future release. Please use " - "container_defs instead." - ) - warnings.warn(msg, DeprecationWarning) - container_defs = primary_container - role = self.expand_role(role) if isinstance(container_defs, list): @@ -1726,7 +1714,6 @@ def create_model( container_defs=None, vpc_config=None, enable_network_isolation=None, - primary_container=None, tags=None, ): """Create an Amazon SageMaker ``Model``. @@ -1754,11 +1741,6 @@ def create_model( * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. enable_network_isolation (bool): Whether the model requires network isolation or not. - primary_container (str or dict[str, str]): Docker image which defines the inference - code. You can also specify the return value of ``sagemaker.container_def()``, - which is used to create more advanced container configurations, including model - containers which need artifacts from S3. This field is deprecated, please use - container_defs instead. tags(Optional[Tags]): Optional. The list of tags to add to the model. Example: @@ -1794,7 +1776,6 @@ def create_model( container_defs=container_defs, vpc_config=vpc_config, enable_network_isolation=enable_network_isolation, - primary_container=primary_container, tags=tags, ) diff --git a/sagemaker-core/src/sagemaker/core/lineage/action.py b/sagemaker-core/src/sagemaker/core/lineage/action.py index d6b06196f5..0a0686074d 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/action.py +++ b/sagemaker-core/src/sagemaker/core/lineage/action.py @@ -38,7 +38,7 @@ class Action(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import action + from sagemaker.core.lineage import action my_action = action.Action.create( action_name='MyAction', diff --git a/sagemaker-core/src/sagemaker/core/lineage/artifact.py b/sagemaker-core/src/sagemaker/core/lineage/artifact.py index bc9522069f..8af1420659 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/artifact.py +++ b/sagemaker-core/src/sagemaker/core/lineage/artifact.py @@ -42,7 +42,7 @@ class Artifact(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import artifact + from sagemaker.core.lineage import artifact my_artifact = artifact.Artifact.create( artifact_name='MyArtifact', diff --git a/sagemaker-core/src/sagemaker/core/lineage/association.py b/sagemaker-core/src/sagemaker/core/lineage/association.py index f175622cd8..60e5c028ea 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/association.py +++ b/sagemaker-core/src/sagemaker/core/lineage/association.py @@ -31,7 +31,7 @@ class Association(_base_types.Record): Examples: .. code-block:: python - from sagemaker.lineage import association + from sagemaker.core.lineage import association my_association = association.Association.create( source_arn=artifact_arn, diff --git a/sagemaker-core/src/sagemaker/core/lineage/context.py b/sagemaker-core/src/sagemaker/core/lineage/context.py index 7a086095fa..7017d77992 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/context.py +++ b/sagemaker-core/src/sagemaker/core/lineage/context.py @@ -136,7 +136,7 @@ def load(cls, context_name: str, sagemaker_session=None) -> "Context": Examples: .. code-block:: python - from sagemaker.lineage import context + from sagemaker.core.lineage import context my_context = context.Context.create( context_name='MyContext', diff --git a/sagemaker-core/src/sagemaker/core/lineage/query.py b/sagemaker-core/src/sagemaker/core/lineage/query.py index a539cf621c..8a2d14a642 100644 --- a/sagemaker-core/src/sagemaker/core/lineage/query.py +++ b/sagemaker-core/src/sagemaker/core/lineage/query.py @@ -194,9 +194,9 @@ def to_lineage_object(self): A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context`` or ``TrialComponent`` object. """ - from sagemaker.lineage.context import Context, EndpointContext - from sagemaker.lineage.action import Action - from sagemaker.lineage.lineage_trial_component import LineageTrialComponent + from sagemaker.core.lineage.context import Context, EndpointContext + from sagemaker.core.lineage.action import Action + from sagemaker.core.lineage.lineage_trial_component import LineageTrialComponent if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -221,8 +221,8 @@ def to_lineage_object(self): def _artifact_to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding ``Artifact``.""" - from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact - from sagemaker.lineage.artifact import DatasetArtifact + from sagemaker.core.lineage.artifact import Artifact, ModelArtifact, ImageArtifact + from sagemaker.core.lineage.artifact import DatasetArtifact if self.lineage_source == LineageSourceEnum.MODEL.value: return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) diff --git a/sagemaker-core/src/sagemaker/core/modules/distributed.py b/sagemaker-core/src/sagemaker/core/modules/distributed.py index 21a33343c3..14c5837846 100644 --- a/sagemaker-core/src/sagemaker/core/modules/distributed.py +++ b/sagemaker-core/src/sagemaker/core/modules/distributed.py @@ -19,8 +19,8 @@ from typing import Optional, Dict, Any, List from sagemaker.core.modules.utils import safe_serialize -from sagemaker.core.training.configs import BaseConfig -from sagemaker.core.training.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.train.configs import BaseConfig +from sagemaker.train.constants import SM_DRIVERS_LOCAL_PATH class SMP(BaseConfig): diff --git a/sagemaker-core/src/sagemaker/core/modules/train/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/__init__.py deleted file mode 100644 index c5b5d01ed4..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules train directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py deleted file mode 100644 index 864f3663b8..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py deleted file mode 100644 index aab88c6b97..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - common directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py deleted file mode 100644 index c07aa1359a..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import logging -import sys -import subprocess -import traceback -import json - -from typing import List, Dict, Any, Tuple, IO, Optional - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -FAILURE_FILE = "/opt/ml/output/failure" -DEFAULT_FAILURE_MESSAGE = """ -Training Execution failed. -For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. -TrainingJob - {training_job_name} -""" - -USER_CODE_PATH = "/opt/ml/input/data/code" -SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" -DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" - -HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - - -def write_failure_file(message: Optional[str] = None): - """Write a failure file with the message.""" - if message is None: - message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) - if not os.path.exists(FAILURE_FILE): - with open(FAILURE_FILE, "w") as f: - f.write(message) - - -def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): - """Read the source code config json file.""" - try: - with open(source_code_json, "r") as f: - source_code_dict = json.load(f) or {} - except FileNotFoundError: - source_code_dict = {} - return source_code_dict - - -def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): - """Read the distribution config json file.""" - try: - with open(distributed_json, "r") as f: - distributed_dict = json.load(f) or {} - except FileNotFoundError: - distributed_dict = {} - return distributed_dict - - -def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): - """Read the hyperparameters config json file.""" - try: - with open(hyperparameters_json, "r") as f: - hyperparameters_dict = json.load(f) or {} - except FileNotFoundError: - hyperparameters_dict = {} - return hyperparameters_dict - - -def get_process_count(process_count: Optional[int] = None) -> int: - """Get the number of processes to run on each node in the training job.""" - return ( - process_count - or int(os.environ.get("SM_NUM_GPUS", 0)) - or int(os.environ.get("SM_NUM_NEURONS", 0)) - or 1 - ) - - -def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: - """Convert the hyperparameters to CLI arguments.""" - cli_args = [] - for key, value in hyperparameters.items(): - value = safe_deserialize(value) - cli_args.extend([f"--{key}", safe_serialize(value)]) - - return cli_args - - -def safe_deserialize(data: Any) -> Any: - """Safely deserialize data from a JSON string. - - This function handles the following cases: - 1. If `data` is not a string, it returns the input as-is. - 2. If `data` is a string and matches common boolean values ("true" or "false"), - it returns the corresponding boolean value (True or False). - 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. - 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. - - Returns: - Any: The deserialized data, or the original input if it cannot be JSON-decoded. - """ - if not isinstance(data, str): - return data - - lower_data = data.lower() - if lower_data in ["true"]: - return True - if lower_data in ["false"]: - return False - - try: - return json.loads(data) - except json.JSONDecodeError: - return data - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def get_python_executable() -> str: - """Get the python executable path.""" - return sys.executable - - -def log_subprocess_output(pipe: IO[bytes]): - """Log the output from the subprocess.""" - for line in iter(pipe.readline, b""): - logger.info(line.decode("utf-8").strip()) - - -def execute_commands(commands: List[str]) -> Tuple[int, str]: - """Execute the provided commands and return exit code with failure traceback if any.""" - try: - process = subprocess.Popen( - commands, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - with process.stdout: - log_subprocess_output(process.stdout) - exitcode = process.wait() - if exitcode != 0: - raise subprocess.CalledProcessError(exitcode, commands) - return exitcode, "" - except subprocess.CalledProcessError as e: - # Capture the traceback in case of failure - error_traceback = traceback.format_exc() - print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") - return e.returncode, error_traceback - - -def is_worker_node() -> bool: - """Check if the current node is a worker node.""" - return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") - - -def is_master_node() -> bool: - """Check if the current node is the master node.""" - return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py deleted file mode 100644 index a44e7e81a9..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py deleted file mode 100644 index 0b086a8e4f..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Basic Script Driver.""" -from __future__ import absolute_import - -import os -import sys -import json -import shlex - -from pathlib import Path -from typing import List - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - get_python_executable, - execute_commands, - write_failure_file, - hyperparameters_to_cli_args, -) - - -def create_commands() -> List[str]: - """Create the commands to execute.""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - hyperparameters = json.loads(os.environ["SM_HPS"]) - python_executable = get_python_executable() - - args = hyperparameters_to_cli_args(hyperparameters) - if entry_script.endswith(".py"): - commands = [python_executable, entry_script] - commands += args - elif entry_script.endswith(".sh"): - args_str = " ".join(shlex.quote(arg) for arg in args) - commands = [ - "/bin/sh", - "-c", - f"chmod +x {entry_script} && ./{entry_script} {args_str}", - ] - else: - raise ValueError( - f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." - ) - return commands - - -def main(): - """Main function for the Basic Script Driver. - - This function is the entry point for the Basic Script Driver. - - Execution Lifecycle: - 1. Read the source code and hyperparameters JSON files. - 2. Set hyperparameters as command line arguments. - 3. Create the commands to execute. - 4. Execute the commands. - """ - - cmd = create_commands() - - logger.info(f"Executing command: {' '.join(cmd)}") - exit_code, traceback = execute_commands(cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py deleted file mode 100644 index 7d991e30da..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the MPI driver script.""" -from __future__ import absolute_import - -import os -import sys -import json -from pathlib import Path - -try: - from mpi_utils import ( - start_sshd_daemon, - bootstrap_master_node, - bootstrap_worker_node, - get_mpirun_command, - write_status_file_to_workers, - write_env_vars_to_file, - ) -except ImportError: - # mpi_utils is an optional external dependency for MPI distributed training - # If not available, provide stub functions that raise helpful errors - def _mpi_not_available(*args, **kwargs): - raise ImportError( - "MPI distributed training requires the 'mpi_utils' package. " - "Please install it to use MPI-based distributed training." - ) - - start_sshd_daemon = _mpi_not_available - bootstrap_master_node = _mpi_not_available - bootstrap_worker_node = _mpi_not_available - get_mpirun_command = _mpi_not_available - write_status_file_to_workers = _mpi_not_available - write_env_vars_to_file = _mpi_not_available - - -sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, -) - - -def main(): - """Main function for the MPI driver script. - - The MPI Dirver is responsible for setting up the MPI environment, - generating the correct mpi commands, and launching the MPI job. - - Execution Lifecycle: - 1. Setup General Environment Variables at /etc/environment - 2. Start SSHD Daemon - 3. Bootstrap Worker Nodes - a. Wait to establish connection with Master Node - b. Wait for Master Node to write status file - 4. Bootstrap Master Node - a. Wait to establish connection with Worker Nodes - b. Generate MPI Command - c. Execute MPI Command with user script provided in `entry_script` - d. Write status file to Worker Nodes - 5. Exit - - """ - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - sm_current_host = os.environ["SM_CURRENT_HOST"] - sm_hosts = json.loads(os.environ["SM_HOSTS"]) - sm_master_addr = os.environ["SM_MASTER_ADDR"] - - write_env_vars_to_file() - start_sshd_daemon() - - if sm_current_host != sm_master_addr: - bootstrap_worker_node(sm_master_addr) - else: - worker_hosts = [host for host in sm_hosts if host != sm_master_addr] - bootstrap_master_node(worker_hosts) - - host_list = json.loads(os.environ["SM_HOSTS"]) - host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - - if process_count > 1: - host_list = ["{}:{}".format(host, process_count) for host in host_list] - - mpi_command = get_mpirun_command( - host_count=host_count, - host_list=host_list, - num_processes=process_count, - additional_options=distributed_config["mpi_additional_options"] or [], - entry_script_path=entry_script, - ) - - args = hyperparameters_to_cli_args(hyperparameters) - mpi_command += args - - logger.info(f"Executing command: {' '.join(mpi_command)}") - exit_code, error_traceback = execute_commands(mpi_command) - write_status_file_to_workers(worker_hosts) - - if exit_code != 0: - write_failure_file(error_traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py deleted file mode 100644 index ec9e1fcef9..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides mpi related utility functions for the container drivers.""" -from __future__ import absolute_import - -import os -import sys -import subprocess -import time - -from pathlib import Path -from typing import List - -import paramiko - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, - get_python_executable, - logger, -) - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info(f"Writing {status_file} to {host}") - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info(f"Cannot connect to {host}") - return False - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError(f"Timed out waiting for {worker} to be reachable.") - logger.info(f"Retrying to write status file to {worker}") - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info(f"Waiting for status file {status_file}") - while not os.path.exists(status_file): - time.sleep(30) - logger.info(f"Found status file {status_file}") - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - logger.debug("Testing connection to host %s", host) - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug(f"Connection failed with exception: {e}") - return False - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info(f"Worker is attempting to connect to the master node {master_host}...") - if _can_connect(master_host, port): - logger.info(f"Worker can connect to master node {master_host}.") - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"]) - _wait_for_status_file(status_file) - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def validate_smddprun() -> bool: - """Whether smddprun is installed. - - Returns: - bool: True if installed - """ - try: - output = subprocess.run( - ["which", "smddprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def validate_smddpmprun() -> bool: - """Whether smddpmprun is installed. - - Returns: - bool: True if both are installed - """ - try: - output = subprocess.run( - ["which", "smddpmprun"], - capture_output=True, - text=True, - check=True, - ) - return output.stdout != "" - except subprocess.CalledProcessError: - return False - - -def write_env_vars_to_file(): - """Write environment variables to /etc/environment file.""" - with open("/etc/environment", "a", encoding="utf-8") as f: - for name in os.environ: - f.write(f"{name}={os.environ.get(name)}\n") - - -def get_mpirun_command( - host_count: int, - host_list: List[str], - num_processes: int, - additional_options: List[str], - entry_script_path: str, -): - """Fetch mpi command""" - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - mpirun_command = [ - "mpirun", - "--host", - ",".join(host_list), - "-np", - str(num_processes), - "--allow-run-as-root", - "--tag-output", - "-mca", - "btl_tcp_if_include", - network_interface_name, - "-mca", - "oob_tcp_if_include", - network_interface_name, - "-mca", - "plm_rsh_no_tree_spawn", - "1", - "-mca", - "pml", - "ob1", - "-mca", - "btl", - "^openib", - "-mca", - "orte_abort_on_non_zero_status", - "1", - "-mca", - "btl_vader_single_copy_mechanism", - "none", - "-mca", - "plm_rsh_num_concurrent", - str(host_count), - "-x", - "NCCL_SOCKET_IFNAME=%s" % network_interface_name, - "-x", - "LD_LIBRARY_PATH", - "-x", - "PATH", - ] - - if additional_options: - mpirun_command.extend(additional_options) - - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - # EFA settings - if instance_type in SM_EFA_NCCL_INSTANCES: - mpirun_command.extend(["-x", "FI_PROVIDER=efa"]) - # Use simple protocol to handle the out-of-order data delivery from EFA - mpirun_command.extend(["-x", "NCCL_PROTO=simple"]) - - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) - - for credential in [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - ]: - if credential in os.environ: - mpirun_command.extend(["-x", credential]) - - mpirun_command.extend([get_python_executable()]) - mpirun_command.extend(["-m", "mpi4py", entry_script_path]) - return mpirun_command diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py deleted file mode 100644 index 7fcfabe05d..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is the entry point for the Torchrun driver script.""" -from __future__ import absolute_import - -import os -import sys -import json - -from pathlib import Path -from typing import List, Tuple - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - logger, - hyperparameters_to_cli_args, - get_process_count, - get_python_executable, - execute_commands, - write_failure_file, - SM_EFA_NCCL_INSTANCES, - SM_EFA_RDMA_INSTANCES, -) - - -def pytorch_version() -> Tuple[int, int]: - """Get the PyTorch version as a tuple of integers.""" - import torch - - return tuple(map(int, torch.__version__.split(".")[:2])) - - -def get_base_pytorch_command() -> List[str]: - """Get the base Torch Distributed launcher to execute""" - if pytorch_version() >= (1, 9): - return ["torchrun"] - return [f"{get_python_executable()}", "-m", "torch.distributed.launch"] - - -def setup_env(): - """Setup the environment variables for PyTorch distributed training""" - instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - os.environ["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" - os.environ["RDMAV_FORK_SAFE"] = "1" - os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - os.environ["NCCL_PROTO"] = "simple" - - -def create_commands(): - """Create the Torch Distributed command to execute""" - entry_script = os.environ["SM_ENTRY_SCRIPT"] - distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) - hyperparameters = json.loads(os.environ["SM_HPS"]) - - process_count = int(distributed_config["process_count_per_node"] or 0) - process_count = get_process_count(process_count) - host_count = int(os.environ["SM_HOST_COUNT"]) - - torch_cmd = [] - if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": - torch_cmd.append("neuron_parallel_compile") - - torch_cmd.extend(get_base_pytorch_command()) - torch_cmd.extend( - [ - f"--nnodes={host_count}", - f"--nproc_per_node={process_count}", - ] - ) - - # If more than one node is used, add node rank information - if int(host_count) > 1: - torch_cmd.extend( - [ - f"--master_addr={os.environ['SM_MASTER_ADDR']}", - f"--master_port={os.environ['SM_MASTER_PORT']}", - f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}", - ] - ) - - torch_cmd.extend([entry_script]) - - args = hyperparameters_to_cli_args(hyperparameters) - torch_cmd += args - - return torch_cmd - - -def main(): - """Main function to execute the PyTorch distributed training script. - - This function sets some environment variables and executes the PyTorch - distributed training script. - - Execution Lifecycle: - 1. Setup Environment Variables for PyTorch Distributed Training - 2. Create Torch Distributed Command - 3. Execute Torch Distributed Command with user script provided in `entry_script` - 4. Exit - - """ - setup_env() - torch_cmd = create_commands() - logger.info(f"Executing command: {' '.join(torch_cmd)}") - exit_code, traceback = execute_commands(torch_cmd) - if exit_code != 0: - write_failure_file(traceback) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py deleted file mode 100644 index f04c5b17a0..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container drivers - scripts directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py deleted file mode 100644 index 897b1f8af4..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/scripts/environment.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the environment variables for the training job container.""" -from __future__ import absolute_import - -from typing import Dict, Any -import multiprocessing -import subprocess -import json -import os -import sys -from pathlib import Path -import logging - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 - safe_serialize, - safe_deserialize, - read_distributed_json, - read_source_code_json, -) - -# Initialize logger -SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) -logger = logging.getLogger(__name__) -console_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(console_handler) -logger.setLevel(int(SM_LOG_LEVEL)) - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" -SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json" -HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json" - -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: - """Deserialize hyperparameters from string to their original types. - - Args: - hyperparameters (Dict[str, str]): Hyperparameters as strings. - - Returns: - Dict[str, Any]: Hyperparameters as their original types. - """ - deserialized_hyperparameters = {} - for key, value in hyperparameters.items(): - deserialized_hyperparameters[key] = safe_deserialize(value) - return deserialized_hyperparameters - - -def set_env( - resource_config: Dict[str, Any], - input_data_config: Dict[str, Any], - hyperparameters_config: Dict[str, Any], - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - input_data_config (Dict[str, Any]): Input data configuration for the training job. - hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_LOG_LEVEL": SM_LOG_LEVEL, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # SourceCode and DistributedConfig Environment Variables - source_code = read_source_code_json() - if source_code: - env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH - env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") - - distributed = read_distributed_json() - if distributed: - env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH - env_vars["SM_DISTRIBUTED_CONFIG"] = distributed - - # Data Channels - channels = list(input_data_config.keys()) - for channel in channels: - env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" - env_vars["SM_CHANNELS"] = channels - - # Hyperparameters - hps = deserialize_hyperparameters(hyperparameters_config) - for key, value in hps.items(): - key_upper = key.replace("-", "_").upper() - env_vars[f"SM_HP_{key_upper}"] = value - env_vars["SM_HPS"] = hps - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "channel_input_dirs": { - channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels - }, - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "hyperparameters": env_vars["SM_HPS"], - "input_data_config": input_data_config, - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "log_level": env_vars["SM_LOG_LEVEL"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def main(): - """Main function to set the environment variables for the training job container.""" - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - with open(INPUT_DATA_CONFIG, "r") as f: - input_data_config = json.load(f) - with open(HYPERPARAMETERS_CONFIG, "r") as f: - hyperparameters_config = json.load(f) - - set_env( - resource_config=resource_config, - input_data_config=input_data_config, - hyperparameters_config=hyperparameters_config, - output_file=ENV_OUTPUT_FILE, - ) - - -if __name__ == "__main__": - main() diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/__init__.py b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json deleted file mode 100644 index a51513f49f..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/training_recipes.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", - "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", - "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", - "gpu_image" : { - "framework": "pytorch-smp", - "version": "2.4.1", - "additional_args": { - "container_version": "cu121" - } - }, - "neuron_image": { - "framework": "hyperpod-recipes-neuron", - "version": "2.1.2", - "additional_args": {} - } -} \ No newline at end of file diff --git a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py deleted file mode 100644 index 1eb0b83e97..0000000000 --- a/sagemaker-core/src/sagemaker/core/modules/train/sm_recipes/utils.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utility functions for SageMaker training recipes.""" -from __future__ import absolute_import - -import math -import os -import json -import shutil -import tempfile -from urllib.request import urlretrieve -from typing import Dict, Any, Optional, Tuple - -import omegaconf -from omegaconf import OmegaConf, dictconfig - -from sagemaker.core.image_uris import retrieve - -from sagemaker.core.modules import logger -from sagemaker.core.modules.utils import _run_clone_command_silent -from sagemaker.core.modules.configs import Compute, SourceCode -from sagemaker.core.modules.distributed import Torchrun, SMP - - -def _try_resolve_recipe(recipe, key=None): - """Try to resolve recipe and return resolved recipe.""" - if key is not None: - recipe = dictconfig.DictConfig({key: recipe}) - try: - OmegaConf.resolve(recipe) - except omegaconf.errors.OmegaConfBaseException: - return None - if key is None: - return recipe - return recipe[key] - - -def _determine_device_type(instance_type: str) -> str: - """Determine device type (gpu, cpu, trainium) based on instance type.""" - instance_family = instance_type.split(".")[1] - if instance_family.startswith(("p", "g")): - return "gpu" - if instance_family.startswith("trn"): - return "trainium" - return "cpu" - - -def _load_recipes_cfg() -> str: - """Load training recipes configuration json.""" - training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") - with open(training_recipes_cfg_filename) as training_recipes_cfg_file: - training_recipes_cfg = json.load(training_recipes_cfg_file) - return training_recipes_cfg - - -def _load_base_recipe( - training_recipe: str, - recipe_overrides: Optional[Dict[str, Any]] = None, - training_recipes_cfg: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Load recipe and apply overrides.""" - if recipe_overrides is None: - recipe_overrides = dict() - - temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name - - if training_recipe.endswith(".yaml"): - if os.path.isfile(training_recipe): - shutil.copy(training_recipe, temp_local_recipe) - else: - try: - urlretrieve(training_recipe, temp_local_recipe) - except Exception as e: - raise ValueError( - f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" - ) - else: - recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") - - launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( - "launcher_repo" - ) - _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name) - - recipe = os.path.join( - recipe_launcher_dir.name, - "recipes_collection", - "recipes", - training_recipe + ".yaml", - ) - if os.path.isfile(recipe): - shutil.copy(recipe, temp_local_recipe) - else: - raise ValueError(f"Recipe {training_recipe} not found.") - - recipe = OmegaConf.load(temp_local_recipe) - os.unlink(temp_local_recipe) - recipe = OmegaConf.merge(recipe, recipe_overrides) - return recipe - - -def _register_custom_resolvers(): - """Register custom resolvers for OmegaConf.""" - if not OmegaConf.has_resolver("multiply"): - OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) - if not OmegaConf.has_resolver("divide_ceil"): - OmegaConf.register_new_resolver( - "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True - ) - if not OmegaConf.has_resolver("divide_floor"): - OmegaConf.register_new_resolver( - "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True - ) - if not OmegaConf.has_resolver("add"): - OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) - - -def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): - """Get the model base name and script for the training recipe.""" - - model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - "deepseek": ("deepseek", "deepseek_pretrain.py"), - } - - for key in model_type_to_script: - if model_type.startswith(key): - model_type = key - break - - if model_type not in model_type_to_script: - raise ValueError(f"Model type {model_type} not supported") - - return model_type_to_script[model_type][0], model_type_to_script[model_type][1] - - -def _configure_gpu_args( - training_recipes_cfg: Dict[str, Any], - region_name: str, - recipe: OmegaConf, - recipe_train_dir: tempfile.TemporaryDirectory, -) -> Dict[str, Any]: - """Configure arguments specific to GPU.""" - source_code = SourceCode() - args = dict() - - adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( - "adapter_repo" - ) - _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - - if "model" not in recipe: - raise ValueError("Supplied recipe does not contain required field model.") - if "model_type" not in recipe["model"]: - raise ValueError("Supplied recipe does not contain required field model_type.") - model_type = recipe["model"]["model_type"] - - model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) - - source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) - source_code.entry_script = script - - gpu_image_cfg = training_recipes_cfg.get("gpu_image") - if isinstance(gpu_image_cfg, str): - training_image = gpu_image_cfg - else: - training_image = retrieve( - gpu_image_cfg.get("framework"), - region=region_name, - version=gpu_image_cfg.get("version"), - image_scope="training", - **gpu_image_cfg.get("additional_args"), - ) - - # Setting dummy parameters for now - torch_distributed = Torchrun(smp=SMP(random_seed="123456")) - args.update( - { - "source_code": source_code, - "training_image": training_image, - "distributed": torch_distributed, - } - ) - return args - - -def _configure_trainium_args( - training_recipes_cfg: Dict[str, Any], - region_name: str, - recipe_train_dir: tempfile.TemporaryDirectory, -) -> Dict[str, Any]: - """Configure arguments specific to Trainium.""" - source_code = SourceCode() - args = dict() - - _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) - - source_code.source_dir = os.path.join(recipe_train_dir.name, "examples") - source_code.entry_script = "training_orchestrator.py" - neuron_image_cfg = training_recipes_cfg.get("neuron_image") - if isinstance(neuron_image_cfg, str): - training_image = neuron_image_cfg - else: - training_image = retrieve( - neuron_image_cfg.get("framework"), - region=region_name, - version=neuron_image_cfg.get("version"), - image_scope="training", - **neuron_image_cfg.get("additional_args"), - ) - - args.update( - { - "source_code": source_code, - "training_image": training_image, - "distributed": Torchrun(), - } - ) - return args - - -def _get_args_from_recipe( - training_recipe: str, - compute: Compute, - region_name: str, - recipe_overrides: Optional[Dict[str, Any]], - requirements: Optional[str], -) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: - """Get arguments for ModelTrainer from a training recipe. - - Returns a dictionary of arguments to be used with ModelTrainer like: - ```python - { - "source_code": SourceCode, - "training_image": str, - "distributed": DistributedConfig, - "compute": Compute, - "hyperparameters": Dict[str, Any], - } - ``` - - Args: - training_recipe (str): - Name of the training recipe or path to the recipe file. - compute (Compute): - Compute configuration for training. - region_name (str): - Name of the AWS region. - recipe_overrides (Optional[Dict[str, Any]]): - Overrides for the training recipe. - requirements (Optional[str]): - Path to the requirements file. - """ - if compute.instance_type is None: - raise ValueError("Must set `instance_type` in compute when using training recipes.") - - training_recipes_cfg = _load_recipes_cfg() - recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) - - if "trainer" not in recipe: - raise ValueError("Supplied recipe does not contain required field trainer.") - - # Set instance_count - if compute.instance_count and "num_nodes" in recipe["trainer"]: - logger.warning( - f"Using Compute to set instance_count:\n{compute}." - "\nIgnoring trainer -> num_nodes in recipe." - ) - if compute.instance_count is None: - if "num_nodes" not in recipe["trainer"]: - raise ValueError( - "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." - ) - compute.instance_count = recipe["trainer"]["num_nodes"] - - if requirements and not os.path.isfile(requirements): - raise ValueError(f"Recipe requirements file {requirements} not found.") - - # Get Training Image, SourceCode, and distributed args - device_type = _determine_device_type(compute.instance_type) - recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") - if device_type == "gpu": - args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir) - elif device_type == "trainium": - args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir) - else: - raise ValueError(f"Devices of type {device_type} are not supported with training recipes.") - - _register_custom_resolvers() - - # Resolve Final Recipe - final_recipe = _try_resolve_recipe(recipe) - if final_recipe is None: - final_recipe = _try_resolve_recipe(recipe, "recipes") - if final_recipe is None: - final_recipe = _try_resolve_recipe(recipe, "training") - if final_recipe is None: - raise RuntimeError("Could not resolve provided recipe.") - - # Save Final Recipe to source_dir - OmegaConf.save( - config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") - ) - - # If recipe_requirements is provided, copy it to source_dir - if requirements: - shutil.copy(requirements, args["source_code"].source_dir) - args["source_code"].requirements = os.path.basename(requirements) - - # Update args with compute and hyperparameters - args.update( - { - "compute": compute, - "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"}, - } - ) - - return args, recipe_train_dir diff --git a/sagemaker-core/src/sagemaker/core/remote_function/__init__.py b/sagemaker-core/src/sagemaker/core/remote_function/__init__.py deleted file mode 100644 index 6436ddaa22..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Defines classes and helper methods used in remote function executions.""" -from __future__ import absolute_import - -from sagemaker.core.remote_function.client import remote, RemoteExecutor # noqa: F401 -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 -from sagemaker.core.remote_function.spark_config import SparkConfig # noqa: F401 diff --git a/sagemaker-core/src/sagemaker/core/remote_function/checkpoint_location.py b/sagemaker-core/src/sagemaker/core/remote_function/checkpoint_location.py deleted file mode 100644 index 4153fe03d3..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/checkpoint_location.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the CheckpointLocation to remote function.""" -from __future__ import absolute_import - -from os import PathLike -import re - -# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CheckpointConfig.html -S3_URI_REGEX_PATTERN = r"^(https|s3)://([^/]+)/?(.*)$" - -_JOB_CHECKPOINT_LOCATION = "/opt/ml/checkpoints/" - - -def _validate_s3_uri_for_checkpoint(s3_uri: str): - """Validate if checkpoint location is specified with a valid s3 URI.""" - return re.match(S3_URI_REGEX_PATTERN, s3_uri) - - -class CheckpointLocation(PathLike): - """Class to represent the location where checkpoints are accessed in a remote function. - - To save or load checkpoints in a remote function, pass an CheckpointLocation object as a - function parameter and use it as a os.PathLike object. This CheckpointLocation object - represents the local directory (/opt/ml/checkpoints/) of checkpoints in side the job. - """ - - _local_path = _JOB_CHECKPOINT_LOCATION - - def __init__(self, s3_uri): - if not _validate_s3_uri_for_checkpoint(s3_uri): - raise ValueError("CheckpointLocation should be specified with valid s3 URI.") - self._s3_uri = s3_uri - - def __fspath__(self): - """Return job local path where checkpoints are stored.""" - return self._local_path diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-core/src/sagemaker/core/remote_function/client.py deleted file mode 100644 index a38b57662a..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/client.py +++ /dev/null @@ -1,1285 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function client.""" -from __future__ import absolute_import - -from concurrent.futures import ThreadPoolExecutor -from collections import deque -import time -import threading -from typing import Callable, Dict, List, Optional, Tuple, Any, Union -import functools -import itertools -import inspect - -from botocore.exceptions import ClientError -from sagemaker.core.exceptions import UnexpectedStatusException -from sagemaker.core.experiments._run_context import _RunContext - -import sagemaker.core.remote_function.core.serialization as serialization -from sagemaker.core.remote_function.errors import ( - RemoteFunctionError, - ServiceError, - DeserializationError, -) -from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - RuntimeEnvironmentError, -) - -from sagemaker.core.helper.session_helper import Session -from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.job import _JobSettings, _Job, _RunInfo -from sagemaker.core.remote_function import logging_config -from sagemaker.core.common_utils import name_from_base, base_from_name -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter -from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter -from sagemaker.core.telemetry.constants import Feature - -_API_CALL_LIMIT = { - "SubmittingIntervalInSecs": 1, - "MinBatchPollingIntervalInSecs": 10, - "PollingIntervalInSecs": 0.5, -} - -# Possible future states. -_PENDING = "PENDING" -_RUNNING = "RUNNING" -# The future was cancelled by the user... -_CANCELLED = "CANCELLED" -_FINISHED = "FINISHED" - -logger = logging_config.get_logger() - - -@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote") -def remote( - _func=None, - *, - dependencies: str = None, - pre_execution_commands: List[str] = None, - pre_execution_script: str = None, - environment_variables: Dict[str, str] = None, - image_uri: str = None, - include_local_workdir: bool = None, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, - instance_count: int = 1, - instance_type: str = None, - job_conda_env: str = None, - job_name_prefix: str = None, - keep_alive_period_in_seconds: int = 0, - max_retry_attempts: int = 1, - max_runtime_in_seconds: int = 24 * 60 * 60, - role: str = None, - s3_kms_key: str = None, - s3_root_uri: str = None, - sagemaker_session: Session = None, - security_group_ids: List[str] = None, - subnets: List[str] = None, - tags: List[Tuple[str, str]] = None, - volume_kms_key: str = None, - volume_size: int = 30, - encrypt_inter_container_traffic: bool = None, - spark_config: SparkConfig = None, - use_spot_instances=False, - max_wait_time_in_seconds=None, - disable_output_compression: bool = False, - use_torchrun: bool = False, - use_mpirun: bool = False, - nproc_per_node: Optional[int] = None, -): - """Decorator for running the annotated function as a SageMaker training job. - - This decorator wraps the annotated code and runs it as a new SageMaker job synchronously - with the provided runtime settings. - - If a parameter value is not set, the decorator first looks up the value from the SageMaker - configuration file. If no value is specified in the configuration file or no configuration file - is found, the decorator selects the default as specified below. For more information, see - `Configuring and using defaults with the SageMaker Python SDK `_. - - Args: - _func (Optional): A Python function to run as a SageMaker training job. - - dependencies (str): Either the path to a dependencies file or the reserved keyword - ``auto_capture``. Defaults to ``None``. - If ``dependencies`` is provided, the value must be one of the following: - - * A path to a conda environment.yml file. The following conditions apply. - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then the - conda environment is updated by installing dependencies from the yaml file and the - function is invoked within that conda environment. For this to succeed, the - conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and - ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. - * If none of the previous conditions are met, a new conda environment named - ``sagemaker-runtime-env`` is created and the function annotated with the remote - decorator is invoked in that conda environment. - - * A path to a requirements.txt file. The following conditions apply. - - * If ``job_conda_env`` is set in the remote decorator, dependencies are installed - within that conda environment and the function annotated with the remote decorator - is invoked in the same conda environment. For this to succeed, the specified - conda environment must already exist in the image. - * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - dependencies are installed within that conda environment and the function annotated - with the remote decorator is invoked in the same. For this to succeed, the conda - environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and - ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. - * If none of the above conditions are met, conda is not used. Dependencies are - installed at the system level, without any virtual environment, and the function - annotated with the remote decorator is invoked using the Python runtime available - in the system path. - - * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically - generate an env_snapshot.yml corresponding to the current active conda environment’s - snapshot. You do not need to provide a dependencies file. The following conditions - apply: - - * You must run the remote function within an active conda environment. - * When installing the dependencies on the training job, the same conditions as when - dependencies is set to a path to a conda environment file apply. These conditions are - as follows: - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then - the conda environment is updated by installing dependencies from the yaml file - and the function is invoked within that conda environment. For this to - succeed, the conda environment name must already be set in - ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist - in the image. - * If none of the previous conditions are met, a new conda environment with name - ``sagemaker-runtime-env`` is created and the function annotated with the - remote decorator is invoked in that conda environment. - - * ``None``. SageMaker will assume that there are no dependencies to install while - executing the remote annotated function in the training job. - - pre_execution_commands (List[str]): List of commands to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - pre_execution_script (str): Path to script file to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - environment_variables (Dict): The environment variables used inside the decorator function. - Defaults to ``None``. - - image_uri (str): The universal resource identifier (URI) location of a Docker image on - Amazon Elastic Container Registry (ECR). Defaults to the following based on where the SDK - is running: - - * For users who specify ``spark_config`` and want to run the function in a Spark - application, the ``image_uri`` should be ``None``. A SageMaker Spark image will - be used for training, otherwise a ``ValueError`` is thrown. - * For users on SageMaker Studio notebooks, the image used as the kernel image for the - notebook is used. - * For other users, it is resolved to base python image with the same python version - as the environment running the local code. - - If no compatible image is found, a ValueError is thrown. - - include_local_workdir (bool): A flag to indicate that the remote function should include - local directories. Set to ``True`` if the remote function code imports local modules and - methods that are not available via PyPI or conda. Only python files are included. - Default value is ``False``. - - custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function - that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object - that specifies the local directories and files to be included in the remote function. - If a callable is passed in, the function should follow the protocol of ``ignore`` argument - of ``shutil.copytree``. Defaults to ``None``, which means only python - files are accepted and uploaded to S3. - - instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and - mpirun utilities - - instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run - the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. - - job_conda_env (str): The name of the conda environment to activate during job's runtime. - Defaults to ``None``. - - job_name_prefix (str): The prefix used used to create the underlying SageMaker job. - - keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse provisioned - infrastructure after the completion of a training job, also known as SageMaker managed - warm pools. The use of warmpools reduces the latency time spent to provision new - resources. The default value for ``keep_alive_period_in_seconds`` is 0. - NOTE: Additional charges associated with warm pools may apply. Using this parameter also - activates a new persistent cache feature, which will further reduce job start up - latency than over using SageMaker managed warm pools alone by caching the package source - downloaded in the previous runs. - - max_retry_attempts (int): The max number of times the job is retried on - ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. - - max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After - this specified amount of time, SageMaker terminates the job regardless of its current - status. Defaults to 1 day or (86400 seconds). - - role (str): The IAM role (either name or full ARN) used to run your SageMaker training - job. Defaults to: - - * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or - SageMaker Studio Notebooks. - * if not above, a ValueError is be thrown. - - s3_kms_key (str): The key used to encrypt the input and output data. Default to ``None``. - - s3_root_uri (str): The root S3 folder to which the code archives and data are - uploaded to. Defaults to ``s3://``. - - sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which - SageMaker service calls are delegated to (default: None). If not provided, one is created - using a default configuration chain. - - security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and the - training job is created without VPC config. - - subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is created - without VPC config. - - tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` and - the training job is created without tags. - - volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an - Amazon Elastic Block Storage (EBS) volume attached to the training instance. Defaults to - ``None``. - - volume_size (int): The size in GB of the storage volume for storing input and output data - during training. Defaults to ``30``. - - encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between - training containers is encrypted for the training job. Defaults to ``False``. - - spark_config (SparkConfig): Configurations to the Spark application that runs on - Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri - will be used for training. Note that ``image_uri`` can not be specified at the - same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. - - use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for - training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set. - Defaults to ``False``. - - max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. - After this amount of time Amazon SageMaker will stop waiting for managed spot training - job to complete. Defaults to ``None``. - - disable_output_compression (bool): Optional. When set to true, Model is uploaded to - Amazon S3 without compression after training finishes. - - use_torchrun (bool): Specifies whether to use torchrun for distributed training. - Defaults to ``False``. - - use_mpirun (bool): Specifies whether to use mpirun for distributed training. - Defaults to ``False``. - - nproc_per_node (int): Optional. Specifies the number of processes per node for - distributed training. Defaults to ``None``. - This is defined automatically configured on the instance type. - """ - - def _remote(func): - - job_settings = _JobSettings( - dependencies=dependencies, - pre_execution_commands=pre_execution_commands, - pre_execution_script=pre_execution_script, - environment_variables=environment_variables, - image_uri=image_uri, - include_local_workdir=include_local_workdir, - custom_file_filter=custom_file_filter, - instance_count=instance_count, - instance_type=instance_type, - job_conda_env=job_conda_env, - job_name_prefix=job_name_prefix, - keep_alive_period_in_seconds=keep_alive_period_in_seconds, - max_retry_attempts=max_retry_attempts, - max_runtime_in_seconds=max_runtime_in_seconds, - role=role, - s3_kms_key=s3_kms_key, - s3_root_uri=s3_root_uri, - sagemaker_session=sagemaker_session, - security_group_ids=security_group_ids, - subnets=subnets, - tags=tags, - volume_kms_key=volume_kms_key, - volume_size=volume_size, - encrypt_inter_container_traffic=encrypt_inter_container_traffic, - spark_config=spark_config, - use_spot_instances=use_spot_instances, - max_wait_time_in_seconds=max_wait_time_in_seconds, - disable_output_compression=disable_output_compression, - use_torchrun=use_torchrun, - use_mpirun=use_mpirun, - nproc_per_node=nproc_per_node, - ) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - - if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun and not use_mpirun) - or (spark_config is None and use_torchrun and not use_mpirun) - or (spark_config is None and not use_torchrun and use_mpirun) - ): - raise ValueError( - "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun or use_mpirun. " - + "Please provide instance_count = 1" - ) - - RemoteExecutor._validate_submit_args(func, *args, **kwargs) - - job = _Job.start(job_settings, func, args, kwargs) - - try: - job.wait() - except UnexpectedStatusException as usex: - if usex.actual_status == "Failed": - try: - exception = serialization.deserialize_exception_from_s3( - sagemaker_session=job_settings.sagemaker_session, - s3_uri=s3_path_join( - job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER - ), - - ) - except ServiceError as serr: - chained_e = serr.__cause__ - if ( - isinstance(chained_e, ClientError) - and chained_e.response["Error"]["Code"] # pylint: disable=no-member - == "404" - and chained_e.response["Error"]["Message"] # pylint: disable=no-member - == "Not Found" - ): - describe_result = job.describe() - if ( - "FailureReason" in describe_result - and describe_result["FailureReason"] - and "RuntimeEnvironmentError: " in describe_result["FailureReason"] - ): - failure_msg = describe_result["FailureReason"].replace( - "RuntimeEnvironmentError: ", "" - ) - raise RuntimeEnvironmentError(failure_msg) - raise RemoteFunctionError( - "Failed to execute remote function. " - + "Check corresponding job for details." - ) - raise serr - - raise exception - - raise TimeoutError( - "Job for remote function timed out before reaching a termination status." - ) - - if job.describe()["TrainingJobStatus"] == "Completed": - return serialization.deserialize_obj_from_s3( - sagemaker_session=job_settings.sagemaker_session, - s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - - ) - - if job.describe()["TrainingJobStatus"] == "Stopped": - raise RemoteFunctionError("Job for remote function has been aborted.") - - return None - - wrapper.job_settings = job_settings - wrapper.wrapped_func = func - return wrapper - - if _func is None: - return _remote - return _remote(_func) - - -class _SubmitRequest: - """Class that holds parameters and data for creating a new job.""" - - def __init__( - self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None - ): - self.future = future - self.job_settings = job_settings - self.func = func - self.args = func_args - self.kwargs = func_kwargs - self.run_info = run_info - - -def _submit_worker(executor): - """Background worker that submits job requests.""" - - def has_work_to_do(): - return ( - len(executor._pending_request_queue) > 0 - and len(executor._running_jobs) < executor.max_parallel_jobs - ) - - try: - while True: - with executor._state_condition: - executor._state_condition.wait_for(has_work_to_do) - request = executor._pending_request_queue[0] - - if request is None: - with executor._state_condition: - # remove the anchor from the pending queue - executor._pending_request_queue.popleft() - return - - time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"]) - # submit a new job - job = request.future._start_and_notify( - request.job_settings, request.func, request.args, request.kwargs, request.run_info - ) - - with executor._state_condition: - if job: - executor._running_jobs[job.job_name] = job - # remove the request from the pending queue - executor._pending_request_queue.popleft() - except Exception: # pylint: disable=broad-except - logger.exception("Error occurred while submitting CreateTrainingJob requests.") - - -def _polling_worker(executor): - """Background worker that polls the status of the running jobs.""" - try: - while True: - with executor._state_condition: - if ( - executor._shutdown - and len(executor._running_jobs) + len(executor._pending_request_queue) == 0 - ): - return - - time.sleep( - max( - _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"] - - len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"], - 0, - ) - ) - - # check if running jobs are terminated - for job_name in list(executor._running_jobs.keys()): - try: - time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"]) - if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [ - "Completed", - "Failed", - "Stopped", - ]: - with executor._state_condition: - del executor._running_jobs[job_name] - executor._state_condition.notify_all() - except Exception as e: # pylint: disable=broad-except - if ( - not isinstance(e, ClientError) - or e.response["Error"]["Code"] # pylint: disable=no-member - != "LimitExceededException" - ): - # Couldn't check the job status, move on - logger.exception( - "Error occurred while checking the status of job %s", job_name - ) - with executor._state_condition: - del executor._running_jobs[job_name] - executor._state_condition.notify_all() - except Exception: # pylint: disable=broad-except - logger.exception("Error occurred while monitoring the job statuses.") - - -class RemoteExecutor(object): - """Run Python functions asynchronously as SageMaker jobs""" - - def __init__( - self, - *, - dependencies: str = None, - pre_execution_commands: List[str] = None, - pre_execution_script: str = None, - environment_variables: Dict[str, str] = None, - image_uri: str = None, - include_local_workdir: bool = None, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, - instance_count: int = 1, - instance_type: str = None, - job_conda_env: str = None, - job_name_prefix: str = None, - keep_alive_period_in_seconds: int = 0, - max_parallel_jobs: int = 1, - max_retry_attempts: int = 1, - max_runtime_in_seconds: int = 24 * 60 * 60, - role: str = None, - s3_kms_key: str = None, - s3_root_uri: str = None, - sagemaker_session: Session = None, - security_group_ids: List[str] = None, - subnets: List[str] = None, - tags: List[Tuple[str, str]] = None, - volume_kms_key: str = None, - volume_size: int = 30, - encrypt_inter_container_traffic: bool = None, - spark_config: SparkConfig = None, - use_spot_instances=False, - max_wait_time_in_seconds=None, - disable_output_compression: bool = False, - use_torchrun: bool = False, - use_mpirun: bool = False, - nproc_per_node: Optional[int] = None, - ): - """Constructor for RemoteExecutor - - If a parameter value is not set, the constructor first looks up the value from the - SageMaker configuration file. If no value is specified in the configuration file or - no configuration file is found, the constructor selects the default as specified below. - For more information, see `Configuring and using defaults with the SageMaker Python SDK - `_. - - Args: - _func (Optional): A Python function to run as a SageMaker training job. - - dependencies (str): Either the path to a dependencies file or the reserved keyword - ``auto_capture``. Defaults to ``None``. - If ``dependencies`` is provided, the value must be one of the following: - - * A path to a conda environment.yml file. The following conditions apply. - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then - the conda environment is updated by installing dependencies from the yaml file and - the function is invoked within that conda environment. For this to succeed, the - conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and - ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. - * If none of the previous conditions are met, a new conda environment named - ``sagemaker-runtime-env`` is created and the function annotated with the remote - decorator is invoked in that conda environment. - - * A path to a requirements.txt file. The following conditions apply. - - * If ``job_conda_env`` is set in the remote decorator, dependencies are installed - within that conda environment and the function annotated with the remote decorator - is invoked in the same conda environment. For this to succeed, the specified - conda environment must already exist in the image. - * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - dependencies are installed within that conda environment and the function annotated - with the remote decorator is invoked in the same. For this to succeed, the - conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and - ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. - * If none of the above conditions are met, conda is not used. Dependencies are - installed at the system level, without any virtual environment, and the function - annotated with the remote decorator is invoked using the Python runtime available - in the system path. - - * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically - generate an env_snapshot.yml corresponding to the current active conda environment’s - snapshot. You do not need to provide a dependencies file. The following conditions - apply: - - * You must run the remote function within an active conda environment. - * When installing the dependencies on the training job, the same conditions as when - dependencies is set to a path to a conda environment file apply. These conditions - are as follows: - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - then the conda environment is updated by installing dependencies from the yaml - file and the function is invoked within that conda environment. For this to - succeed, the conda environment name must already be set in - ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist - in the image. - * If none of the previous conditions are met, a new conda environment with name - ``sagemaker-runtime-env`` is created and the function annotated with the - remote decorator is invoked in that conda environment. - - * ``None``. SageMaker will assume that there are no dependencies to install while - executing the remote annotated function in the training job. - - pre_execution_commands (List[str]): List of commands to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - pre_execution_script (str): Path to script file to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - environment_variables (Dict): The environment variables used inside the decorator - function. Defaults to ``None``. - - image_uri (str): The universal resource identifier (URI) location of a Docker image on - Amazon Elastic Container Registry (ECR). Defaults to the following based on where the - SDK is running: - - * For users who specify ``spark_config`` and want to run the function in a Spark - application, the ``image_uri`` should be ``None``. A SageMaker Spark image will - be used for training, otherwise a ``ValueError`` is thrown. - * For users on SageMaker Studio notebooks, the image used as the kernel image for - the notebook is used. - * For other users, it is resolved to base python image with the same python - version as the environment running the local code. - - If no compatible image is found, a ValueError is thrown. - - include_local_workdir (bool): A flag to indicate that the remote function should include - local directories. Set to ``True`` if the remote function code imports local modules - and methods that are not available via PyPI or conda. Default value is ``False``. - - custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function - that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object - that specifies the local directories and files to be included in the remote function. - If a callable is passed in, that function is passed to the ``ignore`` argument of - ``shutil.copytree``. Defaults to ``None``, which means only python - files are accepted and uploaded to S3. - - instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and - mpirun utilities - - instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run - the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. - - job_conda_env (str): The name of the conda environment to activate during job's runtime. - Defaults to ``None``. - - job_name_prefix (str): The prefix used used to create the underlying SageMaker job. - - keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse - provisioned infrastructure after the completion of a training job, also known as - SageMaker managed warm pools. The use of warmpools reduces the latency time spent to - provision new resources. The default value for ``keep_alive_period_in_seconds`` is 0. - NOTE: Additional charges associated with warm pools may apply. Using this parameter - also activates a new pesistent cache feature, which will further reduce job start - up latency than over using SageMaker managed warm pools alone by caching the package - source downloaded in the previous runs. - - max_parallel_jobs (int): Maximum number of jobs that run in parallel. Defaults to 1. - - max_retry_attempts (int): The max number of times the job is retried on - ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. - - max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After - this specified amount of time, SageMaker terminates the job regardless of its current - status. Defaults to 1 day or (86400 seconds). - - role (str): The IAM role (either name or full ARN) used to run your SageMaker training - job. Defaults to: - - * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or - SageMaker Studio Notebooks. - * if not above, a ValueError is be thrown. - - s3_kms_key (str): The key used to encrypt the input and output data. - Default to ``None``. - - s3_root_uri (str): The root S3 folder to which the code archives and data are - uploaded to. Defaults to ``s3://``. - - sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which - SageMaker service calls are delegated to (default: None). If not provided, one is - created using a default configuration chain. - - security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and - the training job is created without VPC config. - - subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is - created without VPC config. - - tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` - and the training job is created without tags. - - volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an - Amazon Elastic Block Storage (EBS) volume attached to the training instance. - Defaults to ``None``. - - volume_size (int): The size in GB of the storage volume for storing input and output - data during training. Defaults to ``30``. - - encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between - training containers is encrypted for the training job. Defaults to ``False``. - - spark_config (SparkConfig): Configurations to the Spark application that runs on - Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri - will be used for training. Note that ``image_uri`` can not be specified at the - same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. - - use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for - training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set. - Defaults to ``False``. - - max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. - After this amount of time Amazon SageMaker will stop waiting for managed spot training - job to complete. Defaults to ``None``. - - disable_output_compression (bool): Optional. When set to true, Model is uploaded to - Amazon S3 without compression after training finishes. - - use_torchrun (bool): Specifies whether to use torchrun for distributed training. - Defaults to ``False``. - - use_mpirun (bool): Specifies whether to use mpirun for distributed training. - Defaults to ``False``. - - nproc_per_node (int): Optional. Specifies the number of processes per node for - distributed training. Defaults to ``None``. - This is defined automatically configured on the instance type. - """ - self.max_parallel_jobs = max_parallel_jobs - - if self.max_parallel_jobs <= 0: - raise ValueError("max_parallel_jobs must be greater than 0.") - - if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun and not use_mpirun) - or (spark_config is None and use_torchrun and not use_mpirun) - or (spark_config is None and not use_torchrun and use_mpirun) - ): - raise ValueError( - "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun or use_mpirun. " - + "Please provide instance_count = 1" - ) - - self.job_settings = _JobSettings( - dependencies=dependencies, - pre_execution_commands=pre_execution_commands, - pre_execution_script=pre_execution_script, - environment_variables=environment_variables, - image_uri=image_uri, - include_local_workdir=include_local_workdir, - custom_file_filter=custom_file_filter, - instance_count=instance_count, - instance_type=instance_type, - job_conda_env=job_conda_env, - job_name_prefix=job_name_prefix, - keep_alive_period_in_seconds=keep_alive_period_in_seconds, - max_retry_attempts=max_retry_attempts, - max_runtime_in_seconds=max_runtime_in_seconds, - role=role, - s3_kms_key=s3_kms_key, - s3_root_uri=s3_root_uri, - sagemaker_session=sagemaker_session, - security_group_ids=security_group_ids, - subnets=subnets, - tags=tags, - volume_kms_key=volume_kms_key, - volume_size=volume_size, - encrypt_inter_container_traffic=encrypt_inter_container_traffic, - spark_config=spark_config, - use_spot_instances=use_spot_instances, - max_wait_time_in_seconds=max_wait_time_in_seconds, - disable_output_compression=disable_output_compression, - use_torchrun=use_torchrun, - use_mpirun=use_mpirun, - nproc_per_node=nproc_per_node, - ) - - self._state_condition = threading.Condition() - self._pending_request_queue = deque() - # For thread safety, see - # https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm - self._running_jobs = dict() - self._shutdown = False - - self._workers: ThreadPoolExecutor = None - - def submit(self, func, *args, **kwargs): - """Execute the input function as a SageMaker job asynchronously. - - Args: - func: Python function to run as a SageMaker job. - *args: Positional arguments to the input function. - **kwargs: keyword arguments to the input function - """ - if self._shutdown: - raise RuntimeError("Cannot schedule new remote function executions after shutdown") - - self._validate_submit_args(func, *args, **kwargs) - - with self._state_condition: - future = Future() - - run_info = None - if _RunContext.get_current_run() is not None: - run = _RunContext.get_current_run() - run_info = _RunInfo(run.experiment_name, run.run_name) - - self._pending_request_queue.append( - _SubmitRequest(future, self.job_settings, func, args, kwargs, run_info) - ) - - if self._workers is None: - self._workers = ThreadPoolExecutor(2) - self._workers.submit(_submit_worker, self) - self._workers.submit(_polling_worker, self) - - self._state_condition.notify_all() - - return future - - def map(self, func, *iterables): - """Return an iterator that applies function to every item of iterable, yielding the results. - - If additional iterables arguments are passed, function must take that many arguments and - is applied to the items from all iterables in parallel. With multiple iterables, the - iterator stops when the shortest iterable is exhausted. - - Args: - func: Python function to run as a SageMaker job. - iterables: Arguments of the input python function. - """ - - futures = map(self.submit, itertools.repeat(func), *iterables) - return [future.result() for future in futures] - - def shutdown(self): - """Prevent more function executions to be submitted to this executor.""" - with self._state_condition: - self._shutdown = True - - # give a signal to the submitting worker so that it doesn't block on empty queue forever - self._pending_request_queue.append(None) - - self._state_condition.notify_all() - - if self._workers is not None: - self._workers.shutdown(wait=True) - - def __enter__(self): - """Create an executor instance and return it""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Make sure the executor instance is shutdown.""" - self.shutdown() - return False - - @staticmethod - def _validate_submit_args(func, *args, **kwargs): - """Validates input args passed to submit method.""" - - full_arg_spec = inspect.getfullargspec(func) - - # args related validations - - is_accepting_variable_positional_args = full_arg_spec.varargs is not None - num_default_positional_args = len(full_arg_spec.defaults) if full_arg_spec.defaults else 0 - minimum_num_expected_positional_args = len(full_arg_spec.args) - num_default_positional_args - - if not is_accepting_variable_positional_args and len(args) > len(full_arg_spec.args): - raise TypeError( - f"{func.__name__}() takes {len(full_arg_spec.args)} positional " - + f"{'arguments' if len(full_arg_spec.args) > 1 else 'argument'} but {len(args)} " - + f"{'were' if len(args) > 1 else 'was'} given." - ) - - if len(args) < minimum_num_expected_positional_args: - missing_positional_args = full_arg_spec.args[ - len(args) : minimum_num_expected_positional_args - ] - missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args)) - if missing_args: - missing_args_str = ( - ", ".join(map(lambda x: f"'{x}'", missing_args[:-1])) - + f", and '{missing_args[-1]}'" - if len(missing_args) > 1 - else f"'{missing_args[0]}'" - ) - raise TypeError( - f"{func.__name__}() missing {len(missing_args)} required positional " - + f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}" - ) - - # kwargs related validations - - for k in kwargs: - if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k): - raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'") - if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args: - raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'") - - missing_kwargs = [ - k - for k in full_arg_spec.kwonlyargs - if k not in full_arg_spec.kwonlydefaults and k not in kwargs - ] - if missing_kwargs: - missing_kwargs_string = ( - ", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1])) - + f", and '{missing_kwargs[-1]}'" - if len(missing_kwargs) > 1 - else f"'{missing_kwargs[0]}'" - ) - - raise TypeError( - f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only " - + f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: " - + f"{missing_kwargs_string}" - ) - - -class Future(object): - """Class representing a reference to a SageMaker job result. - - Reference to the SageMaker job created as a result of the remote function run. The job may - or may not have finished running. - """ - - def __init__(self): - self._condition = threading.Condition() - self._state = _PENDING - self._job = None - self._exception = None - self._return = None - - @staticmethod - def from_describe_response(describe_training_job_response, sagemaker_session): - """Construct a Future from a describe_training_job_response object.""" - future = Future() - job_exception = None - client_exception = None - job_return = None - job = _Job.from_describe_response(describe_training_job_response, sagemaker_session) - if describe_training_job_response["TrainingJobStatus"] in ["Stopping", "Stopped"]: - state = _CANCELLED - elif describe_training_job_response["TrainingJobStatus"] == "Completed": - state = _FINISHED - try: - job_return = serialization.deserialize_obj_from_s3( - sagemaker_session=sagemaker_session, - s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - - ) - except DeserializationError as e: - client_exception = e - except ServiceError as e: - client_exception = e - elif describe_training_job_response["TrainingJobStatus"] == "Failed": - state = _FINISHED - try: - job_exception = serialization.deserialize_exception_from_s3( - sagemaker_session=sagemaker_session, - s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - - ) - except ServiceError as serr: - chained_e = serr.__cause__ - if ( - isinstance(chained_e, ClientError) - and chained_e.response["Error"]["Code"] == "404" # pylint: disable=no-member - and chained_e.response["Error"]["Message"] # pylint: disable=no-member - == "Not Found" - ): - if ( - "FailureReason" in describe_training_job_response - and describe_training_job_response["FailureReason"] - and "RuntimeEnvironmentError: " - in describe_training_job_response["FailureReason"] - ): - failure_msg = describe_training_job_response["FailureReason"].replace( - "RuntimeEnvironmentError: ", "" - ) - job_exception = RuntimeEnvironmentError(failure_msg) - else: - job_exception = RemoteFunctionError( - "Failed to execute remote function. " - + "Check corresponding job for details." - ) - else: - job_exception = serr - except DeserializationError as e: - client_exception = e - else: - state = _RUNNING - - future._job = job - future._state = state - future._exception = job_exception or client_exception - future._return = job_return - return future - - def _start_and_notify( - self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None - ): - """Start and record the newly created job in the future object. - - The job is recorded if one is successfully started. Otherwise, the exception is - recorded. The state update is broadcast to other waiting threads. - """ - with self._condition: - if self._state in [_PENDING]: - - try: - self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info) - except (Exception,) as e: # pylint: disable=broad-except - self._exception = e - self._state = _FINISHED - self._condition.notify_all() - return None - - self._state = _RUNNING - self._condition.notify_all() - return self._job - return None - - def result(self, timeout: float = None) -> Any: - """Returns the SageMaker job result. - - This method waits for the SageMaker job created from the remote function execution to - complete for up to the timeout value (if specified). If timeout is ``None``, - this method will wait until the SageMaker job completes. - - Args: - timeout (float): Timeout in seconds to wait until the job is completed. ``None`` by - default. - - Returns: - The Python object returned by the remote function. - """ - try: - self.wait(timeout) - except UnexpectedStatusException: - pass - - with self._condition: - if self._state == _PENDING: - raise RuntimeError() - - if self._state == _RUNNING: - if self._job.describe()["TrainingJobStatus"] == "Completed": - self._return = serialization.deserialize_obj_from_s3( - sagemaker_session=self._job.sagemaker_session, - s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - - ) - self._state = _FINISHED - return self._return - if self._job.describe()["TrainingJobStatus"] == "Failed": - try: - self._exception = serialization.deserialize_exception_from_s3( - sagemaker_session=self._job.sagemaker_session, - s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - - ) - except ServiceError as serr: - chained_e = serr.__cause__ - if ( - isinstance(chained_e, ClientError) - and chained_e.response["Error"]["Code"] # pylint: disable=no-member - == "404" - and chained_e.response["Error"]["Message"] # pylint: disable=no-member - == "Not Found" - ): - if ( - "FailureReason" in self._job.describe() - and self._job.describe()["FailureReason"] - and "RuntimeEnvironmentError: " - in self._job.describe()["FailureReason"] - ): - failure_msg = self._job.describe()["FailureReason"].replace( - "RuntimeEnvironmentError: ", "" - ) - self._exception = RuntimeEnvironmentError(failure_msg) - else: - self._exception = RemoteFunctionError( - "Failed to execute remote function. " - + "Check corresponding job for details." - ) - else: - self._exception = serr - self._state = _FINISHED - elif self._job.describe()["TrainingJobStatus"] == "Stopped": - self._state = _CANCELLED - raise RemoteFunctionError("Job for remote function has been aborted.") - else: - raise TimeoutError( - "Job for remote function timed out before reaching a termination status." - ) - - if self._state == _FINISHED: - if self._exception: - raise self._exception - return self._return - - return None - - def wait( - self, - timeout: int = None, - ) -> None: - """Wait for the underlying SageMaker job to complete. - - This method waits for the SageMaker job created as a result of the remote function run - to complete for up to the timeout value (if specified). If timeout is ``None``, this method - will block until the job is completed. - - Args: - timeout (int): Timeout in seconds to wait until the job is completed before it is - stopped. Defaults to ``None``. - - Returns: - None - """ - - with self._condition: - if self._state == _PENDING: - self._condition.wait(timeout=timeout) - - if self._state == _RUNNING: - self._job.wait(timeout=timeout) - - def cancel(self) -> bool: - """Cancel the function execution. - - This method prevents the SageMaker job being created or stops the underlying SageMaker job - early if it is already in progress. - - Returns: - ``True`` if the underlying SageMaker job created as a result of the remote function - run is cancelled. - """ - with self._condition: - if self._state == _FINISHED: - return False - if self._state == _CANCELLED: - return True - - if self._job: - self._job.stop() - self._state = _CANCELLED - return True - - def running(self) -> bool: - """Check if the underlying SageMaker job is running. - - Returns: - ``True`` if the underlying SageMaker job is still running. ``False``, otherwise. - """ - with self._condition: - return self._state == _RUNNING - - def cancelled(self) -> bool: - """Check if the underlying SageMaker job was cancelled. - - Returns: - ``True`` if the underlying SageMaker job was cancelled. ``False``, otherwise. - """ - with self._condition: - return self._state == _CANCELLED - - def done(self) -> bool: - """Check if the underlying SageMaker job is finished. - - Returns: - ``True`` if the underlying SageMaker job finished running. ``False``, otherwise. - """ - with self._condition: - if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [ - "Completed", - "Failed", - ]: - self._state = _FINISHED - return True - - if self._state == _FINISHED: - return True - - return False - - -def get_future(job_name, sagemaker_session=None) -> Future: - """Get a future object with information about a job with the given job_name. - - Args: - job_name (str): name of the underlying SageMaker job created as a result of the remote - function run. - - sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions - with Amazon SageMaker APIs and any other AWS services needed. - - Returns: - A `sagemaker.remote_function.client.Future` instance. - """ - if not sagemaker_session: - sagemaker_session = Session() - describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=job_name - ) - return Future.from_describe_response(describe_training_job_response, sagemaker_session) - - -def list_futures(job_name_prefix, sagemaker_session=None): - """Generates Future objects with information about jobs with given job_name_prefix. - - Args: - job_name_prefix (str): A prefix used to identify the SageMaker jobs associated with remote - function run. - sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions - with Amazon SageMaker APIs and any other AWS services needed. - - Yields: - A `sagemaker.remote_function.client.Future` instance. - """ - if not sagemaker_session: - sagemaker_session = Session() - job_name = name_from_base(job_name_prefix) - # perform the following transformation because we might have trimmed the job_name_prefix while - # creating the job. - transformed_job_name_prefix = base_from_name(job_name) - next_token = None - list_training_job_kwargs = {"NameContains": transformed_job_name_prefix} - while True: - if next_token: - list_training_job_kwargs["NextToken"] = next_token - list_training_job_response = sagemaker_session.sagemaker_client.list_training_jobs( - **list_training_job_kwargs - ) - training_job_names = [ - job["TrainingJobName"] for job in list_training_job_response["TrainingJobSummaries"] - ] - for training_job_name in training_job_names: - describe_training_job_response = ( - sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=training_job_name - ) - ) - yield Future.from_describe_response(describe_training_job_response, sagemaker_session) - if "NextToken" in list_training_job_response: - next_token = list_training_job_response["NextToken"] - else: - break diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/__init__.py b/sagemaker-core/src/sagemaker/core/remote_function/core/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/_custom_dispatch_table.py b/sagemaker-core/src/sagemaker/core/remote_function/core/_custom_dispatch_table.py deleted file mode 100644 index 3217e88672..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/_custom_dispatch_table.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function data serializer/deserializer.""" -from __future__ import absolute_import - -from sagemaker.core.remote_function.errors import SerializationError - -from sagemaker.core.helper.pipeline_variable import PipelineVariable -from sagemaker.core.workflow.parameters import ( - ParameterInteger, - ParameterFloat, - ParameterString, - ParameterBoolean, -) -from sagemaker.core.workflow.execution_variables import ExecutionVariable -from sagemaker.core.workflow.properties import ( - Properties, - PropertiesMap, - PropertiesList, -) - - -# Lazy import to avoid circular dependency -# DelayedReturn is in MLOps package which depends on Core -def _get_delayed_return_class(): - """Lazy import of DelayedReturn to avoid circular dependency.""" - try: - from sagemaker.mlops.workflow.function_step import DelayedReturn - - return DelayedReturn - except ImportError: - # If MLOps is not installed, return None - return None - - -def _pipeline_variable_reducer(pipeline_variable): - """Reducer for pipeline variable.""" - - raise SerializationError( - """Please pass the pipeline variable to the function decorated with @step as an argument. - Referencing to a pipeline variable from within the function - or passing a pipeline variable nested in a data structure are not supported.""" - ) - - -# Build dispatch table with lazy loading for DelayedReturn -dispatch_table = { - ParameterInteger: _pipeline_variable_reducer, - ParameterFloat: _pipeline_variable_reducer, - ParameterString: _pipeline_variable_reducer, - ParameterBoolean: _pipeline_variable_reducer, - ExecutionVariable: _pipeline_variable_reducer, - PipelineVariable: _pipeline_variable_reducer, - Properties: _pipeline_variable_reducer, - PropertiesMap: _pipeline_variable_reducer, - PropertiesList: _pipeline_variable_reducer, -} - -# Add DelayedReturn to dispatch table if MLOps is available -_delayed_return_class = _get_delayed_return_class() -if _delayed_return_class is not None: - dispatch_table[_delayed_return_class] = _pipeline_variable_reducer diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py b/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py deleted file mode 100644 index 491267b35f..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function data serializer/deserializer.""" -from __future__ import absolute_import - -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from typing import Any, Union, Dict, List, Tuple - -from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.core.serialization import deserialize_obj_from_s3 -from sagemaker.core.workflow.step_outputs import get_step - - -@dataclass -class Context: - """Context for an execution.""" - - step_name: str = None - execution_id: str = None - property_references: Dict[str, str] = field(default_factory=dict) - serialize_output_to_json: bool = False - func_step_s3_dir: str = None - - -@dataclass -class _Parameter: - """Parameter to a function.""" - - name: str - - -class _ParameterInteger(_Parameter): - """Integer parameter to a function.""" - - ... - - -class _ParameterFloat(_Parameter): - """Float parameter to a function.""" - - ... - - -class _ParameterString(_Parameter): - """String parameter to a function.""" - - ... - - -class _ParameterBoolean(_Parameter): - """Boolean parameter to a function.""" - - ... - - -@dataclass -class _Properties: - """Properties of classic steps.""" - - path: str - - -@dataclass -class _ExecutionVariable: - """Execution variable.""" - - name: str - - -@dataclass -class _S3BaseUriIdentifier: - """Identifies that the class refers to function step s3 base uri. - - The s3_base_uri = s3_root_uri + pipeline_name. - This identifier is resolved in function step runtime by SDK. - """ - - NAME = "S3_BASE_URI" - - -@dataclass -class _DelayedReturn: - """Delayed return from a function.""" - - uri: Union[_Properties, List[Union[str, _Parameter, _ExecutionVariable]]] - reference_path: Tuple = field(default_factory=tuple) - - -class _ExecutionVariableResolver: - """Resolve execution variables.""" - - def __init__(self, context: Context): - """Resolve execution variables.""" - self._context = context - - def resolve(self, execution_variable: _ExecutionVariable): - """Resolve a single execution variable. - - Args: - execution_variable: execution variable to resolve. - Returns: - resolved value - """ - return self._context.property_references[f"Execution.{execution_variable.name}"] - - -class _ParameterResolver: - """Resolve parameters.""" - - def __init__(self, context: Context): - """Resolve parameters.""" - self._context = context - - def resolve(self, parameter: _Parameter): - """Resolve a single property reference. - - Args: - parameter: parameter to resolve. - Returns: - resolved value - """ - if isinstance(parameter, _ParameterInteger): - return int(self._context.property_references[f"Parameters.{parameter.name}"]) - if isinstance(parameter, _ParameterFloat): - return float(self._context.property_references[f"Parameters.{parameter.name}"]) - if isinstance(parameter, _ParameterString): - return self._context.property_references[f"Parameters.{parameter.name}"] - - return self._context.property_references[f"Parameters.{parameter.name}"] == "true" - - -class _PropertiesResolver: - """Resolve classic step properties.""" - - def __init__(self, context: Context): - """Resolve classic step properties.""" - self._context = context - - def resolve(self, properties: _Properties): - """Resolve classic step properties. - - Args: - properties: classic step properties. - Returns: - resolved value - """ - return self._context.property_references[properties.path] - - -class _DelayedReturnResolver: - """Resolve delayed returns.""" - - def __init__( - self, - delayed_returns: List[_DelayedReturn], - properties_resolver: _PropertiesResolver, - parameter_resolver: _ParameterResolver, - execution_variable_resolver: _ExecutionVariableResolver, - s3_base_uri: str, - **settings, - ): - """Resolve delayed return. - - Args: - delayed_returns: list of delayed returns to resolve. - properties_resolver: resolver used to resolve step properties. - parameter_resolver: resolver used to pipeline parameters. - execution_variable_resolver: resolver used to resolve execution variables. - s3_base_uri (str): the s3 base uri of the function step that - the serialized artifacts will be uploaded to. - The s3_base_uri = s3_root_uri + pipeline_name. - **settings: settings to pass to the deserialization function. - """ - self._s3_base_uri = s3_base_uri - self._parameter_resolver = parameter_resolver - self._execution_variable_resolver = execution_variable_resolver - self._properties_resolver = properties_resolver - # different delayed returns can have the same uri, so we need to dedupe - uris = { - self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns - } - - def deserialization_task(uri): - return uri, deserialize_obj_from_s3( - sagemaker_session=settings["sagemaker_session"], - s3_uri=uri, - ) - - with ThreadPoolExecutor() as executor: - self._deserialized_objects = dict(executor.map(deserialization_task, uris)) - - def resolve(self, delayed_return: _DelayedReturn) -> Any: - """Resolve a single delayed return. - - Args: - delayed_return: delayed return to resolve. - Returns: - resolved delayed return. - """ - deserialized_obj = self._deserialized_objects[ - self._resolve_delayed_return_uri(delayed_return) - ] - return _retrieve_child_item(delayed_return, deserialized_obj) - - def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn): - """Resolve the s3 uri of the delayed return.""" - if isinstance(delayed_return.uri, _Properties): - return self._properties_resolver.resolve(delayed_return.uri) - - # Keep the following old resolution logics to keep backward compatible - uri = [] - for component in delayed_return.uri: - if isinstance(component, _Parameter): - uri.append(self._parameter_resolver.resolve(component)) - elif isinstance(component, _ExecutionVariable): - uri.append(self._execution_variable_resolver.resolve(component)) - elif isinstance(component, _S3BaseUriIdentifier): - uri.append(self._s3_base_uri) - else: - uri.append(component) - return s3_path_join(*uri) - - -def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any): - """Retrieve child item from deserialized object.""" - result = deserialized_obj - for component in delayed_return.reference_path: - result = result[component[1]] - return result - - -def resolve_pipeline_variables( - context: Context, - func_args: Tuple, - func_kwargs: Dict, - s3_base_uri: str, - **settings, -): - """Resolve pipeline variables. - - Args: - context: context for the execution. - func_args: function args. - func_kwargs: function kwargs. - s3_base_uri: the s3 base uri of the function step that the serialized artifacts - will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. - **settings: settings to pass to the deserialization function. - """ - - delayed_returns = [] - - if func_args is not None: - for arg in func_args: - if isinstance(arg, _DelayedReturn): - delayed_returns.append(arg) - if func_kwargs is not None: - for arg in func_kwargs.values(): - if isinstance(arg, _DelayedReturn): - delayed_returns.append(arg) - - # build the resolvers - parameter_resolver = _ParameterResolver(context) - execution_variable_resolver = _ExecutionVariableResolver(context) - properties_resolver = _PropertiesResolver(context) - delayed_return_resolver = _DelayedReturnResolver( - delayed_returns=delayed_returns, - properties_resolver=properties_resolver, - parameter_resolver=parameter_resolver, - execution_variable_resolver=execution_variable_resolver, - s3_base_uri=s3_base_uri, - **settings, - ) - - # resolve the pipeline variables - resolved_func_args = None - if func_args is not None: - resolved_func_args = [] - for arg in func_args: - if isinstance(arg, _Parameter): - resolved_func_args.append(parameter_resolver.resolve(arg)) - elif isinstance(arg, _ExecutionVariable): - resolved_func_args.append(execution_variable_resolver.resolve(arg)) - elif isinstance(arg, _Properties): - resolved_func_args.append(properties_resolver.resolve(arg)) - elif isinstance(arg, _DelayedReturn): - resolved_func_args.append(delayed_return_resolver.resolve(arg)) - else: - resolved_func_args.append(arg) - resolved_func_args = tuple(resolved_func_args) - - resolved_func_kwargs = None - if func_kwargs is not None: - resolved_func_kwargs = {} - for key, value in func_kwargs.items(): - if isinstance(value, _Parameter): - resolved_func_kwargs[key] = parameter_resolver.resolve(value) - elif isinstance(value, _ExecutionVariable): - resolved_func_kwargs[key] = execution_variable_resolver.resolve(value) - elif isinstance(value, _Properties): - resolved_func_kwargs[key] = properties_resolver.resolve(value) - elif isinstance(value, _DelayedReturn): - resolved_func_kwargs[key] = delayed_return_resolver.resolve(value) - else: - resolved_func_kwargs[key] = value - - return resolved_func_args, resolved_func_kwargs - - -def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict): - """Convert pipeline variables to pickleable. - - Args: - func_args: function args. - func_kwargs: function kwargs. - """ - - from sagemaker.core.helper.pipeline_variable import PipelineVariable - - from sagemaker.mlops.workflow.function_step import DelayedReturn - - def convert(arg): - if isinstance(arg, DelayedReturn): - return _DelayedReturn( - uri=get_step(arg)._properties.OutputDataConfig.S3OutputPath._pickleable, - reference_path=arg._reference_path, - ) - - if isinstance(arg, PipelineVariable): - return arg._pickleable - - return arg - - converted_func_args = tuple(convert(arg) for arg in func_args) - converted_func_kwargs = {key: convert(arg) for key, arg in func_kwargs.items()} - - return converted_func_args, converted_func_kwargs diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py deleted file mode 100644 index 8871f6727f..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function data serializer/deserializer.""" -from __future__ import absolute_import - -import dataclasses -import json - -import io - -import sys -import hashlib -import pickle - -from typing import Any, Callable, Union - -import cloudpickle -from tblib import pickling_support - -from sagemaker.core.remote_function.errors import ( - ServiceError, - SerializationError, - DeserializationError, -) -from sagemaker.core.s3 import S3Downloader, S3Uploader -from sagemaker.core.helper.session_helper import Session -from ._custom_dispatch_table import dispatch_table - -# Note: do not use os.path.join for s3 uris, fails on windows - - -def _get_python_version(): - """Returns the current python version.""" - return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - - -@dataclasses.dataclass -class _MetaData: - """Metadata about the serialized data or functions.""" - - sha256_hash: str - version: str = "2023-04-24" - python_version: str = _get_python_version() - serialization_module: str = "cloudpickle" - - def to_json(self): - """Converts metadata to json string.""" - return json.dumps(dataclasses.asdict(self)).encode() - - @staticmethod - def from_json(s): - """Converts json string to metadata object.""" - try: - obj = json.loads(s) - except json.decoder.JSONDecodeError: - raise DeserializationError("Corrupt metadata file. It is not a valid json file.") - - sha256_hash = obj.get("sha256_hash") - metadata = _MetaData(sha256_hash=sha256_hash) - metadata.version = obj.get("version") - metadata.python_version = obj.get("python_version") - metadata.serialization_module = obj.get("serialization_module") - - if not sha256_hash: - raise DeserializationError( - "Corrupt metadata file. SHA256 hash for the serialized data does not exist. " - "Please make sure to install SageMaker SDK version >= 2.156.0 on the client side " - "and try again." - ) - - if not ( - metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle" - ): - raise DeserializationError( - f"Corrupt metadata file. Serialization approach {s} is not supported." - ) - - return metadata - - -class CloudpickleSerializer: - """Serializer using cloudpickle.""" - - @staticmethod - def serialize(obj: Any) -> bytes: - """Serializes data object and uploads it to S3. - - Args: - obj: object to be serialized and persisted - Raises: - SerializationError: when fail to serialize object to bytes. - """ - try: - io_buffer = io.BytesIO() - custom_pickler = cloudpickle.CloudPickler(io_buffer) - dt = pickle.Pickler.dispatch_table.__get__(custom_pickler) # pylint: disable=no-member - new_dt = dt.new_child(dispatch_table) - pickle.Pickler.dispatch_table.__set__( # pylint: disable=no-member - custom_pickler, new_dt - ) - custom_pickler.dump(obj) - return io_buffer.getvalue() - except Exception as e: - if isinstance( - e, NotImplementedError - ) and "Instance of Run type is not allowed to be pickled." in str(e): - raise SerializationError( - """You are trying to pass a sagemaker.experiments.run.Run object to - a remote function - or are trying to access a global sagemaker.experiments.run.Run object - from within the function. This is not supported. - You must use `load_run` to load an existing Run in the remote function - or instantiate a new Run in the function.""" - ) - - raise SerializationError( - "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) - ) from e - - @staticmethod - def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: - """Downloads from S3 and then deserializes data objects. - - Args: - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - bytes_to_deserialize: bytes to be deserialized. - Returns : - List of deserialized python objects. - Raises: - DeserializationError: when fail to serialize object to bytes. - """ - - try: - return cloudpickle.loads(bytes_to_deserialize) - except Exception as e: - raise DeserializationError( - "Error when deserializing bytes downloaded from {}: {}. " - "NOTE: this may be caused by inconsistent sagemaker python sdk versions " - "where remote function runs versus the one used on client side. " - "If the sagemaker versions do not match, a warning message would " - "be logged starting with 'Inconsistent sagemaker versions found'. " - "Please check it to validate.".format(s3_uri, repr(e)) - ) from e - - -# TODO: use dask serializer in case dask distributed is installed in users' environment. -def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None -): - """Serializes function and uploads it to S3. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying Boto3 session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - func: function to be serialized and persisted - Raises: - SerializationError: when fail to serialize function to bytes. - """ - - _upload_payload_and_metadata_to_s3( - bytes_to_upload=CloudpickleSerializer.serialize(func), - s3_uri=s3_uri, - sagemaker_session=sagemaker_session, - s3_kms_key=s3_kms_key, - ) - - -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: - """Downloads from S3 and then deserializes data objects. - - This method downloads the serialized training job outputs to a temporary directory and - then deserializes them using dask. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying sagemaker session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - Returns : - The deserialized function. - Raises: - DeserializationError: when fail to serialize function to bytes. - """ - metadata = _MetaData.from_json( - _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) - ) - - bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) - - _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize - ) - - return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) - - -def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None -): - """Serializes data object and uploads it to S3. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying Boto3 session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - obj: object to be serialized and persisted - Raises: - SerializationError: when fail to serialize object to bytes. - """ - - _upload_payload_and_metadata_to_s3( - bytes_to_upload=CloudpickleSerializer.serialize(obj), - s3_uri=s3_uri, - sagemaker_session=sagemaker_session, - s3_kms_key=s3_kms_key, - ) - - -def json_serialize_obj_to_s3( - obj: Any, - json_key: str, - sagemaker_session: Session, - s3_uri: str, - s3_kms_key: str = None, -): - """Json serializes data object and uploads it to S3. - - If a function step's output is data referenced by other steps via JsonGet, - its output should be json serialized and uploaded to S3. - - Args: - obj: (Any) object to be serialized and persisted. - json_key: (str) the json key pointing to function step output. - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying Boto3 session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - """ - json_serialized_result = {} - try: - to_dump = {json_key: obj, "Exception": None} - json_serialized_result = json.dumps(to_dump) - except TypeError as e: - if "is not JSON serializable" in str(e): - to_dump = { - json_key: None, - "Exception": f"The function return ({obj}) is not JSON serializable.", - } - json_serialized_result = json.dumps(to_dump) - - S3Uploader.upload_string_as_file_body( - body=json_serialized_result, - desired_s3_uri=s3_uri, - sagemaker_session=sagemaker_session, - kms_key=s3_kms_key, - ) - - -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: - """Downloads from S3 and then deserializes data objects. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying sagemaker session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - Returns : - Deserialized python objects. - Raises: - DeserializationError: when fail to serialize object to bytes. - """ - - metadata = _MetaData.from_json( - _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) - ) - - bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) - - _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize - ) - - return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) - - -def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None -): - """Serializes exception with traceback and uploads it to S3. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying Boto3 session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - exc: Exception to be serialized and persisted - Raises: - SerializationError: when fail to serialize object to bytes. - """ - pickling_support.install() - - _upload_payload_and_metadata_to_s3( - bytes_to_upload=CloudpickleSerializer.serialize(exc), - s3_uri=s3_uri, - sagemaker_session=sagemaker_session, - s3_kms_key=s3_kms_key, - ) - - -def _upload_payload_and_metadata_to_s3( - bytes_to_upload: Union[bytes, io.BytesIO], - s3_uri: str, - sagemaker_session: Session, - s3_kms_key, -): - """Uploads serialized payload and metadata to s3. - - Args: - bytes_to_upload (bytes): Serialized bytes to upload. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying Boto3 session which AWS service calls are delegated to. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - """ - _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - - sha256_hash = _compute_hash(bytes_to_upload) - - _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), - f"{s3_uri}/metadata.json", - s3_kms_key, - sagemaker_session, - ) - - -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: - """Downloads from S3 and then deserializes exception. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): - The underlying sagemaker session which AWS service calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - Returns : - Deserialized exception with traceback. - Raises: - DeserializationError: when fail to serialize object to bytes. - """ - - metadata = _MetaData.from_json( - _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) - ) - - bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) - - _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize - ) - - return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) - - -def _upload_bytes_to_s3(b: Union[bytes, io.BytesIO], s3_uri, s3_kms_key, sagemaker_session): - """Wrapping s3 uploading with exception translation for remote function.""" - try: - S3Uploader.upload_bytes(b, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session) - except Exception as e: - raise ServiceError( - "Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e)) - ) from e - - -def _read_bytes_from_s3(s3_uri, sagemaker_session): - """Wrapping s3 downloading with exception translation for remote function.""" - try: - return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session) - except Exception as e: - raise ServiceError( - "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) - ) from e - - -def _compute_hash(buffer: bytes) -> str: - """Compute the sha256 hash""" - return hashlib.sha256(buffer).hexdigest() - - -def _perform_integrity_check(expected_hash_value: str, buffer: bytes): - """Performs integrity checks for serialized code/arguments uploaded to s3. - - Verifies whether the hash read from s3 matches the hash calculated - during remote function execution. - """ - actual_hash_value = _compute_hash(buffer=buffer) - if expected_hash_value != actual_hash_value: - raise DeserializationError( - "Integrity check for the serialized function or data failed. " - "Please restrict access to your S3 bucket" - ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py deleted file mode 100644 index c7ee86f8a7..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker job function serializer/deserializer.""" -from __future__ import absolute_import - -import os -from dataclasses import dataclass -from typing import Any - - -from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.core.pipeline_variables import ( - Context, - resolve_pipeline_variables, -) - -import sagemaker.core.remote_function.core.serialization as serialization -from sagemaker.core.helper.session_helper import Session - - -logger = logging_config.get_logger() - - -FUNCTION_FOLDER = "function" -ARGUMENTS_FOLDER = "arguments" -RESULTS_FOLDER = "results" -EXCEPTION_FOLDER = "exception" -JSON_SERIALIZED_RESULT_KEY = "Result" -JSON_RESULTS_FILE = "results.json" - - -@dataclass -class _SerializedData: - """Data class to store serialized function and arguments""" - - func: bytes - args: bytes - - -class StoredFunction: - """Class representing a remote function stored in S3.""" - - def __init__( - self, - sagemaker_session: Session, - s3_base_uri: str, - s3_kms_key: str = None, - context: Context = Context(), - ): - """Construct a StoredFunction object. - - Args: - sagemaker_session: (sagemaker.session.Session): The underlying sagemaker session which - AWS service calls are delegated to. - s3_base_uri: the base uri to which serialized artifacts will be uploaded. - s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. - context: Build or run context of a pipeline step. - """ - self.sagemaker_session = sagemaker_session - self.s3_base_uri = s3_base_uri - self.s3_kms_key = s3_kms_key - self.context = context - - # For pipeline steps, function code is at: base/step_name/build_timestamp/ - # For results, path is: base/step_name/build_timestamp/execution_id/ - # This ensures uniqueness: build_timestamp per build, execution_id per run - if context.step_name and context.func_step_s3_dir: - # Pipeline step: include build timestamp in both paths - self.func_upload_path = s3_path_join( - s3_base_uri, context.step_name, context.func_step_s3_dir - ) - self.results_upload_path = s3_path_join( - s3_base_uri, context.step_name, context.func_step_s3_dir, context.execution_id - ) - else: - # Regular remote function: original behavior - self.func_upload_path = s3_path_join( - s3_base_uri, context.step_name, context.func_step_s3_dir - ) - self.results_upload_path = s3_path_join( - s3_base_uri, context.execution_id, context.step_name - ) - - def save(self, func, *args, **kwargs): - """Serialize and persist the function and arguments. - - Args: - func: the python function. - args: the positional arguments to func. - kwargs: the keyword arguments to func. - Returns: - None - """ - - logger.info( - "Serializing function code to %s", s3_path_join(self.func_upload_path, FUNCTION_FOLDER) - ) - serialization.serialize_func_to_s3( - func=func, - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - s3_kms_key=self.s3_kms_key, - - ) - - logger.info( - "Serializing function arguments to %s", - s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - ) - - serialization.serialize_obj_to_s3( - obj=(args, kwargs), - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - - s3_kms_key=self.s3_kms_key, - ) - - def save_pipeline_step_function(self, serialized_data): - """Upload serialized function and arguments to s3. - - Args: - serialized_data (_SerializedData): The serialized function - and function arguments of a function step. - """ - - logger.info( - "Uploading serialized function code to %s", - s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - ) - serialization._upload_payload_and_metadata_to_s3( - bytes_to_upload=serialized_data.func, - - s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - sagemaker_session=self.sagemaker_session, - s3_kms_key=self.s3_kms_key, - ) - - logger.info( - "Uploading serialized function arguments to %s", - s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - ) - serialization._upload_payload_and_metadata_to_s3( - bytes_to_upload=serialized_data.args, - - s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - sagemaker_session=self.sagemaker_session, - s3_kms_key=self.s3_kms_key, - ) - - def load_and_invoke(self) -> Any: - """Load and deserialize the function and the arguments and then execute it.""" - - logger.info( - "Deserializing function code from %s", - s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - ) - func = serialization.deserialize_func_from_s3( - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - - ) - - logger.info( - "Deserializing function arguments from %s", - s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - ) - args, kwargs = serialization.deserialize_obj_from_s3( - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - - ) - - logger.info("Resolving pipeline variables") - resolved_args, resolved_kwargs = resolve_pipeline_variables( - self.context, - args, - kwargs, - - s3_base_uri=self.s3_base_uri, - sagemaker_session=self.sagemaker_session, - ) - - logger.info("Invoking the function") - result = func(*resolved_args, **resolved_kwargs) - - logger.info( - "Serializing the function return and uploading to %s", - s3_path_join(self.results_upload_path, RESULTS_FOLDER), - ) - serialization.serialize_obj_to_s3( - obj=result, - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - - s3_kms_key=self.s3_kms_key, - ) - - if self.context and self.context.serialize_output_to_json: - logger.info( - "JSON Serializing the function return and uploading to %s", - s3_path_join(self.results_upload_path, RESULTS_FOLDER), - ) - serialization.json_serialize_obj_to_s3( - obj=result, - json_key=JSON_SERIALIZED_RESULT_KEY, - sagemaker_session=self.sagemaker_session, - s3_uri=s3_path_join( - os.path.join(self.results_upload_path, RESULTS_FOLDER, JSON_RESULTS_FILE) - ), - s3_kms_key=self.s3_kms_key, - ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/custom_file_filter.py b/sagemaker-core/src/sagemaker/core/remote_function/custom_file_filter.py deleted file mode 100644 index c82cc7eee7..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/custom_file_filter.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker remote function client.""" -from __future__ import absolute_import - -import fnmatch -import os -import shutil -from typing import List, Optional, Callable, Union - -from sagemaker.core.common_utils import resolve_value_from_config -from sagemaker.core.config.config_schema import REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER - - -class CustomFileFilter: - """Configuration that specifies how the local working directory should be packaged.""" - - def __init__(self, *, ignore_name_patterns: List[str] = None): - """Initialize a CustomFileFilter. - - Args: - ignore_name_patterns (List[str]): ignore files or directories with names - that match one of the glob-style patterns. Defaults to None. - """ - - if ignore_name_patterns is None: - ignore_name_patterns = [] - - self._workdir = os.getcwd() - self._ignore_name_patterns = ignore_name_patterns - - @property - def ignore_name_patterns(self): - """Get the ignore name patterns.""" - return self._ignore_name_patterns - - @property - def workdir(self): - """Get the working directory.""" - return self._workdir - - -def resolve_custom_file_filter_from_config_file( - direct_input: Union[Callable[[str, List], List], CustomFileFilter] = None, - sagemaker_session=None, -) -> Union[Callable[[str, List], List], CustomFileFilter, None]: - """Resolve the CustomFileFilter configuration from the config file. - - Args: - direct_input (Callable[[str, List], List], CustomFileFilter): direct input from the user. - sagemaker_session (sagemaker.core.helper.session.Session): sagemaker session. - Returns: - CustomFileFilter: configuration that specifies how the local - working directory should be packaged. - """ - if direct_input is not None: - return direct_input - ignore_name_patterns = resolve_value_from_config( - direct_input=None, - config_path=".".join([REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER, "IgnoreNamePatterns"]), - default_value=None, - sagemaker_session=sagemaker_session, - ) - if ignore_name_patterns is not None: - return CustomFileFilter(ignore_name_patterns=ignore_name_patterns) - return None - - -def copy_workdir( - dst: str, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, -): - """Copy the local working directory to the destination. - - Args: - dst (str): destination path. - custom_file_filter (Union[Callable[[str, List], List], CustomFileFilter): configuration that - specifies how the local working directory should be packaged. - """ - - def _ignore_patterns(path: str, names: List): # pylint: disable=unused-argument - ignored_names = set() - if custom_file_filter.ignore_name_patterns is not None: - for pattern in custom_file_filter.ignore_name_patterns: - ignored_names.update(fnmatch.filter(names, pattern)) - return ignored_names - - def _filter_non_python_files(path: str, names: List) -> List: - """Ignore function for filtering out non python files.""" - to_ignore = [] - for name in names: - full_path = os.path.join(path, name) - if os.path.isfile(full_path): - if not name.endswith(".py"): - to_ignore.append(name) - elif os.path.isdir(full_path): - if name == "__pycache__": - to_ignore.append(name) - else: - to_ignore.append(name) - - return to_ignore - - _ignore = None - _src = os.getcwd() - if not custom_file_filter: - _ignore = _filter_non_python_files - elif callable(custom_file_filter): - _ignore = custom_file_filter - elif isinstance(custom_file_filter, CustomFileFilter): - _ignore = _ignore_patterns - _src = custom_file_filter.workdir - - shutil.copytree( - _src, - dst, - ignore=_ignore, - ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py deleted file mode 100644 index 3f391570cf..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Definitions for reomote job errors and error handling""" -from __future__ import absolute_import - -import os - -from tblib import pickling_support -from sagemaker.core.s3 import s3_path_join -import sagemaker.core.remote_function.core.serialization as serialization - - -DEFAULT_FAILURE_CODE = 1 -FAILURE_REASON_PATH = "/opt/ml/output/failure" - - -@pickling_support.install -class RemoteFunctionError(Exception): - """The base exception class for remote function exceptions""" - - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -@pickling_support.install -class ServiceError(RemoteFunctionError): - """Raised when errors encountered during interaction with SageMaker, S3 service APIs""" - - -@pickling_support.install -class SerializationError(RemoteFunctionError): - """Raised when errors encountered during serialization of remote function objects""" - - -@pickling_support.install -class DeserializationError(RemoteFunctionError): - """Raised when errors encountered during deserialization of remote function objects""" - - -def _get_valid_failure_exit_code(exit_code) -> int: - """Normalize exit code for terminating the process""" - try: - valid_exit_code = int(exit_code) - except (TypeError, ValueError): - valid_exit_code = DEFAULT_FAILURE_CODE - - return valid_exit_code - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if remote function execution failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write(failure_msg) - - -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: - """Handle all exceptions raised during remote function execution. - - Args: - error (Exception): The error to be handled. - sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which - AWS service calls are delegated to. - s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - Returns : - exit_code (int): Exit code to terminate current job. - """ - - failure_reason = repr(error) - if isinstance(error, RemoteFunctionError): - exit_code = DEFAULT_FAILURE_CODE - else: - error_number = getattr(error, "errno", DEFAULT_FAILURE_CODE) - exit_code = _get_valid_failure_exit_code(error_number) - - _write_failure_reason_file(failure_reason) - - serialization.serialize_exception_to_s3( - exc=error, - sagemaker_session=sagemaker_session, - s3_uri=s3_path_join(s3_base_uri, "exception"), - s3_kms_key=s3_kms_key, - ) - - return exit_code diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py deleted file mode 100644 index 2e69f4f116..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for invoking remote function inside a job.""" - -from __future__ import absolute_import - -import argparse -import sys -import json -import os -from typing import TYPE_CHECKING - -import boto3 -from sagemaker.core.remote_function.job import ( - KEY_EXPERIMENT_NAME, - KEY_RUN_NAME, -) - -from sagemaker.core.helper.session_helper import Session -from sagemaker.core.s3 import s3_path_join -from sagemaker.core.remote_function.errors import handle_error -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.core.pipeline_variables import Context - -if TYPE_CHECKING: - from sagemaker.core.experiments.run import Run - - -SUCCESS_EXIT_CODE = 0 - - -def _parse_args(args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--region", type=str, required=True) - parser.add_argument("--s3_base_uri", type=str, required=True) - parser.add_argument("--s3_kms_key", type=str) - parser.add_argument("--run_in_context", type=str) - parser.add_argument("--pipeline_step_name", type=str) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--property_references", nargs="+", type=str, default=[]) - parser.add_argument( - "--serialize_output_to_json", default=False, type=lambda x: (str(x).lower() == "true") - ) - parser.add_argument("--func_step_s3_dir", type=str) - - args, _ = parser.parse_known_args(args) - return args - - -def _get_sagemaker_session(region): - """Get sagemaker session for interacting with AWS or Sagemaker services""" - boto_session = boto3.session.Session(region_name=region) - return Session(boto_session=boto_session) - - -def _load_run_object(run_in_context: str, sagemaker_session: Session) -> "Run": - """Load current run in json string into run object""" - from sagemaker.core.experiments.run import Run - - run_dict = json.loads(run_in_context) - return Run( - experiment_name=run_dict.get(KEY_EXPERIMENT_NAME), - run_name=run_dict.get(KEY_RUN_NAME), - sagemaker_session=sagemaker_session, - ) - - -def _load_pipeline_context(args) -> Context: - """Load pipeline build or run context into context object""" - - pipeline_step_name = args.pipeline_step_name - pipeline_execution_id = args.pipeline_execution_id - property_references = args.property_references - serialize_output_to_json = args.serialize_output_to_json - func_step_s3_dir = args.func_step_s3_dir - - property_references_dict = {} - for i in range(0, len(property_references), 2): - property_references_dict[property_references[i]] = property_references[i + 1] - return Context( - step_name=pipeline_step_name, - execution_id=pipeline_execution_id, - property_references=property_references_dict, - serialize_output_to_json=serialize_output_to_json, - func_step_s3_dir=func_step_s3_dir, - ) - - -def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context -): - """Execute stored remote function""" - from sagemaker.core.remote_function.core.stored_function import StoredFunction - - stored_function = StoredFunction( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - context=context, - ) - - if run_in_context: - run_obj = _load_run_object(run_in_context, sagemaker_session) - with run_obj: - stored_function.load_and_invoke() - else: - stored_function.load_and_invoke() - - -def main(sys_args=None): - """Entry point for invoke function script - - Args: - sys_args (list): List of arguments to parse. If not specified, sys.argv is used. - """ - - logger = logging_config.get_logger() - - exit_code = SUCCESS_EXIT_CODE - - try: - args = _parse_args(sys_args) - region = args.region - s3_base_uri = args.s3_base_uri - s3_kms_key = args.s3_kms_key - run_in_context = args.run_in_context - pipeline_context = _load_pipeline_context(args) - - sagemaker_session = _get_sagemaker_session(region) - _execute_remote_function( - sagemaker_session=sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - run_in_context=run_in_context, - context=pipeline_context, - ) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while invoking the remote function.") - s3_uri = ( - s3_path_join(s3_base_uri, pipeline_context.execution_id, pipeline_context.step_name) - if pipeline_context.step_name - else s3_base_uri - ) - exit_code = handle_error( - error=e, - sagemaker_session=sagemaker_session, - s3_base_uri=s3_uri, - s3_kms_key=s3_kms_key, - ) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py deleted file mode 100644 index 435062db57..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ /dev/null @@ -1,2121 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Helper classes that interact with SageMaker Training service.""" -from __future__ import absolute_import - -import dataclasses -import json -import os -import re -import shutil -import sys -import time -from io import BytesIO -from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING -from urllib.parse import urlparse - -import botocore -from botocore.exceptions import ClientError - -from sagemaker.core.config.config_schema import ( - REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, - REMOTE_FUNCTION_IMAGE_URI, - REMOTE_FUNCTION_DEPENDENCIES, - REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, - REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, - REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, - REMOTE_FUNCTION_INSTANCE_TYPE, - REMOTE_FUNCTION_JOB_CONDA_ENV, - REMOTE_FUNCTION_ROLE_ARN, - REMOTE_FUNCTION_S3_ROOT_URI, - REMOTE_FUNCTION_S3_KMS_KEY_ID, - REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, - REMOTE_FUNCTION_TAGS, - REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, - REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, - REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, -) -from sagemaker.core.experiments._run_context import _RunContext -from sagemaker.core.experiments.run import Run -from sagemaker.core.image_uris import get_base_python_image_uri -from sagemaker.core import image_uris -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation -from sagemaker.core.helper.session_helper import get_execution_role, expand_role, Session -from sagemaker.core.common_utils import ( - name_from_base, - _tmpdir, - resolve_value_from_config, - format_tags, - Tags, -) -from sagemaker.core.s3 import s3_path_join, S3Uploader - -from sagemaker.core.remote_function.core.stored_function import StoredFunction, _SerializedData -from sagemaker.core.remote_function.core.pipeline_variables import Context - -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, -) -from sagemaker.core.remote_function import logging_config -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.custom_file_filter import ( - CustomFileFilter, - copy_workdir, - resolve_custom_file_filter_from_config_file, -) - -# Lazy import to avoid circular dependency - DelayedReturn is in MLOps which depends on Core -# from sagemaker.mlops.workflow.function_step import DelayedReturn -from sagemaker.core.workflow.step_outputs import get_step -from sagemaker.core import exceptions -from sagemaker.core import network as vpc_utils - -from sagemaker.core import logs as sagemaker_logs - -from sagemaker.core.common_utils import ( - _wait_until, - secondary_training_status_changed, - secondary_training_status_message, -) -from sagemaker.core.config.config_utils import _append_sagemaker_config_tags - -if TYPE_CHECKING: - from sagemaker.core.helper.pipeline_variable import PipelineVariable - -# runtime script names -BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" -MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" -ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" -PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" -RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" -SPARK_APP_SCRIPT_NAME = "spark_app.py" - -# training channel names -RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" -REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" -JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" -SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" - -# Spark config channel and file name -SPARK_CONF_CHANNEL_NAME = "conf" -SPARK_CONF_FILE_NAME = "configuration.json" - -# Spark submitted files workspace names on S3 -SPARK_SUBMIT_JARS_WORKSPACE = "sm_rf_spark_jars" -SPARK_SUBMIT_PY_FILES_WORKSPACE = "sm_rf_spark_py_files" -SPARK_SUBMIT_FILES_WORKSPACE = "sm_rf_spark_data_files" -SPARK_CONF_WORKSPACE = "sm_rf_spark_conf" - -# default spark version -DEFAULT_SPARK_VERSION = "3.3" -DEFAULT_SPARK_CONTAINER_VERSION = "v1" - -SPARK_NAME = "spark" - -# run context dictionary keys -KEY_EXPERIMENT_NAME = "experiment_name" -KEY_RUN_NAME = "run_name" - -JOBS_CONTAINER_ENTRYPOINT = [ - "/bin/bash", - f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", -] - -SPARK_APP_SCRIPT_PATH = f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{SPARK_APP_SCRIPT_NAME}" - -ENTRYPOINT_SCRIPT = f""" -#!/bin/bash - -# Entry point for bootstrapping runtime environment and invoking remote function - -set -eu - -PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} -export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs -printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" -export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip -printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" - -printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" -cat /opt/ml/input/config/resourceconfig.json - -printf "INFO: Bootstraping runtime environment.\\n" -python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" -source /opt/ml/input/sm_training.env - -if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] -then - if [ -f "remote_function_conda_env.txt" ] - then - cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt - fi - printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" - cd {JOB_REMOTE_FUNCTION_WORKSPACE} -fi - -if [ -f "remote_function_conda_env.txt" ] -then - conda_env=$(cat remote_function_conda_env.txt) - - if which mamba >/dev/null; then - conda_exe="mamba" - else - conda_exe="conda" - fi - - printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" - printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" - $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" -else - printf "INFO: No conda env provided. Invoking remote function\\n" - printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" - python -m sagemaker.train.remote_function.invoke_function "$@" -fi -""" - -ENTRYPOINT_MPIRUN_SCRIPT = f""" -#!/bin/bash - -# Entry point for bootstrapping runtime environment and invoking remote function with mpirun - -set -eu - -PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} -export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs -printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" -export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip -printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" - -printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" -cat /opt/ml/input/config/resourceconfig.json - -printf "INFO: Bootstraping runtime environment.\\n" -python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" -source /opt/ml/input/sm_training.env - -if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] -then - if [ -f "remote_function_conda_env.txt" ] - then - cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt - fi - printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" - cd {JOB_REMOTE_FUNCTION_WORKSPACE} -fi - -if [ -f "remote_function_conda_env.txt" ] -then - conda_env=$(cat remote_function_conda_env.txt) - - if which mamba >/dev/null; then - conda_exe="mamba" - else - conda_exe="conda" - fi - - if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} - - printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" - printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ - --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ - -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ - -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ - -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ - - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" - $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ - --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ - -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ - -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ - -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ - $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" - - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 - else - printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} - fi -else - if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} - - printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" - printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ - --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ - -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ - -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ - -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ - $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" - - mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ - --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ - -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ - -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ - -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ - $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" - - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 - else - printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" - python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} - fi -fi -""" - -ENTRYPOINT_TORCHRUN_SCRIPT = f""" -#!/bin/bash - -# Entry point for bootstrapping runtime environment and invoking remote function with torchrun - -set -eu - -PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} -export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs -printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" -export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip -printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" - -printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" -cat /opt/ml/input/config/resourceconfig.json - -printf "INFO: Bootstraping runtime environment.\\n" -python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" -source /opt/ml/input/sm_training.env - -if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] -then - if [ -f "remote_function_conda_env.txt" ] - then - cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt - fi - printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" - cd {JOB_REMOTE_FUNCTION_WORKSPACE} -fi - -if [ -f "remote_function_conda_env.txt" ] -then - conda_env=$(cat remote_function_conda_env.txt) - - if which mamba >/dev/null; then - conda_exe="mamba" - else - conda_exe="conda" - fi - - printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" - printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ - --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ - -m sagemaker.train.remote_function.invoke_function \\n" - - $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ - --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ - -m sagemaker.train.remote_function.invoke_function "$@" -else - printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" - printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" - - torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" -fi -""" - -SPARK_ENTRYPOINT_SCRIPT = f""" -#!/bin/bash - -# Entry point for bootstrapping runtime environment and invoking remote function for Spark - -set -eu - -printf "INFO: Bootstraping Spark runtime environment.\\n" - -python3 /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" - -# Spark Container entry point script to initiate the spark application -smspark-submit "$@" -""" - -_STATUS_CODE_TABLE = { - "COMPLETED": "Completed", - "INPROGRESS": "InProgress", - "IN_PROGRESS": "InProgress", - "FAILED": "Failed", - "STOPPED": "Stopped", - "STOPPING": "Stopping", - "STARTING": "Starting", - "PENDING": "Pending", -} - -logger = logging_config.get_logger() - - -class LogState(object): - """Placeholder docstring""" - - STARTING = 1 - WAIT_IN_PROGRESS = 2 - TAILING = 3 - JOB_COMPLETE = 4 - COMPLETE = 5 - - -class _JobSettings: - """Helper class that processes the job settings. - - It validates the job settings and provides default values if necessary. - """ - - def __init__( - self, - *, - dependencies: str = None, - pre_execution_commands: List[str] = None, - pre_execution_script: str = None, - environment_variables: Dict[str, Union[str, "PipelineVariable"]] = None, - image_uri: Union[str, "PipelineVariable"] = None, - include_local_workdir: bool = None, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, - instance_count: Union[int, "PipelineVariable"] = 1, - instance_type: Union[str, "PipelineVariable"] = None, - job_conda_env: Union[str, "PipelineVariable"] = None, - job_name_prefix: str = None, - keep_alive_period_in_seconds: Union[int, "PipelineVariable"] = 0, - max_retry_attempts: Union[int, "PipelineVariable"] = 1, - max_runtime_in_seconds: Union[int, "PipelineVariable"] = 24 * 60 * 60, - role: str = None, - s3_kms_key: Union[str, "PipelineVariable"] = None, - s3_root_uri: str = None, - sagemaker_session: Session = None, - security_group_ids: List[Union[str, "PipelineVariable"]] = None, - subnets: List[Union[str, "PipelineVariable"]] = None, - tags: Optional[Tags] = None, - volume_kms_key: Union[str, "PipelineVariable"] = None, - volume_size: Union[int, "PipelineVariable"] = 30, - encrypt_inter_container_traffic: Union[bool, "PipelineVariable"] = None, - spark_config: SparkConfig = None, - use_spot_instances=False, - max_wait_time_in_seconds=None, - disable_output_compression: bool = False, - use_torchrun: bool = False, - use_mpirun: bool = False, - nproc_per_node: Optional[int] = None, - ): - """Initialize a _JobSettings instance which configures the remote job. - - Args: - dependencies (str): Either the path to a dependencies file or the reserved keyword - ``auto_capture``. Defaults to ``None``. - If ``dependencies`` is provided, the value must be one of the following: - - * A path to a conda environment.yml file. The following conditions apply. - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - then the conda environment is updated by installing dependencies from the - yaml file and the function is invoked within that conda environment. For - this to succeed, the conda environment name must already be set in - ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already - exist in the image. - * If none of the previous conditions are met, a new conda environment named - ``sagemaker-runtime-env`` is created and the function annotated with the remote - decorator is invoked in that conda environment. - - * A path to a requirements.txt file. The following conditions apply. - - * If ``job_conda_env`` is set in the remote decorator, dependencies are installed - within that conda environment and the function annotated with the remote decorator - is invoked in the same conda environment. For this to succeed, the specified - conda environment must already exist in the image. - * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - dependencies are installed within that conda environment and the function - annotated with the remote decorator is invoked in the same. For this to succeed, - the conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and - ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. - * If none of the above conditions are met, conda is not used. Dependencies are - installed at the system level, without any virtual environment, and the function - annotated with the remote decorator is invoked using the Python runtime available - in the system path. - - * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically - generate an env_snapshot.yml corresponding to the current active conda environment’s - snapshot. You do not need to provide a dependencies file. The following conditions - apply: - - * You must run the remote function within an active conda environment. - * When installing the dependencies on the training job, the same conditions - as when dependencies is set to a path to a conda environment file apply. - These conditions are as follows: - - * If job_conda_env is set, then the conda environment is updated by installing - dependencies from the yaml file and the function is invoked within that - conda environment. For this to succeed, the specified conda environment must - already exist in the image. - * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, - then the conda environment is updated by installing dependencies from the yaml - file and the function is invoked within that conda environment. For this to - succeed, the conda environment name must already be set in - ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist - in the image. - * If none of the previous conditions are met, a new conda environment with name - ``sagemaker-runtime-env`` is created and the function annotated with the - remote decorator is invoked in that conda environment. - - * ``None``. SageMaker will assume that there are no dependencies to install while - executing the remote annotated function in the training job. - - pre_execution_commands (List[str]): List of commands to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - pre_execution_script (str): Path to script file to be executed prior to executing - remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` - can be specified at the same time. Defaults to None. - - environment_variables (dict[str, str] or dict[str, PipelineVariable]): The environment - variables used inside the decorator function. Defaults to ``None``. - - image_uri (str, PipelineVariable): The universal resource identifier (URI) location of - a Docker image on Amazon Elastic Container Registry (ECR). Defaults to the following - based on where the SDK is running: - - * For users who specify ``spark_config`` and want to run the function in a Spark - application, the ``image_uri`` should be ``None``. A SageMaker Spark image will - be used for training, otherwise a ``ValueError`` is thrown. - * For users on SageMaker Studio notebooks, the image used as the kernel image for - the notebook is used. - * For other users, it is resolved to base python image with the same python version - as the environment running the local code. - - If no compatible image is found, a ValueError is thrown. - - include_local_workdir (bool): A flag to indicate that the remote function should include - local directories. Set to ``True`` if the remote function code imports local modules - and methods that are not available via PyPI or conda. Default value is ``False``. - - custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function - that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object - that specifies the local directories and files to be included in the remote function. - If a callable is passed in, that function is passed to the ``ignore`` argument of - ``shutil.copytree``. Defaults to ``None``, which means only python - files are accepted and uploaded to S3. - - instance_count (int, PipelineVariable): The number of instances to use. Defaults to 1. - - instance_type (str, PipelineVariable): The Amazon Elastic Compute Cloud (EC2) instance - type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, - a ValueError is thrown. - - job_conda_env (str, PipelineVariable): The name of the conda environment to activate - during job's runtime. Defaults to ``None``. - - job_name_prefix (str, PipelineVariable): The prefix used to create the underlying - SageMaker job. - - keep_alive_period_in_seconds (int, PipelineVariable): The duration in seconds to retain - and reuse provisioned infrastructure after the completion of a training job, also - known as SageMaker managed warm pools. The use of warm pools reduces the latency time - spent to provision new resources. The default value for - ``keep_alive_period_in_seconds`` is 0. - NOTE: Additional charges associated with warm pools may apply. Using this parameter - also activates a new persistent cache feature, which will further reduce job start up - latency than over using SageMaker managed warm pools alone by caching the package - source downloaded in the previous runs. - - max_retry_attempts (int, PipelineVariable): The max number of times the job is retried - on ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. - - max_runtime_in_seconds (int, PipelineVariable): The upper limit in seconds to be used - for training. After this specified amount of time, SageMaker terminates the job - regardless of its current status. Defaults to 1 day or (86400 seconds). - - role (str): The IAM role (either name or full ARN) used to run your SageMaker training - job. Defaults to: - - * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or - SageMaker Studio Notebooks. - * if not above, a ValueError is thrown. - - s3_kms_key (str): The key used to encrypt the input and output data. - Default to ``None``. - - s3_root_uri (str): The root S3 folder to which the code archives and data are - uploaded to. Defaults to ``s3://``. - - sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to - which SageMaker service calls are delegated to (default: None). If not provided, - one is created using a default configuration chain. - - security_group_ids (List[str, PipelineVariable]): A list of security group IDs. - Defaults to ``None`` and the training job is created without VPC config. - - subnets (List[str, PipelineVariable]): A list of subnet IDs. Defaults to ``None`` - and the job is created without VPC config. - - tags (Optional[Tags]): Tags attached to the job. Defaults to ``None`` - and the training job is created without tags. - - volume_kms_key (str, PipelineVariable): An Amazon Key Management Service (KMS) key - used to encrypt an Amazon Elastic Block Storage (EBS) volume attached to the - training instance. Defaults to ``None``. - - volume_size (int, PipelineVariable): The size in GB of the storage volume for storing - input and output data during training. Defaults to ``30``. - - encrypt_inter_container_traffic (bool, PipelineVariable): A flag that specifies - whether traffic between training containers is encrypted for the training job. - Defaults to ``False``. - - spark_config (SparkConfig): Configurations to the Spark application that runs on - Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri - will be used for training. Note that ``image_uri`` can not be specified at the - same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. - - use_spot_instances (bool, PipelineVariable): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the ``max_wait`` arg should - also be set. Defaults to ``False``. - - max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. - After this amount of time Amazon SageMaker will stop waiting for managed spot - training job to complete. Defaults to ``None``. - - disable_output_compression (bool): Optional. When set to true, Model is uploaded to - Amazon S3 without compression after training finishes. - - use_torchrun (bool): Specifies whether to use torchrun for distributed training. - Defaults to ``False``. - - use_mpirun (bool): Specifies whether to use mpirun for distributed training. - Defaults to ``False``. - - nproc_per_node (int): Optional. Specifies the number of processes per node for - distributed training. Defaults to ``None``. - This is defined automatically configured on the instance type. - """ - self.sagemaker_session = sagemaker_session or Session() - self.environment_variables = resolve_value_from_config( - direct_input=environment_variables, - config_path=REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, - default_value={}, - sagemaker_session=self.sagemaker_session, - ) - self.environment_variables.update( - {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} - ) - - if spark_config and image_uri: - raise ValueError("spark_config and image_uri cannot be specified at the same time!") - - if spark_config and job_conda_env: - raise ValueError("Remote Spark jobs do not support job_conda_env.") - - if spark_config and dependencies == "auto_capture": - raise ValueError( - "Remote Spark jobs do not support automatically capturing dependencies." - ) - - _image_uri = resolve_value_from_config( - direct_input=image_uri, - config_path=REMOTE_FUNCTION_IMAGE_URI, - sagemaker_session=self.sagemaker_session, - ) - - if spark_config: - self.image_uri = self._get_default_spark_image(self.sagemaker_session) - logger.info( - "Set the image uri as %s because value of spark_config is " - "indicating this is a remote spark job.", - self.image_uri, - ) - elif _image_uri: - self.image_uri = _image_uri - else: - self.image_uri = self._get_default_image(self.sagemaker_session) - - self.dependencies = resolve_value_from_config( - direct_input=dependencies, - config_path=REMOTE_FUNCTION_DEPENDENCIES, - sagemaker_session=self.sagemaker_session, - ) - - self.pre_execution_commands = resolve_value_from_config( - direct_input=pre_execution_commands, - config_path=REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, - sagemaker_session=self.sagemaker_session, - ) - - self.pre_execution_script = resolve_value_from_config( - direct_input=pre_execution_script, - config_path=REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, - sagemaker_session=self.sagemaker_session, - ) - - if self.pre_execution_commands is not None and self.pre_execution_script is not None: - raise ValueError( - "Only one of pre_execution_commands or pre_execution_script can be specified!" - ) - - self.include_local_workdir = resolve_value_from_config( - direct_input=include_local_workdir, - config_path=REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, - default_value=False, - sagemaker_session=self.sagemaker_session, - ) - - self.custom_file_filter = resolve_custom_file_filter_from_config_file( - custom_file_filter, self.sagemaker_session - ) - - self.instance_type = resolve_value_from_config( - direct_input=instance_type, - config_path=REMOTE_FUNCTION_INSTANCE_TYPE, - sagemaker_session=self.sagemaker_session, - ) - if not self.instance_type: - raise ValueError("instance_type is a required parameter!") - - self.instance_count = instance_count - self.volume_size = volume_size - self.max_runtime_in_seconds = max_runtime_in_seconds - self.max_retry_attempts = max_retry_attempts - self.keep_alive_period_in_seconds = keep_alive_period_in_seconds - self.spark_config = spark_config - self.use_spot_instances = use_spot_instances - self.max_wait_time_in_seconds = max_wait_time_in_seconds - self.job_conda_env = resolve_value_from_config( - direct_input=job_conda_env, - config_path=REMOTE_FUNCTION_JOB_CONDA_ENV, - sagemaker_session=self.sagemaker_session, - ) - self.job_name_prefix = job_name_prefix - self.encrypt_inter_container_traffic = resolve_value_from_config( - direct_input=encrypt_inter_container_traffic, - config_path=REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, - default_value=False, - sagemaker_session=self.sagemaker_session, - ) - self.enable_network_isolation = False - - _role = resolve_value_from_config( - direct_input=role, - config_path=REMOTE_FUNCTION_ROLE_ARN, - sagemaker_session=self.sagemaker_session, - ) - if _role: - self.role = expand_role(self.sagemaker_session.boto_session, _role) - else: - self.role = get_execution_role(self.sagemaker_session) - - self.s3_root_uri = resolve_value_from_config( - direct_input=s3_root_uri, - config_path=REMOTE_FUNCTION_S3_ROOT_URI, - default_value=s3_path_join( - "s3://", - self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix, - ), - sagemaker_session=self.sagemaker_session, - ) - - self.s3_kms_key = resolve_value_from_config( - direct_input=s3_kms_key, - config_path=REMOTE_FUNCTION_S3_KMS_KEY_ID, - sagemaker_session=self.sagemaker_session, - ) - self.volume_kms_key = resolve_value_from_config( - direct_input=volume_kms_key, - config_path=REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, - sagemaker_session=self.sagemaker_session, - ) - - _subnets = resolve_value_from_config( - direct_input=subnets, - config_path=REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, - sagemaker_session=self.sagemaker_session, - ) - _security_group_ids = resolve_value_from_config( - direct_input=security_group_ids, - config_path=REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, - sagemaker_session=self.sagemaker_session, - ) - vpc_config = vpc_utils.to_dict(subnets=_subnets, security_group_ids=_security_group_ids) - self.vpc_config = vpc_utils.sanitize(vpc_config) - - tags = format_tags(tags) - self.tags = _append_sagemaker_config_tags( - self.sagemaker_session, tags, REMOTE_FUNCTION_TAGS - ) - - self.disable_output_compression = disable_output_compression - self.use_torchrun = use_torchrun - self.use_mpirun = use_mpirun - self.nproc_per_node = nproc_per_node - - @staticmethod - def _get_default_image(session): - """Return Studio notebook image, if in Studio env. Else, base python. - - Args: - session (Session): Boto session. - - Returns: - Default SageMaker base python image. - """ - - if ( - "SAGEMAKER_INTERNAL_IMAGE_URI" in os.environ - and os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] - ): - return os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] - - py_version = str(sys.version_info[0]) + str(sys.version_info[1]) - - if py_version not in ["310", "38"]: - raise ValueError( - "Default image is supported only for Python versions 3.8 and 3.10. If you " - "are using any other python version, you must provide a compatible image_uri." - ) - - region = session.boto_region_name - image_uri = get_base_python_image_uri(region=region, py_version=py_version) - - return image_uri - - @staticmethod - def _get_default_spark_image(session): - """Return the Spark image. - - Args: - session (Session): Boto session. - - Returns: - SageMaker Spark container image uri. - """ - - region = session.boto_region_name - - py_version = str(sys.version_info[0]) + str(sys.version_info[1]) - - if py_version not in ["39"]: - raise ValueError( - "The SageMaker Spark image for remote job only supports Python version 3.9. " - ) - - image_uri = image_uris.retrieve( - framework=SPARK_NAME, - region=region, - version=DEFAULT_SPARK_VERSION, - instance_type=None, - py_version=f"py{py_version}", - container_version=DEFAULT_SPARK_CONTAINER_VERSION, - ) - - return image_uri - - -class _Job: - """Helper class that interacts with the SageMaker training service.""" - - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): - """Initialize a _Job object. - - Args: - job_name (str): The training job name. - s3_uri (str): The training job output S3 uri. - sagemaker_session (Session): SageMaker boto session. - """ - self.job_name = job_name - self.s3_uri = s3_uri - self.sagemaker_session = sagemaker_session - self._last_describe_response = None - - @staticmethod - def from_describe_response(describe_training_job_response, sagemaker_session): - """Construct a _Job from a describe_training_job_response object. - - Args: - describe_training_job_response (Dict): Describe training job response. - sagemaker_session (Session): SageMaker boto session. - - Returns: - the _Job object. - """ - job_name = describe_training_job_response["TrainingJobName"] - s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - - job = _Job(job_name, s3_uri, sagemaker_session) - job._last_describe_response = describe_training_job_response - return job - - @staticmethod - def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None): - """Start a training job. - - Args: - job_settings (_JobSettings): the job settings. - func: the function to be executed. - func_args: the positional arguments to the function. - func_kwargs: the keyword arguments to the function - - Returns: - the _Job object. - """ - job_name = _Job._get_job_name(job_settings, func) - s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name) - - training_job_request = _Job.compile( - job_settings=job_settings, - job_name=job_name, - s3_base_uri=s3_base_uri, - func=func, - func_args=func_args, - func_kwargs=func_kwargs, - run_info=run_info, - ) - - logger.info("Creating job: %s", job_name) - - job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) - - return _Job( - job_name, - s3_base_uri, - job_settings.sagemaker_session, - ) - - @staticmethod - def compile( - job_settings: _JobSettings, - job_name: str, - s3_base_uri: str, - func: Callable, - func_args: tuple, - func_kwargs: dict, - run_info=None, - serialized_data: _SerializedData = None, - ) -> dict: - """Build the artifacts and generate the training job request.""" - from sagemaker.core.workflow.properties import Properties - from sagemaker.core.workflow.parameters import Parameter - from sagemaker.core.workflow.functions import Join - from sagemaker.core.workflow.execution_variables import ( - ExecutionVariables, - ExecutionVariable, - ) - from sagemaker.core.workflow.utilities import load_step_compilation_context - - step_compilation_context = load_step_compilation_context() - - jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:] - - # serialize function and arguments - if step_compilation_context is None: - stored_function = StoredFunction( - sagemaker_session=job_settings.sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - ) - stored_function.save(func, *func_args, **func_kwargs) - else: - stored_function = StoredFunction( - sagemaker_session=job_settings.sagemaker_session, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - context=Context( - step_name=step_compilation_context.step_name, - func_step_s3_dir=step_compilation_context.pipeline_build_time, - ), - ) - - stored_function.save_pipeline_step_function(serialized_data) - - stopping_condition = { - "MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds, - } - if job_settings.max_wait_time_in_seconds is not None: - stopping_condition["MaxWaitTimeInSeconds"] = job_settings.max_wait_time_in_seconds - - request_dict = dict( - TrainingJobName=job_name, - RoleArn=job_settings.role, - StoppingCondition=stopping_condition, - RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts}, - ) - - _update_job_request_with_checkpoint_config(func_args, func_kwargs, request_dict) - - if job_settings.tags: - request_dict["Tags"] = job_settings.tags - - # generate other build artifacts including workspace, requirements.txt - request_dict["InputDataConfig"] = _generate_input_data_config( - job_settings=job_settings, s3_base_uri=s3_base_uri - ) - - if step_compilation_context: - # Path format: base/step_name/build_timestamp/execution_id/results - # This matches the path construction in stored_function.py - s3_output_path = Join( - on="/", - values=[ - s3_base_uri, - step_compilation_context.step_name, - step_compilation_context.pipeline_build_time, - ExecutionVariables.PIPELINE_EXECUTION_ID, - "results", - ], - ) - output_config = {"S3OutputPath": s3_output_path} - else: - output_config = {"S3OutputPath": s3_base_uri} - if job_settings.s3_kms_key is not None: - output_config["KmsKeyId"] = job_settings.s3_kms_key - if job_settings.disable_output_compression: - output_config["CompressionType"] = "NONE" - request_dict["OutputDataConfig"] = output_config - - container_args = ["--s3_base_uri", s3_base_uri] - container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name]) - container_args.extend( - ["--client_python_version", RuntimeEnvironmentManager()._current_python_version()] - ) - container_args.extend( - [ - "--client_sagemaker_pysdk_version", - RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(), - ] - ) - container_args.extend( - [ - "--dependency_settings", - _DependencySettings.from_dependency_file_path( - job_settings.dependencies - ).to_string(), - ] - ) - if job_settings.use_torchrun: - container_args.extend(["--distribution", "torchrun"]) - elif job_settings.use_mpirun: - container_args.extend(["--distribution", "mpirun"]) - if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: - container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) - if job_settings.s3_kms_key: - container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) - - if job_settings.job_conda_env: - container_args.extend(["--job_conda_env", job_settings.job_conda_env]) - - if step_compilation_context: - # TODO: remove the duplicates in the list - container_args.extend(["--pipeline_step_name", step_compilation_context.step_name]) - container_args.extend( - ["--pipeline_execution_id", ExecutionVariables.PIPELINE_EXECUTION_ID] - ) - container_args.extend( - ["--func_step_s3_dir", step_compilation_context.pipeline_build_time] - ) - container_args.extend(["--property_references"]) - container_args.extend( - [ - ExecutionVariables.PIPELINE_EXECUTION_ID.expr["Get"], - ExecutionVariables.PIPELINE_EXECUTION_ID.to_string(), - ] - ) - for arg in func_args + tuple(func_kwargs.values()): - if isinstance(arg, (Parameter, ExecutionVariable, Properties)): - container_args.extend([arg.expr["Get"], arg.to_string()]) - - # Lazy import to avoid circular dependency - try: - from sagemaker.mlops.workflow.function_step import DelayedReturn - - if isinstance(arg, DelayedReturn): - # The uri is a Properties object - uri = get_step(arg)._properties.OutputDataConfig.S3OutputPath - container_args.extend([uri.expr["Get"], uri.to_string()]) - except ImportError: - # MLOps not installed, skip DelayedReturn handling - pass - - if run_info is not None: - container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))]) - elif _RunContext.get_current_run() is not None: - container_args.extend( - ["--run_in_context", _convert_run_to_json(_RunContext.get_current_run())] - ) - - algorithm_spec = dict( - TrainingImage=job_settings.image_uri, - TrainingInputMode="File", - ContainerEntrypoint=jobs_container_entrypoint, - ContainerArguments=container_args, - ) - - request_dict["AlgorithmSpecification"] = algorithm_spec - - resource_config = dict( - VolumeSizeInGB=job_settings.volume_size, - InstanceCount=job_settings.instance_count, - InstanceType=job_settings.instance_type, - ) - if job_settings.volume_kms_key is not None: - resource_config["VolumeKmsKeyId"] = job_settings.volume_kms_key - if job_settings.keep_alive_period_in_seconds is not None: - resource_config["KeepAlivePeriodInSeconds"] = job_settings.keep_alive_period_in_seconds - - request_dict["ResourceConfig"] = resource_config - - if job_settings.enable_network_isolation is not None: - request_dict["EnableNetworkIsolation"] = job_settings.enable_network_isolation - - if job_settings.encrypt_inter_container_traffic is not None: - request_dict["EnableInterContainerTrafficEncryption"] = ( - job_settings.encrypt_inter_container_traffic - ) - - if job_settings.vpc_config: - request_dict["VpcConfig"] = job_settings.vpc_config - - request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances - - request_dict["Environment"] = job_settings.environment_variables - - extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) - extended_request = _extend_mpirun_to_request(extended_request, job_settings) - extended_request = _extend_torchrun_to_request(extended_request, job_settings) - - return extended_request - - def describe(self): - """Describe the underlying sagemaker training job. - - Returns: - Dict: Describe training job response. - """ - if self._last_describe_response is not None and self._last_describe_response[ - "TrainingJobStatus" - ] in ["Completed", "Failed", "Stopped"]: - return self._last_describe_response - - self._last_describe_response = ( - self.sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=self.job_name - ) - ) - - return self._last_describe_response - - def stop(self): - """Stop the underlying sagemaker training job.""" - self.sagemaker_session.sagemaker_client.stop_training_job(TrainingJobName=self.job_name) - - def wait(self, timeout: int = None): - """Wait for the underlying sagemaker job to finish and displays its logs . - - This method blocks on the sagemaker job completing for up to the timeout value (if - specified). If timeout is ``None``, this method will block until the job is completed. - - Args: - timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by - default. - - Returns: None - """ - - self._last_describe_response = _logs_for_job( - sagemaker_session=self.sagemaker_session, - job_name=self.job_name, - wait=True, - timeout=timeout, - ) - - @staticmethod - def _get_job_name(job_settings, func): - """Get the underlying SageMaker job name from job_name_prefix or func. - - Args: - job_settings (_JobSettings): the job settings. - func: the function to be executed. - - Returns: - str : the training job name. - """ - from sagemaker.core.workflow.utilities import load_step_compilation_context - - step_complication_context = load_step_compilation_context() - - job_name_prefix = job_settings.job_name_prefix - if not job_name_prefix: - job_name_prefix = func.__name__ - # remove all special characters in the beginning of function name - job_name_prefix = re.sub(r"^[^a-zA-Z0-9]+", "", job_name_prefix) - # convert all remaining special characters to '-' - job_name_prefix = re.sub(r"[^a-zA-Z0-9-]", "-", job_name_prefix) - - if step_complication_context: - return job_name_prefix - return name_from_base(job_name_prefix) - - -def _prepare_and_upload_runtime_scripts( - spark_config: SparkConfig, - s3_base_uri: str, - s3_kms_key: str, - sagemaker_session: Session, - use_torchrun: bool = False, - use_mpirun: bool = False, -): - """Copy runtime scripts to a folder and upload to S3. - - In case of remote function, s3_base_uri is s3_root_uri + function_name. - In case of pipeline, s3_base_uri is s3_root_uri + pipeline_name. The runtime scripts are - uploaded only once per pipeline. - - Args: - spark_config (SparkConfig): remote Spark job configurations. - - s3_base_uri (str): S3 location that the runtime scripts will be uploaded to. - - s3_kms_key (str): kms key used to encrypt the files uploaded to S3. - - sagemaker_session (str): SageMaker boto client session. - - use_torchrun (bool): Whether to use torchrun or not. - - use_mpirun (bool): Whether to use mpirun or not. - - nproc_per_node (Optional[int]): Number of processes per node - """ - - from sagemaker.core.workflow.utilities import load_step_compilation_context - - step_compilation_context = load_step_compilation_context() - - if step_compilation_context and not step_compilation_context.upload_runtime_scripts: - return s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME) - - with _tmpdir() as bootstrap_scripts: - - # write entrypoint script to tmpdir - entrypoint_script_path = os.path.join(bootstrap_scripts, ENTRYPOINT_SCRIPT_NAME) - entry_point_script = ENTRYPOINT_SCRIPT - if spark_config: - entry_point_script = SPARK_ENTRYPOINT_SCRIPT - spark_script_path = os.path.join( - os.path.dirname(__file__), "runtime_environment", SPARK_APP_SCRIPT_NAME - ) - shutil.copy2(spark_script_path, bootstrap_scripts) - - if use_torchrun: - entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT - - if use_mpirun: - entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT - - with open(entrypoint_script_path, "w", newline="\n") as file: - file.writelines(entry_point_script) - - bootstrap_script_path = os.path.join( - os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME - ) - mpi_utils_path = os.path.join( - os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME - ) - runtime_manager_script_path = os.path.join( - os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME - ) - - # copy runtime scripts to tmpdir - shutil.copy2(bootstrap_script_path, bootstrap_scripts) - shutil.copy2(mpi_utils_path, bootstrap_scripts) - shutil.copy2(runtime_manager_script_path, bootstrap_scripts) - - upload_path = S3Uploader.upload( - bootstrap_scripts, - s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME), - s3_kms_key, - sagemaker_session, - ) - - if step_compilation_context: - step_compilation_context.upload_runtime_scripts = False - return upload_path - - -def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): - """Generates input data config""" - from sagemaker.core.workflow.utilities import load_step_compilation_context - - step_compilation_context = load_step_compilation_context() - - bootstrap_scripts_s3uri = _prepare_and_upload_runtime_scripts( - spark_config=job_settings.spark_config, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - sagemaker_session=job_settings.sagemaker_session, - use_torchrun=job_settings.use_torchrun, - use_mpirun=job_settings.use_mpirun, - ) - - input_data_config = [ - dict( - ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, - DataSource={ - "S3DataSource": { - "S3Uri": bootstrap_scripts_s3uri, - "S3DataType": "S3Prefix", - } - }, - ) - ] - - local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) - - if step_compilation_context: - with _tmpdir() as tmp_dir: - script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( - local_dependencies_path=local_dependencies_path, - pre_execution_commands=job_settings.pre_execution_commands, - pre_execution_script_local_path=job_settings.pre_execution_script, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - sagemaker_session=job_settings.sagemaker_session, - tmp_dir=tmp_dir, - ) - - if script_and_dependencies_s3uri: - input_data_config.append( - dict( - ChannelName=SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - DataSource={ - "S3DataSource": { - "S3Uri": script_and_dependencies_s3uri, - "S3DataType": "S3Prefix", - } - }, - ) - ) - - user_workspace_s3uri = _prepare_and_upload_workspace( - local_dependencies_path=local_dependencies_path, - include_local_workdir=job_settings.include_local_workdir, - pre_execution_commands=job_settings.pre_execution_commands, - pre_execution_script_local_path=job_settings.pre_execution_script, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - sagemaker_session=job_settings.sagemaker_session, - custom_file_filter=job_settings.custom_file_filter, - ) - - if user_workspace_s3uri: - input_data_config.append( - dict( - ChannelName=( - REMOTE_FUNCTION_WORKSPACE - if not step_compilation_context - else step_compilation_context.pipeline_build_time - ), - DataSource={ - "S3DataSource": { - "S3Uri": user_workspace_s3uri, - "S3DataType": "S3Prefix", - } - }, - ) - ) - - return input_data_config - - -def _prepare_dependencies_and_pre_execution_scripts( - local_dependencies_path: str, - pre_execution_commands: List[str], - pre_execution_script_local_path: str, - s3_base_uri: str, - s3_kms_key: str, - sagemaker_session: Session, - tmp_dir: str, -): - """Prepare pre-execution scripts and dependencies and upload them to s3. - - If pre execution commands are provided, a new bash file will be created - with those commands in tmp directory. - If pre execution script is provided, it copies that file from local file path - to tmp directory. - If local dependencies file is provided, it copies that file from local file path - to tmp directory. - If under pipeline context, tmp directory with copied dependencies and scripts is - uploaded to S3. - """ - from sagemaker.core.workflow.utilities import load_step_compilation_context - - if not (local_dependencies_path or pre_execution_commands or pre_execution_script_local_path): - return None - - if local_dependencies_path: - dst_path = shutil.copy2(local_dependencies_path, tmp_dir) - logger.info("Copied dependencies file at '%s' to '%s'", local_dependencies_path, dst_path) - - if pre_execution_commands or pre_execution_script_local_path: - pre_execution_script = os.path.join(tmp_dir, PRE_EXECUTION_SCRIPT_NAME) - if pre_execution_commands: - with open(pre_execution_script, "w") as target_script: - commands = [cmd + "\n" for cmd in pre_execution_commands] - target_script.writelines(commands) - logger.info( - "Generated pre-execution script from commands to '%s'", pre_execution_script - ) - else: - shutil.copy2(pre_execution_script_local_path, pre_execution_script) - logger.info( - "Copied pre-execution commands from script at '%s' to '%s'", - pre_execution_script_local_path, - pre_execution_script, - ) - - step_compilation_context = load_step_compilation_context() - if step_compilation_context: - upload_path = S3Uploader.upload( - tmp_dir, - s3_path_join( - s3_base_uri, - step_compilation_context.step_name, - step_compilation_context.pipeline_build_time, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - ), - s3_kms_key, - sagemaker_session, - ) - logger.info( - "Successfully uploaded dependencies and pre execution scripts to '%s'", upload_path - ) - return upload_path - return None - - -def _prepare_and_upload_workspace( - local_dependencies_path: str, - include_local_workdir: bool, - pre_execution_commands: List[str], - pre_execution_script_local_path: str, - s3_base_uri: str, - s3_kms_key: str, - sagemaker_session: Session, - custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, -) -> str: - """Prepare and upload the workspace to S3. - - Under pipeline context, only workdir is packaged in the workspace folder and uploaded to s3. - Under remote function context, workdir along with pre execution scripts and dependencies - are packaged together into the workspace folder and uploaded to S3. - """ - from sagemaker.core.workflow.utilities import load_step_compilation_context - - step_compilation_context = load_step_compilation_context() - - if not ( - local_dependencies_path - or include_local_workdir - or pre_execution_commands - or pre_execution_script_local_path - ): - return None - - func_step_s3_dir = None - if step_compilation_context: - func_step_s3_dir = step_compilation_context.pipeline_build_time - if not include_local_workdir: - return None - if not step_compilation_context.upload_workspace: - return s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE, func_step_s3_dir) - - with _tmpdir() as tmp_dir: - tmp_workspace_dir = os.path.join(tmp_dir, "temp_workspace/") - os.mkdir(tmp_workspace_dir) - # TODO Remove the following hack to avoid dir_exists error in the copy_tree call below. - tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE) - - if include_local_workdir: - copy_workdir(tmp_workspace, custom_file_filter) - logger.info("Copied user workspace to '%s'", tmp_workspace) - - if not os.path.isdir(tmp_workspace): - # create the directory if no workdir_path was provided in the input. - os.mkdir(tmp_workspace) - - if not step_compilation_context: - _prepare_dependencies_and_pre_execution_scripts( - local_dependencies_path=local_dependencies_path, - pre_execution_commands=pre_execution_commands, - pre_execution_script_local_path=pre_execution_script_local_path, - s3_base_uri=s3_base_uri, - s3_kms_key=s3_kms_key, - sagemaker_session=sagemaker_session, - tmp_dir=tmp_workspace, - ) - - workspace_archive_path = os.path.join(tmp_dir, "workspace") - workspace_archive_path = shutil.make_archive( - workspace_archive_path, "zip", tmp_workspace_dir - ) - logger.info("Successfully created workdir archive at '%s'", workspace_archive_path) - - upload_path = S3Uploader.upload( - workspace_archive_path, - s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE, func_step_s3_dir), - s3_kms_key, - sagemaker_session, - ) - logger.info("Successfully uploaded workdir to '%s'", upload_path) - if step_compilation_context: - step_compilation_context.upload_workspace = False - return upload_path - - -def _convert_run_to_json(run: Run) -> str: - """Convert current run into json string""" - run_info = _RunInfo(run.experiment_name, run.run_name) - return json.dumps(dataclasses.asdict(run_info)) - - -def _prepare_and_upload_spark_dependent_files( - spark_config: SparkConfig, - s3_base_uri: str, - s3_kms_key: str, - sagemaker_session: Session, -) -> Tuple: - """Upload the Spark dependencies to S3 if present. - - Args: - spark_config (SparkConfig): The remote Spark job configurations. - s3_base_uri (str): The S3 location that the Spark dependencies will be uploaded to. - s3_kms_key (str): The kms key used to encrypt the files uploaded to S3. - sagemaker_session (str): SageMaker boto client session. - """ - if not spark_config: - return None, None, None, None - - submit_jars_s3_paths = _upload_spark_submit_deps( - spark_config.submit_jars, - SPARK_SUBMIT_JARS_WORKSPACE, - s3_base_uri, - s3_kms_key, - sagemaker_session, - ) - submit_py_files_s3_paths = _upload_spark_submit_deps( - spark_config.submit_py_files, - SPARK_SUBMIT_PY_FILES_WORKSPACE, - s3_base_uri, - s3_kms_key, - sagemaker_session, - ) - submit_files_s3_path = _upload_spark_submit_deps( - spark_config.submit_files, - SPARK_SUBMIT_FILES_WORKSPACE, - s3_base_uri, - s3_kms_key, - sagemaker_session, - ) - config_file_s3_uri = _upload_serialized_spark_configuration( - s3_base_uri, s3_kms_key, spark_config.configuration, sagemaker_session - ) - - return submit_jars_s3_paths, submit_py_files_s3_paths, submit_files_s3_path, config_file_s3_uri - - -def _upload_spark_submit_deps( - submit_deps: List[str], - workspace_name: str, - s3_base_uri: str, - s3_kms_key: str, - sagemaker_session: Session, -) -> str: - """Upload the Spark submit dependencies to S3. - - Args: - submit_deps (List[str]): A list of path which points to the Spark dependency files. - The path can be either a local path or S3 uri. For example ``/local/deps.jar`` or - ``s3:///deps.jar``. - - workspace_name (str): workspace name for Spark dependency. - s3_base_uri (str): S3 location that the Spark dependencies will be uploaded to. - s3_kms_key (str): kms key used to encrypt the files uploaded to S3. - sagemaker_session (str): SageMaker boto client session. - - Returns: - str : The concatenated path of all dependencies which will be passed to Spark. - """ - spark_opt_s3_uris = [] - if not submit_deps: - return None - - if not workspace_name or not s3_base_uri: - raise ValueError("workspace_name or s3_base_uri may not be empty.") - - for dep_path in submit_deps: - dep_url = urlparse(dep_path) - - if dep_url.scheme in ["s3", "s3a"]: - spark_opt_s3_uris.append(dep_path) - elif not dep_url.scheme or dep_url.scheme == "file": - if not os.path.isfile(dep_path): - raise ValueError(f"submit_deps path {dep_path} is not a valid local file.") - - upload_path = S3Uploader.upload( - local_path=dep_path, - desired_s3_uri=s3_path_join(s3_base_uri, workspace_name), - kms_key=s3_kms_key, - sagemaker_session=sagemaker_session, - ) - - spark_opt_s3_uris.append(upload_path) - logger.info("Uploaded the local file %s to %s", dep_path, upload_path) - return str.join(",", spark_opt_s3_uris) - - -def _upload_serialized_spark_configuration( - s3_base_uri: str, s3_kms_key: str, configuration: Dict, sagemaker_session: Session -) -> str: - """Upload the Spark configuration json to S3""" - if not configuration: - return None - - serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8")) - config_file_s3_uri = s3_path_join(s3_base_uri, SPARK_CONF_WORKSPACE, SPARK_CONF_FILE_NAME) - - S3Uploader.upload_string_as_file_body( - body=serialized_configuration, - desired_s3_uri=config_file_s3_uri, - kms_key=s3_kms_key, - sagemaker_session=sagemaker_session, - ) - - logger.info("Uploaded spark configuration json %s to %s", configuration, config_file_s3_uri) - - return config_file_s3_uri - - -def _extend_mpirun_to_request( - request_dict: Dict, - job_settings: _JobSettings, -) -> Dict: - """Extend the create training job request with mpirun configuration. - - Args: - request_dict (Dict): create training job request dict. - job_settings (_JobSettings): the job settings. - """ - use_mpirun = job_settings.use_mpirun - instance_count = job_settings.instance_count - - if not use_mpirun: - return request_dict - - if instance_count == 1: - return request_dict - - extended_request = request_dict.copy() - - for input_channel in extended_request["InputDataConfig"]: - s3_data_source = input_channel["DataSource"].get("S3DataSource", None) - if s3_data_source: - s3_data_source["S3DataDistributionType"] = "FullyReplicated" - - return extended_request - - -def _extend_torchrun_to_request( - request_dict: Dict, - job_settings: _JobSettings, -) -> Dict: - """Extend the create training job request with torchrun configuration. - - Args: - request_dict (Dict): create training job request dict. - job_settings (_JobSettings): the job settings. - """ - use_torchrun = job_settings.use_torchrun - instance_count = job_settings.instance_count - - if not use_torchrun: - return request_dict - - if instance_count == 1: - return request_dict - - extended_request = request_dict.copy() - - for input_channel in extended_request["InputDataConfig"]: - s3_data_source = input_channel["DataSource"].get("S3DataSource", None) - if s3_data_source: - s3_data_source["S3DataDistributionType"] = "FullyReplicated" - - return extended_request - - -def _extend_spark_config_to_request( - request_dict: Dict, - job_settings: _JobSettings, - s3_base_uri: str, -) -> Dict: - """Extend the create training job request with spark configurations. - - Args: - request_dict (Dict): create training job request dict. - job_settings (_JobSettings): the job settings. - s3_base_uri (str): S3 location that the Spark dependencies will be uploaded to. - """ - spark_config = job_settings.spark_config - - if not spark_config: - return request_dict - - extended_request = request_dict.copy() - container_entrypoint = extended_request["AlgorithmSpecification"]["ContainerEntrypoint"] - - ( - submit_jars_s3_paths, - submit_py_files_s3_paths, - submit_files_s3_path, - config_file_s3_uri, - ) = _prepare_and_upload_spark_dependent_files( - spark_config=spark_config, - s3_base_uri=s3_base_uri, - s3_kms_key=job_settings.s3_kms_key, - sagemaker_session=job_settings.sagemaker_session, - ) - - input_data_config = extended_request["InputDataConfig"] - - if config_file_s3_uri: - input_data_config.append( - dict( - ChannelName=SPARK_CONF_CHANNEL_NAME, - DataSource={ - "S3DataSource": { - "S3Uri": config_file_s3_uri, - "S3DataType": "S3Prefix", - } - }, - ) - ) - - for input_channel in extended_request["InputDataConfig"]: - s3_data_source = input_channel["DataSource"].get("S3DataSource", None) - if s3_data_source: - s3_data_source["S3DataDistributionType"] = "FullyReplicated" - - if spark_config.spark_event_logs_uri: - container_entrypoint.extend( - ["--spark-event-logs-s3-uri", spark_config.spark_event_logs_uri] - ) - - if submit_jars_s3_paths: - container_entrypoint.extend(["--jars", submit_jars_s3_paths]) - - if submit_py_files_s3_paths: - container_entrypoint.extend(["--py-files", submit_py_files_s3_paths]) - - if submit_files_s3_path: - container_entrypoint.extend(["--files", submit_files_s3_path]) - - if spark_config: - container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) - - return extended_request - - -def _update_job_request_with_checkpoint_config(args, kwargs, request_dict): - """Extend job request with checkpoint config based on CheckpointLocation in function args. - - Args: - args (tuple): The positional arguments of the remote function. - kwargs (Dict): The keyword arguments of the remote function. - request_dict (Dict): create training job request dict. - """ - checkpoint_location_index_in_args = None - checkpoint_location_key_in_kwargs = None - checkpoint_location_count = 0 - - for index, arg in enumerate(args): - if isinstance(arg, CheckpointLocation): - checkpoint_location_index_in_args = index - checkpoint_location_count += 1 - - for key, value in kwargs.items(): - if isinstance(value, CheckpointLocation): - checkpoint_location_key_in_kwargs = key - checkpoint_location_count += 1 - - if checkpoint_location_count < 1: - return - - if checkpoint_location_count > 1: - raise ValueError( - "Remote function cannot have more than one argument of type CheckpointLocation." - ) - - if checkpoint_location_index_in_args is not None: - checkpoint_location_arg = args[checkpoint_location_index_in_args] - else: - checkpoint_location_arg = kwargs[checkpoint_location_key_in_kwargs] - - checkpoint_s3_uri = checkpoint_location_arg._s3_uri - checkpoint_local_path = checkpoint_location_arg._local_path - - request_dict["CheckpointConfig"] = { - "LocalPath": checkpoint_local_path, - "S3Uri": checkpoint_s3_uri, - } - - -@dataclasses.dataclass -class _RunInfo: - """Data class to hold information of the run object from context.""" - - experiment_name: str - run_name: str - - -def _get_initial_job_state(description, status_key, wait): - """Placeholder docstring""" - status = description[status_key] - job_already_completed = status in ("Completed", "Failed", "Stopped") - return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE - - -def _logs_for_job( # noqa: C901 - suppress complexity warning for this method - sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None -): - """Display logs for a given training job, optionally tailing them until job is complete. - - If the output is a tty or a Jupyter cell, it will be color-coded - based on which instance the log entry is from. - - Args: - sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session - object, used for SageMaker interactions. - job_name (str): Name of the training job to display the logs for. - wait (bool): Whether to keep looking for new log entries until the job completes - (default: False). - poll (int): The interval in seconds between polling for new log entries and job - completion (default: 5). - log_type ([str]): A list of strings specifying which logs to print. Acceptable - strings are "All", "None", "Training", or "Rules". To maintain backwards - compatibility, boolean values are also accepted and converted to strings. - timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by - default. - Returns: - Last call to sagemaker DescribeTrainingJob - Raises: - exceptions.CapacityError: If the training job fails with CapacityError. - exceptions.UnexpectedStatusException: If waiting and the training job fails. - """ - sagemaker_client = sagemaker_session.sagemaker_client - request_end_time = time.time() + timeout if timeout else None - description = _wait_until( - lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name) - ) - print(secondary_training_status_message(description, None), end="") - - instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( - sagemaker_session.boto_session, description, job="Training" - ) - - state = _get_initial_job_state(description, "TrainingJobStatus", wait) - - # The loop below implements a state machine that alternates between checking the job status - # and reading whatever is available in the logs at this point. Note, that if we were - # called with wait == False, we never check the job status. - # - # If wait == TRUE and job is not completed, the initial state is TAILING - # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is - # complete). - # - # The state table: - # - # STATE ACTIONS CONDITION NEW STATE - # ---------------- ---------------- ----------------- ---------------- - # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE - # Else TAILING - # JOB_COMPLETE Read logs, Pause Any COMPLETE - # COMPLETE Read logs, Exit N/A - # - # Notes: - # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to - # Cloudwatch after the job was marked complete. - last_describe_job_call = time.time() - last_description = description - last_debug_rule_statuses = None - last_profiler_rule_statuses = None - - while True: - _flush_log_streams( - stream_names, - instance_count, - client, - log_group, - job_name, - positions, - dot, - color_wrap, - ) - if timeout and time.time() > request_end_time: - print("Timeout Exceeded. {} seconds elapsed.".format(timeout)) - break - - if state == LogState.COMPLETE: - break - - time.sleep(poll) - - if state == LogState.JOB_COMPLETE: - state = LogState.COMPLETE - elif time.time() - last_describe_job_call >= 30: - description = sagemaker_client.describe_training_job(TrainingJobName=job_name) - last_describe_job_call = time.time() - - if secondary_training_status_changed(description, last_description): - print() - print(secondary_training_status_message(description, last_description), end="") - last_description = description - - status = description["TrainingJobStatus"] - - if status in ("Completed", "Failed", "Stopped"): - print() - state = LogState.JOB_COMPLETE - - # Print prettified logs related to the status of SageMaker Debugger rules. - debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {}) - if ( - debug_rule_statuses - and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses) - and (log_type in {"All", "Rules"}) - ): - for status in debug_rule_statuses: - rule_log = ( - f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" - ) - print(rule_log) - - last_debug_rule_statuses = debug_rule_statuses - - # Print prettified logs related to the status of SageMaker Profiler rules. - profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {}) - if ( - profiler_rule_statuses - and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses) - and (log_type in {"All", "Rules"}) - ): - for status in profiler_rule_statuses: - rule_log = ( - f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" - ) - print(rule_log) - - last_profiler_rule_statuses = profiler_rule_statuses - - if wait: - _check_job_status(job_name, description, "TrainingJobStatus") - if dot: - print() - # Customers are not billed for hardware provisioning, so billable time is less than - # total time - training_time = description.get("TrainingTimeInSeconds") - billable_time = description.get("BillableTimeInSeconds") - if training_time is not None: - print("Training seconds:", training_time * instance_count) - if billable_time is not None: - print("Billable seconds:", billable_time * instance_count) - if description.get("EnableManagedSpotTraining"): - saving = (1 - float(billable_time) / training_time) * 100 - print("Managed Spot Training savings: {:.1f}%".format(saving)) - return last_description - - -def _check_job_status(job, desc, status_key_name): - """Check to see if the job completed successfully. - - If not, construct and raise a exceptions. (UnexpectedStatusException). - - Args: - job (str): The name of the job to check. - desc (dict[str, str]): The result of ``describe_training_job()``. - status_key_name (str): Status key name to check for. - - Raises: - exceptions.CapacityError: If the training job fails with CapacityError. - exceptions.UnexpectedStatusException: If the training job fails. - """ - status = desc[status_key_name] - # If the status is capital case, then convert it to Camel case - status = _STATUS_CODE_TABLE.get(status, status) - - if status == "Stopped": - logger.warning( - "Job ended with status 'Stopped' rather than 'Completed'. " - "This could mean the job timed out or stopped early for some other reason: " - "Consider checking whether it completed as you expect." - ) - elif status != "Completed": - reason = desc.get("FailureReason", "(No reason provided)") - job_type = status_key_name.replace("JobStatus", " job") - troubleshooting = ( - "https://docs.aws.amazon.com/sagemaker/latest/dg/" - "sagemaker-python-sdk-troubleshooting.html" - ) - message = ( - "Error for {job_type} {job_name}: {status}. Reason: {reason}. " - "Check troubleshooting guide for common errors: {troubleshooting}" - ).format( - job_type=job_type, - job_name=job, - status=status, - reason=reason, - troubleshooting=troubleshooting, - ) - if "CapacityError" in str(reason): - raise exceptions.CapacityError( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - raise exceptions.UnexpectedStatusException( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - - -def _flush_log_streams( - stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap -): - """Placeholder docstring""" - if len(stream_names) < instance_count: - # Log streams are created whenever a container starts writing to stdout/err, so this list - # may be dynamic until we have a stream for every instance. - try: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=min(instance_count, 50), - ) - stream_names = [s["logStreamName"] for s in streams["logStreams"]] - - while "nextToken" in streams: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=50, - ) - - stream_names.extend([s["logStreamName"] for s in streams["logStreams"]]) - - positions.update( - [ - (s, sagemaker_logs.Position(timestamp=0, skip=0)) - for s in stream_names - if s not in positions - ] - ) - except ClientError as e: - # On the very first training job run on an account, there's no log group until - # the container starts logging, so ignore any errors thrown about that - err = e.response.get("Error", {}) - if err.get("Code", None) != "ResourceNotFoundException": - raise - - if len(stream_names) > 0: - if dot: - print("") - dot = False - for idx, event in sagemaker_logs.multi_stream_iter( - client, log_group, stream_names, positions - ): - color_wrap(idx, event["message"]) - ts, count = positions[stream_names[idx]] - if event["timestamp"] == ts: - positions[stream_names[idx]] = sagemaker_logs.Position(timestamp=ts, skip=count + 1) - else: - positions[stream_names[idx]] = sagemaker_logs.Position( - timestamp=event["timestamp"], skip=1 - ) - else: - dot = True - print(".", end="") - sys.stdout.flush() - - -def _rule_statuses_changed(current_statuses, last_statuses): - """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules.""" - if not last_statuses: - return True - - for current, last in zip(current_statuses, last_statuses): - if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and ( - current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"] - ): - return True - - return False - - -def _get_initial_job_state(description, status_key, wait): - """Placeholder docstring""" - status = description[status_key] - job_already_completed = status in ("Completed", "Failed", "Stopped") - return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE - - -def _logs_init(boto_session, description, job): - """Placeholder docstring""" - if job == "Training": - if "InstanceGroups" in description["ResourceConfig"]: - instance_count = 0 - for instanceGroup in description["ResourceConfig"]["InstanceGroups"]: - instance_count += instanceGroup["InstanceCount"] - else: - instance_count = description["ResourceConfig"]["InstanceCount"] - elif job == "Transform": - instance_count = description["TransformResources"]["InstanceCount"] - elif job == "Processing": - instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] - elif job == "AutoML": - instance_count = 0 - - stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position - - # Increase retries allowed (from default of 4), as we don't want waiting for a training job - # to be interrupted by a transient exception. - config = botocore.config.Config(retries={"max_attempts": 15}) - client = boto_session.client("logs", config=config) - log_group = "/aws/sagemaker/" + job + "Jobs" - - dot = False - - from sagemaker.core.logs import ColorWrap - - color_wrap = ColorWrap() - - return instance_count, stream_names, positions, client, log_group, dot, color_wrap diff --git a/sagemaker-core/src/sagemaker/core/remote_function/logging_config.py b/sagemaker-core/src/sagemaker/core/remote_function/logging_config.py deleted file mode 100644 index 875fabf6e0..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/logging_config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utilities related to logging.""" -from __future__ import absolute_import - -import logging -import time - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/__init__.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/__init__.py deleted file mode 100644 index 18557a2eb5..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py deleted file mode 100644 index 2c20151ed1..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ /dev/null @@ -1,605 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import getpass -import json -import multiprocessing -import os -import pathlib -import shutil -import subprocess -import sys -from typing import Any, Dict - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) -else: - from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - RuntimeEnvironmentManager, - _DependencySettings, - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" -BASE_CHANNEL_PATH = "/opt/ml/input/data" -FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] -PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" -JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" -SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" - -SM_MODEL_DIR = "/opt/ml/model" - -SM_INPUT_DIR = "/opt/ml/input" -SM_INPUT_DATA_DIR = "/opt/ml/input/data" -SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" - -SM_OUTPUT_DIR = "/opt/ml/output" -SM_OUTPUT_FAILURE = "/opt/ml/output/failure" -SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" - -SM_MASTER_ADDR = "algo-1" -SM_MASTER_PORT = 7777 - -RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" -ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - -SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] -HIDDEN_VALUE = "******" - -SM_EFA_NCCL_INSTANCES = [ - "ml.g4dn.8xlarge", - "ml.g4dn.12xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.p5.48xlarge", - "ml.trn1.32xlarge", -] - -SM_EFA_RDMA_INSTANCES = [ - "ml.p4d.24xlarge", - "ml.p4de.24xlarge", - "ml.trn1.32xlarge", -] - -logger = get_logger() - - -def _bootstrap_runtime_env_for_remote_function( - client_python_version: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for remote function invocation. - - Args: - client_python_version (str): Python version at the client side. - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Settings for installing dependencies. - """ - - workspace_unpack_dir = _unpack_user_workspace() - if not workspace_unpack_dir: - logger.info("No workspace to unpack and setup.") - return - - _handle_pre_exec_scripts(workspace_unpack_dir) - - _install_dependencies( - workspace_unpack_dir, - conda_env, - client_python_version, - REMOTE_FUNCTION_WORKSPACE, - dependency_settings, - ) - - -def _bootstrap_runtime_env_for_pipeline_step( - client_python_version: str, - func_step_workspace: str, - conda_env: str = None, - dependency_settings: _DependencySettings = None, -): - """Bootstrap runtime environment for pipeline step invocation. - - Args: - client_python_version (str): Python version at the client side. - func_step_workspace (str): s3 folder where workspace for FunctionStep is stored - conda_env (str): conda environment to be activated. Default is None. - dependency_settings (dict): Name of the dependency file. Default is None. - """ - - workspace_dir = _unpack_user_workspace(func_step_workspace) - if not workspace_dir: - os.mkdir(JOB_REMOTE_FUNCTION_WORKSPACE) - workspace_dir = pathlib.Path(os.getcwd(), JOB_REMOTE_FUNCTION_WORKSPACE).absolute() - - pre_exec_script_and_dependencies_dir = os.path.join( - BASE_CHANNEL_PATH, SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME - ) - - if not os.path.exists(pre_exec_script_and_dependencies_dir): - logger.info("No dependencies to bootstrap") - return - for file in os.listdir(pre_exec_script_and_dependencies_dir): - src_path = os.path.join(pre_exec_script_and_dependencies_dir, file) - dest_path = os.path.join(workspace_dir, file) - shutil.copy(src_path, dest_path) - - _handle_pre_exec_scripts(workspace_dir) - - _install_dependencies( - workspace_dir, - conda_env, - client_python_version, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - dependency_settings, - ) - - -def _handle_pre_exec_scripts(script_file_dir: str): - """Run the pre execution scripts. - - Args: - script_file_dir (str): Directory in the container where pre-execution scripts exists. - """ - - path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME) - if os.path.isfile(path_to_pre_exec_script): - RuntimeEnvironmentManager().run_pre_exec_script( - pre_exec_script_path=path_to_pre_exec_script - ) - - -def _install_dependencies( - dependency_file_dir: str, - conda_env: str, - client_python_version: str, - channel_name: str, - dependency_settings: _DependencySettings = None, -): - """Install dependencies in the job container - - Args: - dependency_file_dir (str): Directory in the container where dependency file exists. - conda_env (str): conda environment to be activated. - client_python_version (str): Python version at the client side. - channel_name (str): Channel where dependency file was uploaded. - dependency_settings (dict): Settings for installing dependencies. - """ - - if dependency_settings is not None and dependency_settings.dependency_file is None: - # an empty dict is passed when no dependencies are specified - logger.info("No dependencies to install.") - elif dependency_settings is not None: - dependencies_file = os.path.join(dependency_file_dir, dependency_settings.dependency_file) - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - # no dependency file name is passed when an legacy version of the SDK is used - # we look for a file with .txt, .yml or .yaml extension in the workspace directory - dependencies_file = None - for file in os.listdir(dependency_file_dir): - if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"): - dependencies_file = os.path.join(dependency_file_dir, file) - break - - if dependencies_file: - RuntimeEnvironmentManager().bootstrap( - local_dependencies_file=dependencies_file, - conda_env=conda_env, - client_python_version=client_python_version, - ) - else: - logger.info( - "Did not find any dependency file in the directory at '%s'." - " Assuming no additional dependencies to install.", - os.path.join(BASE_CHANNEL_PATH, channel_name), - ) - - -def _unpack_user_workspace(func_step_workspace: str = None): - """Unzip the user workspace""" - - workspace_archive_dir_path = ( - os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE) - if not func_step_workspace - else os.path.join(BASE_CHANNEL_PATH, func_step_workspace) - ) - if not os.path.exists(workspace_archive_dir_path): - logger.info( - "Directory '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip") - if not os.path.isfile(workspace_archive_path): - logger.info( - "Workspace archive '%s' does not exist.", - workspace_archive_dir_path, - ) - return None - - workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute() - shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir) - logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir) - workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE) - return workspace_unpack_dir - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_conda_env", type=str) - parser.add_argument("--client_python_version", type=str) - parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None) - parser.add_argument("--pipeline_execution_id", type=str) - parser.add_argument("--dependency_settings", type=str) - parser.add_argument("--func_step_s3_dir", type=str) - parser.add_argument("--distribution", type=str, default=None) - parser.add_argument("--user_nproc_per_node", type=str, default=None) - args, _ = parser.parse_known_args(sys_args) - return args - - -def log_key_value(key: str, value: str): - """Log a key-value pair, masking sensitive values if necessary.""" - if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): - logger.info("%s=%s", key, HIDDEN_VALUE) - elif isinstance(value, dict): - masked_value = mask_sensitive_info(value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - try: - decoded_value = json.loads(value) - if isinstance(decoded_value, dict): - masked_value = mask_sensitive_info(decoded_value) - logger.info("%s=%s", key, json.dumps(masked_value)) - else: - logger.info("%s=%s", key, decoded_value) - except (json.JSONDecodeError, TypeError): - logger.info("%s=%s", key, value) - - -def log_env_variables(env_vars_dict: Dict[str, Any]): - """Log Environment Variables from the environment and an env_vars_dict.""" - for key, value in os.environ.items(): - log_key_value(key, value) - - for key, value in env_vars_dict.items(): - log_key_value(key, value) - - -def mask_sensitive_info(data): - """Recursively mask sensitive information in a dictionary.""" - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict): - data[k] = mask_sensitive_info(v) - elif isinstance(v, str) and any( - keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS - ): - data[k] = HIDDEN_VALUE - return data - - -def num_cpus() -> int: - """Return the number of CPUs available in the current container. - - Returns: - int: Number of CPUs available in the current container. - """ - return multiprocessing.cpu_count() - - -def num_gpus() -> int: - """Return the number of GPUs available in the current container. - - Returns: - int: Number of GPUs available in the current container. - """ - try: - cmd = ["nvidia-smi", "--list-gpus"] - output = subprocess.check_output(cmd).decode("utf-8") - return sum(1 for line in output.splitlines() if line.startswith("GPU ")) - except (OSError, subprocess.CalledProcessError): - logger.info("No GPUs detected (normal if no gpus installed)") - return 0 - - -def num_neurons() -> int: - """Return the number of neuron cores available in the current container. - - Returns: - int: Number of Neuron Cores available in the current container. - """ - try: - cmd = ["neuron-ls", "-j"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") - j = json.loads(output) - neuron_cores = 0 - for item in j: - neuron_cores += item.get("nc_count", 0) - logger.info("Found %s neurons on this instance", neuron_cores) - return neuron_cores - except OSError: - logger.info("No Neurons detected (normal if no neurons installed)") - return 0 - except subprocess.CalledProcessError as e: - if e.output is not None: - try: - msg = e.output.decode("utf-8").partition("error=")[2] - logger.info( - "No Neurons detected (normal if no neurons installed). \ - If neuron installed then %s", - msg, - ) - except AttributeError: - logger.info("No Neurons detected (normal if no neurons installed)") - else: - logger.info("No Neurons detected (normal if no neurons installed)") - - return 0 - - -def safe_serialize(data): - """Serialize the data without wrapping strings in quotes. - - This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. - - Args: - data (Any): The data to serialize. - - Returns: - str: The serialized JSON-compatible string or the string representation of the input. - """ - if isinstance(data, str): - return data - try: - return json.dumps(data) - except TypeError: - return str(data) - - -def set_env( - resource_config: Dict[str, Any], - distribution: str = None, - user_nproc_per_node: bool = None, - output_file: str = ENV_OUTPUT_FILE, -): - """Set environment variables for the training job container. - - Args: - resource_config (Dict[str, Any]): Resource configuration for the training job. - output_file (str): Output file to write the environment variables. - """ - # Constants - env_vars = { - "SM_MODEL_DIR": SM_MODEL_DIR, - "SM_INPUT_DIR": SM_INPUT_DIR, - "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, - "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, - "SM_OUTPUT_DIR": SM_OUTPUT_DIR, - "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, - "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, - "SM_MASTER_ADDR": SM_MASTER_ADDR, - "SM_MASTER_PORT": SM_MASTER_PORT, - } - - # Host Variables - current_host = resource_config["current_host"] - current_instance_type = resource_config["current_instance_type"] - hosts = resource_config["hosts"] - sorted_hosts = sorted(hosts) - - env_vars["SM_CURRENT_HOST"] = current_host - env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type - env_vars["SM_HOSTS"] = sorted_hosts - env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] - env_vars["SM_HOST_COUNT"] = len(sorted_hosts) - env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) - - env_vars["SM_NUM_CPUS"] = num_cpus() - env_vars["SM_NUM_GPUS"] = num_gpus() - env_vars["SM_NUM_NEURONS"] = num_neurons() - - # Misc. - env_vars["SM_RESOURCE_CONFIG"] = resource_config - - if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) - else: - if int(env_vars["SM_NUM_GPUS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) - elif int(env_vars["SM_NUM_NEURONS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) - else: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) - - # All Training Environment Variables - env_vars["SM_TRAINING_ENV"] = { - "current_host": env_vars["SM_CURRENT_HOST"], - "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], - "hosts": env_vars["SM_HOSTS"], - "host_count": env_vars["SM_HOST_COUNT"], - "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], - "master_addr": env_vars["SM_MASTER_ADDR"], - "master_port": env_vars["SM_MASTER_PORT"], - "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], - "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], - "input_dir": env_vars["SM_INPUT_DIR"], - "job_name": os.environ["TRAINING_JOB_NAME"], - "model_dir": env_vars["SM_MODEL_DIR"], - "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], - "num_cpus": env_vars["SM_NUM_CPUS"], - "num_gpus": env_vars["SM_NUM_GPUS"], - "num_neurons": env_vars["SM_NUM_NEURONS"], - "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], - "resource_config": env_vars["SM_RESOURCE_CONFIG"], - } - - if distribution and distribution == "torchrun": - logger.info("Distribution: torchrun") - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") - - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - env_vars["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" - env_vars["RDMAV_FORK_SAFE"] = "1" - env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - env_vars["NCCL_PROTO"] = "simple" - elif distribution and distribution == "mpirun": - logger.info("Distribution: mpirun") - - env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] - env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) - - host_list = [ - "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts - ] - env_vars["SM_HOSTS_LIST"] = ",".join(host_list) - - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - - if instance_type in SM_EFA_NCCL_INSTANCES: - env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" - env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" - else: - env_vars["SM_FI_PROVIDER"] = "" - env_vars["SM_NCCL_PROTO"] = "" - - if instance_type in SM_EFA_RDMA_INSTANCES: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" - else: - env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" - - with open(output_file, "w") as f: - for key, value in env_vars.items(): - f.write(f"export {key}='{safe_serialize(value)}'\n") - - logger.info("Environment Variables:") - log_env_variables(env_vars_dict=env_vars) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE - - try: - args = _parse_args(sys_args) - - logger.info("Arguments:") - for arg in vars(args): - logger.info("%s=%s", arg, getattr(args, arg)) - - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir - distribution = args.distribution - user_nproc_per_node = args.user_nproc_per_node - - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") - - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) - - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - if os.path.exists(RESOURCE_CONFIG): - try: - logger.info("Found %s", RESOURCE_CONFIG) - with open(RESOURCE_CONFIG, "r") as f: - resource_config = json.load(f) - set_env( - resource_config=resource_config, - distribution=distribution, - user_nproc_per_node=user_nproc_per_node, - ) - except (json.JSONDecodeError, FileNotFoundError) as e: - # Optionally, you might want to log this error - logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py deleted file mode 100644 index f36e17a04c..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" -from __future__ import absolute_import - -import argparse -import json -import os -import subprocess -import sys -import time -from typing import List - -import paramiko - -if __package__ is None or __package__ == "": - from runtime_environment_manager import ( - get_logger, - ) -else: - from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - get_logger, - ) - -SUCCESS_EXIT_CODE = 0 -DEFAULT_FAILURE_CODE = 1 - -FINISHED_STATUS_FILE = "/tmp/done.algo-1" -READY_FILE = "/tmp/ready.%s" -DEFAULT_SSH_PORT = 22 - -FAILURE_REASON_PATH = "/opt/ml/output/failure" -FINISHED_STATUS_FILE = "/tmp/done.algo-1" - -logger = get_logger() - - -class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): - """Class to handle host key policy for SageMaker distributed training SSH connections. - - Example: - >>> client = paramiko.SSHClient() - >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) - >>> # Will succeed for SageMaker algorithm containers - >>> client.connect('algo-1234.internal') - >>> # Will raise SSHException for other unknown hosts - >>> client.connect('unknown-host') # raises SSHException - """ - - def missing_host_key(self, client, hostname, key): - """Accept host keys for algo-* hostnames, reject others. - - Args: - client: The SSHClient instance - hostname: The hostname attempting to connect - key: The host key - Raises: - paramiko.SSHException: If hostname doesn't match algo-* pattern - """ - if hostname.startswith("algo-"): - client.get_host_keys().add(hostname, key.get_name(), key) - return - raise paramiko.SSHException(f"Unknown host key for {hostname}") - - -def _parse_args(sys_args): - """Parses CLI arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--job_ended", type=str, default="0") - args, _ = parser.parse_known_args(sys_args) - return args - - -def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: - """Check if the connection to the provided host and port is possible.""" - try: - with paramiko.SSHClient() as client: - client.load_system_host_keys() - client.set_missing_host_key_policy(CustomHostKeyPolicy()) - client.connect(host, port=port) - logger.info("Can connect to host %s", host) - return True - except Exception as e: # pylint: disable=W0703 - logger.info("Cannot connect to host %s", host) - logger.debug("Connection failed with exception: %s", e) - return False - - -def _write_file_to_host(host: str, status_file: str) -> bool: - """Write the a file to the provided host.""" - try: - logger.info("Writing %s to %s", status_file, host) - subprocess.run( - ["ssh", host, "touch", f"{status_file}"], - capture_output=True, - text=True, - check=True, - ) - logger.info("Finished writing status file") - return True - except subprocess.CalledProcessError: - logger.info("Cannot connect to %s", host) - return False - - -def _write_failure_reason_file(failure_msg): - """Create a file 'failure' with failure reason written if bootstrap runtime env failed. - - See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html - Args: - failure_msg: The content of file to be written. - """ - if not os.path.exists(FAILURE_REASON_PATH): - with open(FAILURE_REASON_PATH, "w") as f: - f.write("RuntimeEnvironmentError: " + failure_msg) - - -def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Worker nodes wait until they can connect to the master node.""" - start_time = time.time() - while True: - logger.info("Worker is attempting to connect to the master node %s...", master_host) - if _can_connect(master_host, port): - logger.info("Worker can connect to master node %s.", master_host) - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) - - time.sleep(5) # Wait for 5 seconds before trying again - - -def _wait_for_status_file(status_file: str): - """Wait for the status file to be created.""" - logger.info("Waiting for status file %s", status_file) - while not os.path.exists(status_file): - time.sleep(30) - logger.info("Found status file %s", status_file) - - -def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): - """Master node waits until it can connect to all worker nodes.""" - start_time = time.time() - if not worker_hosts: - logger.info("No worker nodes to connect to.") - return - - while True: - logger.info("Master is attempting to connect to all workers...") - all_workers_connected = all( - _can_connect(worker, port) and os.path.exists(READY_FILE % worker) - for worker in worker_hosts - ) - - if all_workers_connected: - logger.info("Master can connect to all worker nodes.") - break - if time.time() - start_time > timeout: - raise TimeoutError("Timed out waiting for workers to be reachable.") - - time.sleep(5) # Wait for 5 seconds before trying again - - -def bootstrap_master_node(worker_hosts: List[str]): - """Bootstrap the master node.""" - logger.info("Bootstrapping master node...") - _wait_for_workers(worker_hosts) - - -def bootstrap_worker_node( - master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE -): - """Bootstrap the worker nodes.""" - logger.info("Bootstrapping worker node...") - _wait_for_master(master_host) - _write_file_to_host(master_host, READY_FILE % current_host) - _wait_for_status_file(status_file) - - -def start_sshd_daemon(): - """Start the SSH daemon on the current node.""" - sshd_executable = "/usr/sbin/sshd" - - if not os.path.exists(sshd_executable): - raise RuntimeError("SSH daemon not found.") - - # Start the sshd in daemon mode (-D) - subprocess.Popen([sshd_executable, "-D"]) - logger.info("Started SSH daemon.") - - -def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): - """Write the status file to all worker nodes.""" - for worker in worker_hosts: - retry = 0 - while not _write_file_to_host(worker, status_file): - time.sleep(5) - retry += 1 - if retry > 5: - raise TimeoutError("Timed out waiting for %s to be reachable." % worker) - logger.info("Retrying to write status file to %s", worker) - - -def main(sys_args=None): - """Entry point for bootstrap script""" - try: - args = _parse_args(sys_args) - - job_ended = args.job_ended - - main_host = os.environ["SM_MASTER_ADDR"] - current_host = os.environ["SM_CURRENT_HOST"] - - if job_ended == "0": - logger.info("Job is running, bootstrapping nodes") - - start_sshd_daemon() - - if current_host != main_host: - bootstrap_worker_node(main_host, current_host) - else: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - bootstrap_master_node(worker_hosts) - else: - logger.info("Job ended, writing status file to workers") - - if current_host == main_host: - sorted_hosts = json.loads(os.environ["SM_HOSTS"]) - worker_hosts = [host for host in sorted_hosts if host != main_host] - - write_status_file_to_workers(worker_hosts) - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - - sys.exit(DEFAULT_FAILURE_CODE) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py deleted file mode 100644 index 5f00317c23..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK""" - -from __future__ import absolute_import - - -import logging -import sys -import shlex -import os -import subprocess -import time -import dataclasses -import json - - -class _UTCFormatter(logging.Formatter): - """Class that overrides the default local time provider in log formatter.""" - - converter = time.gmtime - - -def get_logger(): - """Return a logger with the name 'sagemaker'""" - sagemaker_logger = logging.getLogger("sagemaker.remote_function") - if len(sagemaker_logger.handlers) == 0: - sagemaker_logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") - handler.setFormatter(formatter) - sagemaker_logger.addHandler(handler) - # don't stream logs with the root logger handler - sagemaker_logger.propagate = 0 - - return sagemaker_logger - - -logger = get_logger() - - -@dataclasses.dataclass -class _DependencySettings: - """Dependency settings for the remote function. - - Instructs the runtime environment script on how to handle dependencies. - If ``dependency_file`` is set, the runtime environment script will attempt - to install the dependencies. If ``dependency_file`` is not set, the runtime - environment script will assume no dependencies are required. - """ - - dependency_file: str = None - - def to_string(self): - """Converts the dependency settings to a string.""" - return json.dumps(dataclasses.asdict(self)) - - @staticmethod - def from_string(dependency_settings_string): - """Converts a json string to dependency settings. - - Args: - dependency_settings_string (str): The json string to convert. - """ - if dependency_settings_string is None: - return None - dependency_settings_dict = json.loads(dependency_settings_string) - return _DependencySettings(dependency_settings_dict.get("dependency_file")) - - @staticmethod - def from_dependency_file_path(dependency_file_path): - """Converts a dependency file path to dependency settings. - - Args: - dependency_file_path (str): The path to the dependency file. - """ - if dependency_file_path is None: - return _DependencySettings() - if dependency_file_path == "auto_capture": - return _DependencySettings("env_snapshot.yml") - return _DependencySettings(os.path.basename(dependency_file_path)) - - -class RuntimeEnvironmentManager: - """Runtime Environment Manager class to manage runtime environment.""" - - def _validate_path(self, path: str) -> str: - """Validate and sanitize file path to prevent path traversal attacks. - - Args: - path (str): The file path to validate - - Returns: - str: The validated absolute path - - Raises: - ValueError: If the path is invalid or contains suspicious patterns - """ - if not path: - raise ValueError("Path cannot be empty") - - # Get absolute path to prevent path traversal - abs_path = os.path.abspath(path) - - # Check for null bytes (common in path traversal attacks) - if '\x00' in path: - raise ValueError(f"Invalid path contains null byte: {path}") - - return abs_path - - def _validate_env_name(self, env_name: str) -> None: - """Validate conda environment name to prevent command injection. - - Args: - env_name (str): The environment name to validate - - Raises: - ValueError: If the environment name contains invalid characters - """ - if not env_name: - raise ValueError("Environment name cannot be empty") - - # Allow only alphanumeric, underscore, and hyphen - import re - if not re.match(r'^[a-zA-Z0-9_-]+$', env_name): - raise ValueError( - f"Invalid environment name '{env_name}'. " - "Only alphanumeric characters, underscores, and hyphens are allowed." - ) - - def snapshot(self, dependencies: str = None) -> str: - """Creates snapshot of the user's environment - - If a req.txt or conda.yml file is provided, it verifies their existence and - returns the local file path - If ``auto_capture`` is set, this method will take the snapshot of - user's dependencies installed in the local runtime. - Current support for ``auto_capture``: - * conda env, generate a yml file and return it's local path - - Args: - dependencies (str): Local path where dependencies file exists. - - Returns: - file path of the existing or generated dependencies file - """ - - # No additional dependencies specified - if dependencies is None: - return None - - if dependencies == "auto_capture": - return self._capture_from_local_runtime() - - # Dependencies specified as either req.txt or conda_env.yml - if ( - dependencies.endswith(".txt") - or dependencies.endswith(".yml") - or dependencies.endswith(".yaml") - ): - self._is_file_exists(dependencies) - return dependencies - - raise ValueError(f'Invalid dependencies provided: "{dependencies}"') - - def _capture_from_local_runtime(self) -> str: - """Generates dependencies list from the user's local runtime. - - Raises RuntimeEnvironmentError if not able to. - - Currently supports: conda environments - """ - - # Try to capture dependencies from the conda environment, if any. - conda_env_name = self._get_active_conda_env_name() - conda_env_prefix = self._get_active_conda_env_prefix() - if conda_env_name: - logger.info("Found conda_env_name: '%s'", conda_env_name) - elif conda_env_prefix: - logger.info("Found conda_env_prefix: '%s'", conda_env_prefix) - else: - raise ValueError("No conda environment seems to be active.") - - if conda_env_name == "base": - logger.warning( - "We recommend using an environment other than base to " - "isolate your project dependencies from conda dependencies" - ) - - local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml") - self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path) - - return local_dependencies_path - - def _get_active_conda_env_prefix(self) -> str: - """Returns the conda prefix from the set environment variable. None otherwise.""" - return os.getenv("CONDA_PREFIX") - - def _get_active_conda_env_name(self) -> str: - """Returns the conda environment name from the set environment variable. None otherwise.""" - return os.getenv("CONDA_DEFAULT_ENV") - - def bootstrap( - self, local_dependencies_file: str, client_python_version: str, conda_env: str = None - ): - """Bootstraps the runtime environment by installing the additional dependencies if any. - - Args: - local_dependencies_file (str): path where dependencies file exists. - conda_env (str): conda environment to be activated. Default is None. - - Returns: None - """ - - if local_dependencies_file.endswith(".txt"): - if conda_env: - self._install_req_txt_in_conda_env(conda_env, local_dependencies_file) - self._write_conda_env_to_file(conda_env) - - else: - self._install_requirements_txt(local_dependencies_file, _python_executable()) - - elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"): - if conda_env: - self._update_conda_env(conda_env, local_dependencies_file) - else: - conda_env = "sagemaker-runtime-env" - self._create_conda_env(conda_env, local_dependencies_file) - self._validate_python_version(client_python_version, conda_env) - self._write_conda_env_to_file(conda_env) - - def run_pre_exec_script(self, pre_exec_script_path: str): - """Runs script of pre-execution commands if existing. - - Args: - pre_exec_script_path (str): Path to pre-execution command script file. - """ - if os.path.isfile(pre_exec_script_path): - logger.info("Running pre-execution commands in '%s'", pre_exec_script_path) - return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path) - - if return_code: - error_message = ( - f"Encountered error while running pre-execution commands. Reason: {error_logs}" - ) - raise RuntimeEnvironmentError(error_message) - else: - logger.info( - "'%s' does not exist. Assuming no pre-execution commands to run", - pre_exec_script_path, - ) - - def change_dir_permission(self, dirs: list, new_permission: str): - """Change the permission of given directories - - Args: - dirs (list[str]): A list of directories for permission update. - new_permission (str): The new permission for the given directories. - """ - - _ERROR_MSG_PREFIX = "Failed to change directory permissions due to: " - command = ["sudo", "chmod", "-R", new_permission] + dirs - logger.info("Executing '%s'.", " ".join(command)) - - try: - subprocess.run(command, check=True, stderr=subprocess.PIPE) - except subprocess.CalledProcessError as called_process_err: - err_msg = called_process_err.stderr.decode("utf-8") - raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}") - except FileNotFoundError as file_not_found_err: - if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err): - raise RuntimeEnvironmentError( - f"{_ERROR_MSG_PREFIX} {file_not_found_err}. " - "Please contact the image owner to install 'sudo' in the job container " - "and provide sudo privilege to the container user." - ) - raise RuntimeEnvironmentError(file_not_found_err) - - def _is_file_exists(self, dependencies): - """Check whether the dependencies file exists at the given location. - - Raises error if not - """ - if not os.path.isfile(dependencies): - raise ValueError(f'No dependencies file named "{dependencies}" was found.') - - def _install_requirements_txt(self, local_path, python_executable): - """Install requirements.txt file""" - # Validate path to prevent command injection - validated_path = self._validate_path(local_path) - cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"] - logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd()) - _run_shell_cmd(cmd) - logger.info("Command %s ran successfully", " ".join(cmd)) - - def _create_conda_env(self, env_name, local_path): - """Create conda env using conda yml file""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path] - logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Conda environment %s created successfully.", env_name) - - def _install_req_txt_in_conda_env(self, env_name, local_path): - """Install requirements.txt in the given conda environment""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"] - logger.info("Activating conda env and installing requirements: %s", " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Requirements installed successfully in conda env %s", env_name) - - def _update_conda_env(self, env_name, local_path): - """Update conda env using conda yml file""" - # Validate inputs to prevent command injection - self._validate_env_name(env_name) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path] - logger.info("Updating conda env: %s", " ".join(cmd)) - _run_shell_cmd(cmd) - logger.info("Conda env %s updated succesfully", env_name) - - def _export_conda_env_from_prefix(self, prefix, local_path): - """Export the conda env to a conda yml file""" - # Validate inputs to prevent command injection - validated_prefix = self._validate_path(prefix) - validated_path = self._validate_path(local_path) - - cmd = [self._get_conda_exe(), "env", "export", "-p", validated_prefix, "--no-builds"] - logger.info("Exporting conda environment: %s", " ".join(cmd)) - - # Capture output and write to file instead of using shell redirection - try: - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False - ) - output, error_output = process.communicate() - return_code = process.wait() - - if return_code: - error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_output.decode('utf-8')}" - raise RuntimeEnvironmentError(error_message) - - # Write the captured output to the file - with open(validated_path, 'w') as f: - f.write(output.decode('utf-8')) - - logger.info("Conda environment %s exported successfully", validated_prefix) - except Exception as e: - raise RuntimeEnvironmentError(f"Failed to export conda environment: {str(e)}") - - def _write_conda_env_to_file(self, env_name): - """Writes conda env to the text file""" - - file_name = "remote_function_conda_env.txt" - file_path = os.path.join(os.getcwd(), file_name) - with open(file_path, "w") as output_file: - output_file.write(env_name) - - def _get_conda_exe(self): - """Checks whether conda or mamba is available to use""" - - if not subprocess.Popen(["which", "mamba"]).wait(): - return "mamba" - if not subprocess.Popen(["which", "conda"]).wait(): - return "conda" - raise ValueError("Neither conda nor mamba is installed on the image") - - def _python_version_in_conda_env(self, env_name): - """Returns python version inside a conda environment""" - cmd = f"{self._get_conda_exe()} run -n {env_name} python --version" - try: - output = ( - subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT) - .decode("utf-8") - .strip() - ) - # convert 'Python 3.7.16' to [3, 7, 16] - version = output.split("Python ")[1].split(".") - return version[0] + "." + version[1] - except subprocess.CalledProcessError as e: - raise RuntimeEnvironmentError(e.output) - - def _current_python_version(self): - """Returns the current python version where program is running""" - - return f"{sys.version_info.major}.{sys.version_info.minor}".strip() - - def _current_sagemaker_pysdk_version(self): - """Returns the current sagemaker python sdk version where program is running""" - try: - from importlib import metadata - - return metadata.version("sagemaker") - except Exception: - return "3.0.0.dev0" # Development version fallback - - def _validate_python_version(self, client_python_version: str, conda_env: str = None): - """Validate the python version - - Validates if the python version where remote function runs - matches the one used on client side. - """ - if conda_env: - job_python_version = self._python_version_in_conda_env(conda_env) - else: - job_python_version = self._current_python_version() - if client_python_version.strip() != job_python_version.strip(): - raise RuntimeEnvironmentError( - f"Python version found in the container is '{job_python_version}' which " - f"does not match python version '{client_python_version}' on the local client. " - f"Please make sure that the python version used in the training container " - f"is same as the local python version." - ) - - def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): - """Validate the sagemaker python sdk version - - Validates if the sagemaker python sdk version where remote function runs - matches the one used on client side. - Otherwise, log a warning to call out that unexpected behaviors - may occur in this case. - """ - job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version() - if ( - client_sagemaker_pysdk_version - and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version - ): - logger.warning( - "Inconsistent sagemaker versions found: " - "sagemaker python sdk version found in the container is " - "'%s' which does not match the '%s' on the local client. " - "Please make sure that the sagemaker version used in the training container " - "is the same as the local sagemaker version in case of unexpected behaviors.", - job_sagemaker_pysdk_version, - client_sagemaker_pysdk_version, - ) - - -def _run_and_get_output_shell_cmd(cmd: str) -> str: - """Run and return the output of the given shell command""" - return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") - - -def _run_pre_execution_command_script(script_path: str): - """This method runs a given shell script using subprocess - - Raises RuntimeEnvironmentError if the shell script fails - """ - current_dir = os.path.dirname(script_path) - - process = subprocess.Popen( - ["/bin/bash", "-eu", script_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=current_dir, - ) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - - return return_code, error_logs - - -def _run_shell_cmd(cmd: list): - """This method runs a given shell command using subprocess - - Args: - cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt']) - - Raises: - RuntimeEnvironmentError: If the command fails - ValueError: If cmd is not a list - """ - if not isinstance(cmd, list): - raise ValueError("Command must be a list of arguments for security reasons") - - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) - - _log_output(process) - error_logs = _log_error(process) - return_code = process.wait() - if return_code: - error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}" - raise RuntimeEnvironmentError(error_message) - - -def _log_output(process): - """This method takes in Popen process and logs the output of that process""" - with process.stdout as pipe: - for line in iter(pipe.readline, b""): - logger.info(str(line, "UTF-8")) - - -def _log_error(process): - """This method takes in Popen process and logs the error of that process. - - Returns those logs as a string - """ - - error_logs = "" - with process.stderr as pipe: - for line in iter(pipe.readline, b""): - error_str = str(line, "UTF-8") - if "ERROR:" in error_str: - logger.error(error_str) - else: - logger.warning(error_str) - error_logs = error_logs + error_str - - return error_logs - - -def _python_executable(): - """Return the real path for the Python executable, if it exists. - - Return RuntimeEnvironmentError otherwise. - - Returns: - (str): The real path of the current Python executable. - """ - if not sys.executable: - raise RuntimeEnvironmentError( - "Failed to retrieve the path for the Python executable binary" - ) - return sys.executable - - -class RuntimeEnvironmentError(Exception): - """The base exception class for bootstrap env excepitons""" - - def __init__(self, message): - self.message = message - super().__init__(self.message) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/spark_app.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/spark_app.py deleted file mode 100644 index 21eef068b9..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/spark_app.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This is a simple scrip of spark which invokes the pickled remote function""" -from __future__ import absolute_import - -from sagemaker.core.remote_function import invoke_function - -invoke_function.main() diff --git a/sagemaker-core/src/sagemaker/core/remote_function/spark_config.py b/sagemaker-core/src/sagemaker/core/remote_function/spark_config.py deleted file mode 100644 index 6b25d5da8b..0000000000 --- a/sagemaker-core/src/sagemaker/core/remote_function/spark_config.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module is used to define the Spark job config to remote function.""" -from __future__ import absolute_import - -from typing import Optional, List, Dict, Union -import attr -from urllib.parse import urlparse -from sagemaker.core.workflow import is_pipeline_variable - - -def _validate_configuration(instance, attribute, configuration): - # pylint: disable=unused-argument - """This is the helper method to validate the spark configuration""" - if configuration: - SparkConfigUtils.validate_configuration(configuration=configuration) - - -def _validate_s3_uri(instance, attribute, s3_uri): - # pylint: disable=unused-argument - """This is the helper method to validate the s3 uri""" - if s3_uri: - SparkConfigUtils.validate_s3_uri(s3_uri) - - -@attr.s(frozen=True) -class SparkConfig: - """This is the class to initialize the spark configurations for remote function - - Attributes: - submit_jars (Optional[List[str]]): A list which contains paths to the jars which - are going to be submitted to Spark job. The location can be a valid s3 uri or - local path to the jar. Defaults to ``None``. - submit_py_files (Optional[List[str]]): A list which contains paths to the python - files which are going to be submitted to Spark job. The location can be a - valid s3 uri or local path to the python file. Defaults to ``None``. - submit_files (Optional[List[str]]): A list which contains paths to the files which - are going to be submitted to Spark job. The location can be a valid s3 uri or - local path to the python file. Defaults to ``None``. - configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive. - List or dictionary of EMR-style classifications. - https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html - spark_event_logs_s3_uri (str): S3 path where Spark application events will - be published to. - """ - - submit_jars: Optional[List[str]] = attr.ib(default=None) - submit_py_files: Optional[List[str]] = attr.ib(default=None) - submit_files: Optional[List[str]] = attr.ib(default=None) - configuration: Optional[Union[List[Dict], Dict]] = attr.ib( - default=None, validator=_validate_configuration - ) - spark_event_logs_uri: Optional[str] = attr.ib(default=None, validator=_validate_s3_uri) - - -class SparkConfigUtils: - """Util class for spark configurations""" - - _valid_configuration_keys = ["Classification", "Properties", "Configurations"] - _valid_configuration_classifications = [ - "core-site", - "hadoop-env", - "hadoop-log4j", - "hive-env", - "hive-log4j", - "hive-exec-log4j", - "hive-site", - "spark-defaults", - "spark-env", - "spark-log4j", - "spark-hive-site", - "spark-metrics", - "yarn-env", - "yarn-site", - "export", - ] - - @staticmethod - def validate_configuration(configuration: Dict): - """Validates the user-provided Hadoop/Spark/Hive configuration. - - This ensures that the list or dictionary the user provides will serialize to - JSON matching the schema of EMR's application configuration - - Args: - configuration (Dict): A dict that contains the configuration overrides to - the default values. For more information, please visit: - https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html - """ - emr_configure_apps_url = ( - "https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html" - ) - if isinstance(configuration, dict): - keys = configuration.keys() - if "Classification" not in keys or "Properties" not in keys: - raise ValueError( - f"Missing one or more required keys in configuration dictionary " - f"{configuration} Please see {emr_configure_apps_url} for more information" - ) - - for key in keys: - if key not in SparkConfigUtils._valid_configuration_keys: - raise ValueError( - f"Invalid key: {key}. " - f"Must be one of {SparkConfigUtils._valid_configuration_keys}. " - f"Please see {emr_configure_apps_url} for more information." - ) - if key == "Classification": - if ( - configuration[key] - not in SparkConfigUtils._valid_configuration_classifications - ): - raise ValueError( - f"Invalid classification: {key}. Must be one of " - f"{SparkConfigUtils._valid_configuration_classifications}" - ) - - if isinstance(configuration, list): - for item in configuration: - SparkConfigUtils.validate_configuration(item) - - # TODO (guoqioa@): method only checks urlparse scheme, need to perform deep s3 validation - @staticmethod - def validate_s3_uri(spark_output_s3_path): - """Validate whether the URI uses an S3 scheme. - - In the future, this validation will perform deeper S3 validation. - - Args: - spark_output_s3_path (str): The URI of the Spark output S3 Path. - """ - if is_pipeline_variable(spark_output_s3_path): - return - - if urlparse(spark_output_s3_path).scheme != "s3": - raise ValueError( - f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " - "s3://bucket-name/folder-name" - ) diff --git a/sagemaker-core/src/sagemaker/core/training/__init__.py b/sagemaker-core/src/sagemaker/core/training/__init__.py deleted file mode 100644 index 86ce9b7a0f..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Training configuration and utilities.""" -from __future__ import absolute_import diff --git a/sagemaker-core/src/sagemaker/core/training/configs.py b/sagemaker-core/src/sagemaker/core/training/configs.py deleted file mode 100644 index ad2232d630..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/configs.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module provides the configuration classes used in ``sagemaker.modules``. - -Some of these classes are re-exported from ``sagemaker.core.shapes``. For convinence, -users can import these classes directly from ``sagemaker.modules.configs``. - -For more documentation on ``sagemaker.core.shapes``, see: - - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes -""" - -from __future__ import absolute_import - -from typing import Optional, Union -from pydantic import BaseModel, model_validator, ConfigDict - -import sagemaker.core.shapes as shapes -from sagemaker.core.helper.pipeline_variable import StrPipeVar - -# TODO: Can we add custom logic to some of these to set better defaults? -from sagemaker.core.shapes import ( - StoppingCondition, - RetryStrategy, - Channel, - ShuffleConfig, - DataSource, - S3DataSource, - FileSystemDataSource, - TrainingImageConfig, - TrainingRepositoryAuthConfig, - Tag, - InfraCheckConfig, - RemoteDebugConfig, - SessionChainingConfig, - InstanceGroup, - HubAccessConfig, - ModelAccessConfig, - MetricDefinition, - DatasetSource, -) - -from sagemaker.core.training.utils import convert_unassigned_to_none - -__all__ = [ - "BaseConfig", - "SourceCode", - "StoppingCondition", - "RetryStrategy", - "OutputDataConfig", - "Channel", - "ShuffleConfig", - "DataSource", - "S3DataSource", - "FileSystemDataSource", - "TrainingImageConfig", - "TrainingRepositoryAuthConfig", - "Tag", - "InfraCheckConfig", - "RemoteDebugConfig", - "SessionChainingConfig", - "InstanceGroup", - "TensorBoardOutputConfig", - "CheckpointConfig", - "HubAccessConfig", - "ModelAccessConfig", - "Compute", - "Networking", - "InputData", - "MetricDefinition", - "DatasetSource", -] - - -class BaseConfig(BaseModel): - """BaseConfig""" - - model_config = ConfigDict(validate_assignment=True, extra="forbid") - - -class SourceCode(BaseConfig): - """SourceCode. - - The SourceCode class allows the user to specify the source code location, dependencies, - entry script, or commands to be executed in the training job container. - - Parameters: - source_dir (Optional[StrPipeVar]): - The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that - contains the source code to be used in the training job container. - requirements (Optional[StrPipeVar]): - The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed - requirements will be installed in the training job container. - entry_script (Optional[StrPipeVar]): - The path within ``source_dir`` to the entry script that will be executed in the training - job container. If not specified, command must be provided. - command (Optional[StrPipeVar]): - The command(s) to execute in the training job container. Example: "python my_script.py". - If not specified, entry_script must be provided. - """ - - source_dir: Optional[StrPipeVar] = None - requirements: Optional[StrPipeVar] = None - entry_script: Optional[StrPipeVar] = None - command: Optional[StrPipeVar] = None - - -class OutputDataConfig(shapes.OutputDataConfig): - """OutputDataConfig. - - Provides the configuration for the output data location of the training job. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - The S3 URI where the output data will be stored. This is the location where the - training job will save its output data, such as model artifacts and logs. - kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that - SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side - encryption. - compression_type (Optional[StrPipeVar]): - The model output compression type. Select None to output an uncompressed model, - recommended for large model outputs. Defaults to gzip. - """ - - s3_output_path: Optional[StrPipeVar] = None - kms_key_id: Optional[StrPipeVar] = None - compression_type: Optional[StrPipeVar] = None - - -class Compute(shapes.ResourceConfig): - """Compute. - - The Compute class is a subclass of ``sagemaker.core.shapes.ResourceConfig`` - and allows the user to specify the compute resources for the training job. - - Parameters: - instance_type (Optional[StrPipeVar]): - The ML compute instance type. For information about available instance types, - see https://aws.amazon.com/sagemaker/pricing/. - instance_count (Optional[int]): The number of ML compute instances to use. For distributed - training, provide a value greater than 1. - volume_size_in_gb (Optional[int]): - The size of the ML storage volume that you want to provision. ML storage volumes store - model artifacts and incremental states. Training algorithms might also use the ML - storage volume for scratch space. Default: 30 - volume_kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage - volume attached to the ML compute instance(s) that run the training job. - keep_alive_period_in_seconds (Optional[int]): - The duration of time in seconds to retain configured resources in a warm pool for - subsequent training jobs. - instance_groups (Optional[List[InstanceGroup]]): - A list of instance groups for heterogeneous clusters to be used in the training job. - training_plan_arn (Optional[StrPipeVar]): - The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. - enable_managed_spot_training (Optional[bool]): - To train models using managed spot training, choose True. Managed spot training - provides a fully managed and scalable infrastructure for training machine learning - models. this option is useful when training jobs can be interrupted and when there - is flexibility when the training job is run. - """ - - volume_size_in_gb: Optional[int] = 30 - enable_managed_spot_training: Optional[bool] = None - - @model_validator(mode="after") - def _model_validator(self) -> "Compute": - """Convert Unassigned values to None.""" - return convert_unassigned_to_none(self) - - def _to_resource_config(self) -> shapes.ResourceConfig: - """Convert to a sagemaker.core.shapes.ResourceConfig object.""" - compute_config_dict = self.model_dump() - resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) - filtered_dict = { - k: v - for k, v in compute_config_dict.items() - if k in resource_config_fields and v is not None - } - if not filtered_dict: - return None - return shapes.ResourceConfig(**filtered_dict) - - -class Networking(shapes.VpcConfig): - """Networking. - - The Networking class is a subclass of ``sagemaker.core.shapes.VpcConfig`` and - allows the user to specify the networking configuration for the training job. - - Parameters: - security_group_ids (Optional[List[StrPipeVar]]): - The VPC security group IDs, in the form sg-xxxxxxxx. Specify the - security groups for the VPC that is specified in the Subnets field. - subnets (Optional[List[StrPipeVar]]): - The ID of the subnets in the VPC to which you want to connect your - training job or model. - enable_network_isolation (Optional[bool]): - Isolates the training container. No inbound or outbound network calls can be made, - except for calls between peers within a training cluster for distributed training. - If you enable network isolation for training jobs that are configured to use a VPC, - SageMaker downloads and uploads customer data and model artifacts through the - specified VPC, but the training container does not have network access. - enable_inter_container_traffic_encryption (Optional[bool]): - To encrypt all communications between ML compute instances in distributed training - choose True. Encryption provides greater security for distributed training, but - training might take longer. How long it takes depends on the amount of - communication between compute instances, especially if you use a deep learning - algorithm in distributed training. - """ - - security_group_ids: Optional[list[StrPipeVar]] = None - subnets: Optional[list[StrPipeVar]] = None - enable_network_isolation: Optional[bool] = None - enable_inter_container_traffic_encryption: Optional[bool] = None - - @model_validator(mode="after") - def _model_validator(self) -> "Networking": - """Convert Unassigned values to None.""" - return convert_unassigned_to_none(self) - - def _to_vpc_config(self) -> shapes.VpcConfig: - """Convert to a sagemaker.core.shapes.VpcConfig object.""" - compute_config_dict = self.model_dump() - vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) - filtered_dict = { - k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None - } - if not filtered_dict: - return None - return shapes.VpcConfig(**filtered_dict) - - -class InputData(BaseConfig): - """InputData. - - This config allows the user to specify an input data source for the training job. - - Will be found at ``/opt/ml/input/data/`` within the training container. - For convience, can be referenced inside the training container like: - - .. code:: python - - import os - input_data_dir = os.environ['SM_CHANNEL_'] - - Parameters: - channel_name (StrPipeVar): - The name of the input data source channel. - data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]): - The data source for the channel. Can be an S3 URI string, local file path string, - S3DataSource object, FileSystemDataSource object, DatasetSource object, or a - pipeline variable (Properties) from a previous step. - content_type (StrPipeVar): - The MIME type of the data. - """ - - channel_name: StrPipeVar = None - data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None - content_type: StrPipeVar = None - - -class OutputDataConfig(shapes.OutputDataConfig): - """OutputDataConfig. - - The OutputDataConfig class is a subclass of ``sagemaker.core.shapes.OutputDataConfig`` - and allows the user to specify the output data configuration for the training job. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - The S3 URI where the output data will be stored. This is the location where the - training job will save its output data, such as model artifacts and logs. - kms_key_id (Optional[StrPipeVar]): - The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that - SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side - encryption. - compression_type (Optional[StrPipeVar]): - The model output compression type. Select `NONE` to output an uncompressed model, - recommended for large model outputs. Defaults to `GZIP`. - """ - - s3_output_path: Optional[StrPipeVar] = None - kms_key_id: Optional[StrPipeVar] = None - compression_type: Optional[StrPipeVar] = None - - -class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): - """TensorBoardOutputConfig. - - The TensorBoardOutputConfig class is a subclass of ``sagemaker.core.shapes.TensorBoardOutputConfig`` - and allows the user to specify the storage locations for the Amazon SageMaker - Debugger TensorBoard. - - Parameters: - s3_output_path (Optional[StrPipeVar]): - Path to Amazon S3 storage location for TensorBoard output. If not specified, will - default to - ``s3://////tensorboard-output`` - local_path (Optional[StrPipeVar]): - Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. - """ - - s3_output_path: Optional[StrPipeVar] = None - local_path: Optional[StrPipeVar] = "/opt/ml/output/tensorboard" - - -class CheckpointConfig(shapes.CheckpointConfig): - """CheckpointConfig. - - The CheckpointConfig class is a subclass of ``sagemaker.core.shapes.CheckpointConfig`` - and allows the user to specify the checkpoint configuration for the training job. - - Parameters: - s3_uri (Optional[StrPipeVar]): - Path to Amazon S3 storage location for the Checkpoint data. If not specified, will - default to - ``s3://////checkpoints`` - local_path (Optional[StrPipeVar]): - The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. - """ - - s3_uri: Optional[StrPipeVar] = None - local_path: Optional[StrPipeVar] = "/opt/ml/checkpoints" diff --git a/sagemaker-core/src/sagemaker/core/training/constants.py b/sagemaker-core/src/sagemaker/core/training/constants.py deleted file mode 100644 index 76a866581e..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/constants.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Constants module.""" -from __future__ import absolute_import -import os - -DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" - -SM_CODE = "code" -SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" - -SM_DRIVERS = "sm_drivers" -SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers" -SM_DRIVERS_LOCAL_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "container_drivers" -) - -SOURCE_CODE_JSON = "sourcecode.json" -DISTRIBUTED_JSON = "distributed.json" -TRAIN_SCRIPT = "sm_train.sh" - -DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"] -DEFAULT_CONTAINER_ARGUMENTS = [ - "-c", - f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} " - + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", -] diff --git a/sagemaker-core/src/sagemaker/core/training/utils.py b/sagemaker-core/src/sagemaker/core/training/utils.py deleted file mode 100644 index 67009c9131..0000000000 --- a/sagemaker-core/src/sagemaker/core/training/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Training utilities.""" -from __future__ import absolute_import - -import os -from typing import Any, Literal -from sagemaker.core.utils.utils import Unassigned - - -def convert_unassigned_to_none(instance) -> Any: - """Convert Unassigned values to None for any instance.""" - for name, value in instance.__dict__.items(): - if isinstance(value, Unassigned): - setattr(instance, name, None) - return instance - - -def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: - """Check if the path is a valid local path. - - Args: - path (str): Local path to validate - path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. - Defaults to "Any". - - Returns: - bool: True if the path is a valid local path, False otherwise - """ - if not os.path.exists(path): - return False - - if path_type == "File": - return os.path.isfile(path) - if path_type == "Directory": - return os.path.isdir(path) - - return path_type == "Any" - - -def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: - """Check if the path is a valid S3 URI. - - This method checks if the path is a valid S3 URI. If the path_type is specified, - it will also check if the path is a file or a directory. - This method does not check if the S3 bucket or object exists. - - Args: - path (str): S3 URI to validate - path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. - Defaults to "Any". - - Returns: - bool: True if the path is a valid S3 URI, False otherwise - """ - # Check if the path is a valid S3 URI - if not path.startswith("s3://"): - return False - - if path_type == "File": - # If it's a file, it should not end with a slash - return not path.endswith("/") - if path_type == "Directory": - # If it's a directory, it should end with a slash - return path.endswith("/") - - return path_type == "Any" diff --git a/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py b/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py index efb0b8b6ef..380ad0c280 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py +++ b/sagemaker-core/src/sagemaker/core/workflow/execution_variables.py @@ -56,7 +56,7 @@ def expr(self) -> RequestType: def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import _ExecutionVariable + from sagemaker.train.remote_function.core.pipeline_variables import _ExecutionVariable return _ExecutionVariable(name=self.name) diff --git a/sagemaker-core/src/sagemaker/core/workflow/parameters.py b/sagemaker-core/src/sagemaker/core/workflow/parameters.py index 90505c99cc..81d6d59d94 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/parameters.py +++ b/sagemaker-core/src/sagemaker/core/workflow/parameters.py @@ -96,7 +96,7 @@ def expr(self) -> Dict[str, str]: def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import ( + from sagemaker.train.remote_function.core.pipeline_variables import ( _ParameterString, _ParameterInteger, _ParameterBoolean, diff --git a/sagemaker-core/src/sagemaker/core/workflow/properties.py b/sagemaker-core/src/sagemaker/core/workflow/properties.py index c9e897e178..366cecfd3f 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/properties.py +++ b/sagemaker-core/src/sagemaker/core/workflow/properties.py @@ -137,7 +137,7 @@ def __reduce__(self): def _pickleable(self): """The pickleable object that can be passed to a remote function invocation.""" - from sagemaker.core.remote_function.core.pipeline_variables import _Properties + from sagemaker.train.remote_function.core.pipeline_variables import _Properties prefix = f"Steps.{self.step_name}" full_path = prefix if self.path is None else f"{prefix}.{self.path}" diff --git a/sagemaker-core/src/sagemaker/lineage/__init__.py b/sagemaker-core/src/sagemaker/lineage/__init__.py deleted file mode 100644 index 4d9cec4b6c..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage module - compatibility shim. - -This module provides backward compatibility for code using the old -`sagemaker.lineage` import path. All functionality has been moved to -`sagemaker.core.lineage`. - -DEPRECATED: This module is deprecated. Use `sagemaker.core.lineage` instead. -""" -from __future__ import absolute_import - -import warnings - -# Show deprecation warning -warnings.warn( - "The 'sagemaker.lineage' module is deprecated. " "Please use 'sagemaker.core.lineage' instead.", - DeprecationWarning, - stacklevel=2, -) - -# Re-export from core.lineage for backward compatibility -from sagemaker.core.lineage import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/action.py b/sagemaker-core/src/sagemaker/lineage/action.py deleted file mode 100644 index c14ffa2a69..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/action.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.action module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.action` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.action' module is deprecated. " - "Please use 'sagemaker.core.lineage.action' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.action import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/artifact.py b/sagemaker-core/src/sagemaker/lineage/artifact.py deleted file mode 100644 index 4d74205fc5..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/artifact.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.artifact module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.artifact` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.artifact' module is deprecated. " - "Please use 'sagemaker.core.lineage.artifact' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.artifact import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/context.py b/sagemaker-core/src/sagemaker/lineage/context.py deleted file mode 100644 index d5fe8b3884..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/context.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.context module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.context` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.context' module is deprecated. " - "Please use 'sagemaker.core.lineage.context' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.context import * # noqa: F401, F403 diff --git a/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py b/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py deleted file mode 100644 index b729166f2c..0000000000 --- a/sagemaker-core/src/sagemaker/lineage/lineage_trial_component.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Legacy lineage.lineage_trial_component module - compatibility shim. - -DEPRECATED: Use `sagemaker.core.lineage.lineage_trial_component` instead. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "The 'sagemaker.lineage.lineage_trial_component' module is deprecated. " - "Please use 'sagemaker.core.lineage.lineage_trial_component' instead.", - DeprecationWarning, - stacklevel=2, -) - -from sagemaker.core.lineage.lineage_trial_component import * # noqa: F401, F403 diff --git a/sagemaker-core/tests/unit/lineage/test_query.py b/sagemaker-core/tests/unit/lineage/test_query.py index 2dbeea28d4..e6fdfc5eb9 100644 --- a/sagemaker-core/tests/unit/lineage/test_query.py +++ b/sagemaker-core/tests/unit/lineage/test_query.py @@ -146,7 +146,7 @@ def test_vertex_hash(self): vertex_set = {vertex1, vertex2} assert len(vertex_set) == 1 - @patch("sagemaker.lineage.context.EndpointContext") + @patch("sagemaker.core.lineage.context.EndpointContext") def test_to_lineage_object_context(self, mock_endpoint_context_class): """Test converting vertex to Context""" mock_session = Mock() @@ -164,7 +164,7 @@ def test_to_lineage_object_context(self, mock_endpoint_context_class): # Should call EndpointContext.load for Endpoint source assert result is not None - @patch("sagemaker.lineage.action.Action") + @patch("sagemaker.core.lineage.action.Action") def test_to_lineage_object_action(self, mock_action_class): """Test converting vertex to Action""" mock_session = Mock() diff --git a/sagemaker-core/tests/unit/modules/train/__init__.py b/sagemaker-core/tests/unit/modules/train/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/modules/train/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/__init__.py b/sagemaker-core/tests/unit/modules/train/container_drivers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/__init__.py b/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py b/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py deleted file mode 100644 index ee3d60c6f2..0000000000 --- a/sagemaker-core/tests/unit/modules/train/container_drivers/distributed_drivers/test_mpi_utils.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils module.""" -from __future__ import absolute_import - -import pytest -import os -import subprocess -import paramiko -from unittest.mock import Mock, patch, MagicMock, call - -from sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils import ( - _write_file_to_host, - write_status_file_to_workers, - _wait_for_status_file, - start_sshd_daemon, - CustomHostKeyPolicy, - _can_connect, - _wait_for_workers, - _wait_for_master, - bootstrap_worker_node, - bootstrap_master_node, - validate_smddprun, - validate_smddpmprun, - write_env_vars_to_file, - get_mpirun_command, - FINISHED_STATUS_FILE, - READY_FILE, - DEFAULT_SSH_PORT, -) - - -class TestWriteFileToHost: - """Test _write_file_to_host function.""" - - @patch("subprocess.run") - def test_write_file_to_host_success(self, mock_run): - """Test successful file write to host.""" - mock_run.return_value = Mock(returncode=0) - - result = _write_file_to_host("algo-1", "/tmp/test.txt") - - assert result is True - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_write_file_to_host_failure(self, mock_run): - """Test failed file write to host.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - - result = _write_file_to_host("algo-1", "/tmp/test.txt") - - assert result is False - - -class TestWriteStatusFileToWorkers: - """Test write_status_file_to_workers function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - def test_write_status_file_to_workers_success(self, mock_write): - """Test writing status file to workers successfully.""" - mock_write.return_value = True - - write_status_file_to_workers(["algo-1", "algo-2"]) - - assert mock_write.call_count == 2 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch("time.sleep") - def test_write_status_file_to_workers_with_retry(self, mock_sleep, mock_write): - """Test writing status file with retry.""" - mock_write.side_effect = [False, False, True] - - write_status_file_to_workers(["algo-1"]) - - assert mock_write.call_count == 3 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch("time.sleep") - def test_write_status_file_to_workers_timeout(self, mock_sleep, mock_write): - """Test writing status file timeout.""" - mock_write.return_value = False - - with pytest.raises(TimeoutError): - write_status_file_to_workers(["algo-1"]) - - -class TestWaitForStatusFile: - """Test _wait_for_status_file function.""" - - @patch("os.path.exists") - @patch("time.sleep") - def test_wait_for_status_file_exists(self, mock_sleep, mock_exists): - """Test waiting for status file that exists.""" - mock_exists.return_value = True - - _wait_for_status_file("/tmp/test.txt") - - mock_exists.assert_called_once() - - @patch("os.path.exists") - @patch("time.sleep") - def test_wait_for_status_file_eventually_exists(self, mock_sleep, mock_exists): - """Test waiting for status file that eventually exists.""" - mock_exists.side_effect = [False, False, True] - - _wait_for_status_file("/tmp/test.txt") - - assert mock_exists.call_count == 3 - - -class TestStartSshdDaemon: - """Test start_sshd_daemon function.""" - - @patch("os.path.exists") - @patch("subprocess.Popen") - def test_start_sshd_daemon_success(self, mock_popen, mock_exists): - """Test starting SSH daemon successfully.""" - mock_exists.return_value = True - - start_sshd_daemon() - - mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - - @patch("os.path.exists") - def test_start_sshd_daemon_not_found(self, mock_exists): - """Test starting SSH daemon when not found.""" - mock_exists.return_value = False - - with pytest.raises(RuntimeError, match="SSH daemon not found"): - start_sshd_daemon() - - -class TestCustomHostKeyPolicy: - """Test CustomHostKeyPolicy class.""" - - def test_custom_host_key_policy_algo_hostname(self): - """Test accepting algo-* hostnames.""" - policy = CustomHostKeyPolicy() - mock_client = Mock() - mock_client.get_host_keys.return_value = Mock() - mock_key = Mock() - mock_key.get_name.return_value = "ssh-rsa" - - # Should not raise exception - policy.missing_host_key(mock_client, "algo-1234", mock_key) - - def test_custom_host_key_policy_unknown_hostname(self): - """Test rejecting unknown hostnames.""" - policy = CustomHostKeyPolicy() - mock_client = Mock() - mock_key = Mock() - - with pytest.raises(paramiko.SSHException): - policy.missing_host_key(mock_client, "unknown-host", mock_key) - - -class TestCanConnect: - """Test _can_connect function.""" - - @patch("paramiko.SSHClient") - def test_can_connect_success(self, mock_ssh_client): - """Test successful connection.""" - mock_client_instance = Mock() - mock_ssh_client.return_value.__enter__.return_value = mock_client_instance - - result = _can_connect("algo-1") - - assert result is True - - @patch("paramiko.SSHClient") - def test_can_connect_failure(self, mock_ssh_client): - """Test failed connection.""" - mock_client_instance = Mock() - mock_client_instance.connect.side_effect = Exception("Connection failed") - mock_ssh_client.return_value.__enter__.return_value = mock_client_instance - - result = _can_connect("algo-1") - - assert result is False - - -class TestWaitForWorkers: - """Test _wait_for_workers function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("os.path.exists") - def test_wait_for_workers_success(self, mock_exists, mock_connect): - """Test waiting for workers successfully.""" - mock_connect.return_value = True - mock_exists.return_value = True - - _wait_for_workers(["algo-1", "algo-2"]) - - assert mock_connect.call_count >= 2 - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("os.path.exists") - @patch("time.sleep") - @patch("time.time") - def test_wait_for_workers_timeout(self, mock_time, mock_sleep, mock_exists, mock_connect): - """Test waiting for workers timeout.""" - mock_connect.return_value = False - mock_exists.return_value = False - # Use side_effect with a generator to provide unlimited values - mock_time.side_effect = (i * 200 for i in range(1000)) # Simulate timeout - - with pytest.raises(TimeoutError): - _wait_for_workers(["algo-1"]) - - def test_wait_for_workers_empty_list(self): - """Test waiting for workers with empty list.""" - # Should not raise exception - _wait_for_workers([]) - - -class TestWaitForMaster: - """Test _wait_for_master function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - def test_wait_for_master_success(self, mock_connect): - """Test waiting for master successfully.""" - mock_connect.return_value = True - - _wait_for_master("algo-1") - - mock_connect.assert_called() - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._can_connect" - ) - @patch("time.sleep") - @patch("time.time") - def test_wait_for_master_timeout(self, mock_time, mock_sleep, mock_connect): - """Test waiting for master timeout.""" - mock_connect.return_value = False - # Use side_effect with a generator to provide unlimited values - mock_time.side_effect = (i * 200 for i in range(1000)) # Simulate timeout - - with pytest.raises(TimeoutError): - _wait_for_master("algo-1") - - -class TestBootstrapWorkerNode: - """Test bootstrap_worker_node function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_master" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._write_file_to_host" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_status_file" - ) - @patch.dict(os.environ, {"SM_CURRENT_HOST": "algo-2"}) - def test_bootstrap_worker_node(self, mock_wait_status, mock_write, mock_wait_master): - """Test bootstrapping worker node.""" - mock_write.return_value = True - - bootstrap_worker_node("algo-1") - - mock_wait_master.assert_called_once_with("algo-1") - mock_write.assert_called_once() - mock_wait_status.assert_called_once() - - -class TestBootstrapMasterNode: - """Test bootstrap_master_node function.""" - - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils._wait_for_workers" - ) - def test_bootstrap_master_node(self, mock_wait): - """Test bootstrapping master node.""" - bootstrap_master_node(["algo-2", "algo-3"]) - - mock_wait.assert_called_once_with(["algo-2", "algo-3"]) - - -class TestValidateSmddprun: - """Test validate_smddprun function.""" - - @patch("subprocess.run") - def test_validate_smddprun_installed(self, mock_run): - """Test validating smddprun when installed.""" - mock_run.return_value = Mock(stdout="smddprun") - - result = validate_smddprun() - - assert result is True - - @patch("subprocess.run") - def test_validate_smddprun_not_installed(self, mock_run): - """Test validating smddprun when not installed.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "which") - - result = validate_smddprun() - - assert result is False - - -class TestValidateSmddpmprun: - """Test validate_smddpmprun function.""" - - @patch("subprocess.run") - def test_validate_smddpmprun_installed(self, mock_run): - """Test validating smddpmprun when installed.""" - mock_run.return_value = Mock(stdout="smddpmprun") - - result = validate_smddpmprun() - - assert result is True - - @patch("subprocess.run") - def test_validate_smddpmprun_not_installed(self, mock_run): - """Test validating smddpmprun when not installed.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "which") - - result = validate_smddpmprun() - - assert result is False - - -class TestWriteEnvVarsToFile: - """Test write_env_vars_to_file function.""" - - @patch("builtins.open", create=True) - @patch.dict(os.environ, {"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}) - def test_write_env_vars_to_file(self, mock_open_func): - """Test writing environment variables to file.""" - mock_file = MagicMock() - mock_open_func.return_value.__enter__.return_value = mock_file - - write_env_vars_to_file() - - mock_open_func.assert_called_once_with("/etc/environment", "a", encoding="utf-8") - assert mock_file.write.called - - -class TestGetMpirunCommand: - """Test get_mpirun_command function.""" - - @patch.dict( - os.environ, - {"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge"}, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_basic(self, mock_python): - """Test getting basic mpirun command.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "mpirun" in result - assert "--host" in result - assert "algo-1,algo-2" in result - assert "-np" in result - assert "4" in result - - @patch.dict( - os.environ, - { - "SM_NETWORK_INTERFACE_NAME": "eth0", - "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", - "AWS_ACCESS_KEY_ID": "test_key", - "AWS_SECRET_ACCESS_KEY": "test_secret", - }, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_efa(self, mock_python): - """Test getting mpirun command with EFA instance.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "FI_PROVIDER=efa" in result - - @patch.dict( - os.environ, - {"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge"}, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_additional_options(self, mock_python): - """Test getting mpirun command with additional options.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=["-x", "CUSTOM_VAR"], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "-x" in result - assert "CUSTOM_VAR" in result - - @patch.dict( - os.environ, - { - "SM_NETWORK_INTERFACE_NAME": "eth0", - "SM_CURRENT_INSTANCE_TYPE": "ml.p3.2xlarge", - "AWS_ACCESS_KEY_ID": "test_key", - "AWS_SECRET_ACCESS_KEY": "test_secret", - "AWS_SESSION_TOKEN": "test_token", - }, - ) - @patch( - "sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_python_executable" - ) - def test_get_mpirun_command_with_credentials(self, mock_python): - """Test getting mpirun command with AWS credentials.""" - mock_python.return_value = "/usr/bin/python3" - - result = get_mpirun_command( - host_count=2, - host_list=["algo-1", "algo-2"], - num_processes=4, - additional_options=[], - entry_script_path="/opt/ml/code/train.py", - ) - - assert "AWS_ACCESS_KEY_ID" in result - assert "AWS_SECRET_ACCESS_KEY" in result - assert "AWS_SESSION_TOKEN" in result diff --git a/sagemaker-core/tests/unit/modules/train/test_environment.py b/sagemaker-core/tests/unit/modules/train/test_environment.py deleted file mode 100644 index e80eeef005..0000000000 --- a/sagemaker-core/tests/unit/modules/train/test_environment.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -import json -import os -from unittest.mock import Mock, patch, mock_open, MagicMock -from sagemaker.core.modules.train.container_drivers.scripts.environment import ( - num_cpus, - num_gpus, - num_neurons, - deserialize_hyperparameters, - set_env, - mask_sensitive_info, - log_key_value, - log_env_variables, -) - - -class TestEnvironment: - """Test cases for environment module""" - - def test_num_cpus(self): - """Test num_cpus returns positive integer""" - result = num_cpus() - assert isinstance(result, int) - assert result > 0 - - @patch("subprocess.check_output") - def test_num_gpus_with_gpus(self, mock_check_output): - """Test num_gpus when GPUs are available""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" - - result = num_gpus() - assert result == 2 - - @patch("subprocess.check_output") - def test_num_gpus_no_gpus(self, mock_check_output): - """Test num_gpus when no GPUs are available""" - mock_check_output.side_effect = OSError("nvidia-smi not found") - - result = num_gpus() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_gpus_command_error(self, mock_check_output): - """Test num_gpus when command fails""" - import subprocess - - mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") - - result = num_gpus() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_with_neurons(self, mock_check_output): - """Test num_neurons when Neuron cores are available""" - mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 2}]) - mock_check_output.return_value = mock_output.encode() - - result = num_neurons() - assert result == 4 - - @patch("subprocess.check_output") - def test_num_neurons_no_neurons(self, mock_check_output): - """Test num_neurons when no Neuron cores are available""" - mock_check_output.side_effect = OSError("neuron-ls not found") - - result = num_neurons() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_command_error(self, mock_check_output): - """Test num_neurons when command fails""" - import subprocess - - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = b"error=No Neuron devices found" - mock_check_output.side_effect = error - - result = num_neurons() - assert result == 0 - - @patch("subprocess.check_output") - def test_num_neurons_command_error_no_output(self, mock_check_output): - """Test num_neurons when command fails without output""" - import subprocess - - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = None - mock_check_output.side_effect = error - - result = num_neurons() - assert result == 0 - - def test_deserialize_hyperparameters_simple(self): - """Test deserialize_hyperparameters with simple types""" - hyperparameters = {"learning_rate": "0.001", "epochs": "10", "batch_size": "32"} - - result = deserialize_hyperparameters(hyperparameters) - - assert result["learning_rate"] == 0.001 - assert result["epochs"] == 10 - assert result["batch_size"] == 32 - - def test_deserialize_hyperparameters_complex(self): - """Test deserialize_hyperparameters with complex types""" - hyperparameters = { - "layers": "[128, 64, 32]", - "config": '{"optimizer": "adam", "loss": "mse"}', - "enabled": "true", - } - - result = deserialize_hyperparameters(hyperparameters) - - assert result["layers"] == [128, 64, 32] - assert result["config"] == {"optimizer": "adam", "loss": "mse"} - assert result["enabled"] is True - - def test_mask_sensitive_info_with_password(self): - """Test mask_sensitive_info masks password fields""" - data = {"username": "user", "password": "secret123", "api_key": "key123"} - - result = mask_sensitive_info(data) - - assert result["username"] == "user" - assert result["password"] == "******" - assert result["api_key"] == "******" - - def test_mask_sensitive_info_nested(self): - """Test mask_sensitive_info with nested dictionaries""" - data = {"config": {"db_password": "secret", "db_host": "localhost"}} - - result = mask_sensitive_info(data) - - assert result["config"]["db_password"] == "******" - assert result["config"]["db_host"] == "localhost" - - def test_mask_sensitive_info_case_insensitive(self): - """Test mask_sensitive_info is case insensitive""" - data = {"API_KEY": "key123", "Secret_Token": "token123"} - - result = mask_sensitive_info(data) - - assert result["API_KEY"] == "******" - assert result["Secret_Token"] == "******" - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_sensitive(self, mock_logger): - """Test log_key_value masks sensitive values""" - log_key_value("password", "secret123") - - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "******" in str(call_args) - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_dict(self, mock_logger): - """Test log_key_value with dictionary value""" - log_key_value("config", {"key": "value"}) - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_json_string(self, mock_logger): - """Test log_key_value with JSON string value""" - log_key_value("config", '{"key": "value"}') - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - def test_log_key_value_regular(self, mock_logger): - """Test log_key_value with regular value""" - log_key_value("learning_rate", "0.001") - - mock_logger.info.assert_called_once() - - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.logger") - @patch.dict(os.environ, {"TEST_VAR": "test_value"}) - def test_log_env_variables(self, mock_logger): - """Test log_env_variables logs both environment and dict variables""" - env_vars_dict = {"CUSTOM_VAR": "custom_value"} - - log_env_variables(env_vars_dict) - - # Should be called for both os.environ and env_vars_dict - assert mock_logger.info.call_count > 0 - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_minimal( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with minimal configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - input_data_config = {"training": {"S3Uri": "s3://bucket/data"}} - - hyperparameters_config = {"learning_rate": "0.001", "epochs": "10"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - handle = mock_file() - assert handle.write.called - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_with_source_code( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with source code configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 1 - mock_neurons.return_value = 0 - mock_source_code.return_value = {"entry_script": "train.py"} - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - input_data_config = {"training": {"S3Uri": "s3://bucket/data"}} - - hyperparameters_config = {"learning_rate": "0.001"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_with_distributed( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with distributed configuration""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 4 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = {"smdistributed": {"dataparallel": {"enabled": True}}} - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.8xlarge", - "hosts": ["algo-1", "algo-2", "algo-3"], - "network_interface_name": "eth0", - } - - input_data_config = { - "training": {"S3Uri": "s3://bucket/data"}, - "validation": {"S3Uri": "s3://bucket/validation"}, - } - - hyperparameters_config = {"learning_rate": "0.001", "batch_size": "64"} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_source_code_json" - ) - @patch( - "sagemaker.core.modules.train.container_drivers.scripts.environment.read_distributed_json" - ) - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_cpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_gpus") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.num_neurons") - @patch("sagemaker.core.modules.train.container_drivers.scripts.environment.log_env_variables") - @patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}) - def test_set_env_multiple_channels( - self, - mock_log_env, - mock_neurons, - mock_gpus, - mock_cpus, - mock_distributed, - mock_source_code, - mock_file, - ): - """Test set_env with multiple data channels""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 - mock_neurons.return_value = 0 - mock_source_code.return_value = None - mock_distributed.return_value = None - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - input_data_config = { - "training": {"S3Uri": "s3://bucket/train"}, - "validation": {"S3Uri": "s3://bucket/val"}, - "test": {"S3Uri": "s3://bucket/test"}, - } - - hyperparameters_config = {} - - set_env(resource_config, input_data_config, hyperparameters_config) - - # Verify file was written - mock_file.assert_called_once() diff --git a/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py b/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py deleted file mode 100644 index 737aa60927..0000000000 --- a/sagemaker-core/tests/unit/modules/train/test_sm_recipes_utils.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -import tempfile -from unittest.mock import Mock, patch, MagicMock -from omegaconf import OmegaConf -from sagemaker.core.modules.train.sm_recipes.utils import ( - _try_resolve_recipe, - _determine_device_type, - _load_recipes_cfg, - _load_base_recipe, - _register_custom_resolvers, - _get_trainining_recipe_gpu_model_name_and_script, - _configure_gpu_args, - _configure_trainium_args, - _get_args_from_recipe, -) -from sagemaker.core.modules.configs import Compute - - -class TestSMRecipesUtils: - """Test cases for SM recipes utility functions""" - - def test_try_resolve_recipe_success(self): - """Test _try_resolve_recipe with resolvable recipe""" - recipe = OmegaConf.create({"value": 10, "doubled": "${value}"}) - - result = _try_resolve_recipe(recipe) - - assert result is not None - assert result["doubled"] == 10 - - def test_try_resolve_recipe_with_key(self): - """Test _try_resolve_recipe with key parameter""" - recipe = 10 - - result = _try_resolve_recipe(recipe, key="test") - - assert result is not None - assert result == 10 - - def test_try_resolve_recipe_unresolvable(self): - """Test _try_resolve_recipe with unresolvable recipe""" - recipe = OmegaConf.create({"value": "${missing_var}"}) - - result = _try_resolve_recipe(recipe) - - assert result is None - - def test_determine_device_type_gpu_p_instance(self): - """Test _determine_device_type with P instance (GPU)""" - result = _determine_device_type("ml.p3.2xlarge") - assert result == "gpu" - - def test_determine_device_type_gpu_g_instance(self): - """Test _determine_device_type with G instance (GPU)""" - result = _determine_device_type("ml.g4dn.xlarge") - assert result == "gpu" - - def test_determine_device_type_trainium(self): - """Test _determine_device_type with Trainium instance""" - result = _determine_device_type("ml.trn1.2xlarge") - assert result == "trainium" - - def test_determine_device_type_cpu(self): - """Test _determine_device_type with CPU instance""" - result = _determine_device_type("ml.m5.xlarge") - assert result == "cpu" - - @patch("sagemaker.core.modules.train.sm_recipes.utils.open") - @patch("sagemaker.core.modules.train.sm_recipes.utils.json.load") - def test_load_recipes_cfg(self, mock_json_load, mock_open): - """Test _load_recipes_cfg loads configuration""" - mock_json_load.return_value = {"launcher_repo": "test_repo", "adapter_repo": "test_adapter"} - - result = _load_recipes_cfg() - - assert isinstance(result, dict) - assert "launcher_repo" in result or "adapter_repo" in result or "neuron_dist_repo" in result - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.shutil.copy") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.load") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.merge") - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.unlink") - def test_load_base_recipe_from_file( - self, mock_unlink, mock_merge, mock_load, mock_copy, mock_isfile - ): - """Test _load_base_recipe from local file""" - mock_isfile.return_value = True - mock_recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - mock_load.return_value = mock_recipe - mock_merge.return_value = mock_recipe - - result = _load_base_recipe("recipe.yaml") - - assert result is not None - mock_copy.assert_called_once() - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.urlretrieve") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.load") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.merge") - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.unlink") - def test_load_base_recipe_from_url( - self, mock_unlink, mock_merge, mock_load, mock_urlretrieve, mock_isfile - ): - """Test _load_base_recipe from URL""" - mock_isfile.return_value = False - mock_recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - mock_load.return_value = mock_recipe - mock_merge.return_value = mock_recipe - - result = _load_base_recipe("https://example.com/recipe.yaml") - - assert result is not None - mock_urlretrieve.assert_called_once() - - @patch("sagemaker.core.modules.train.sm_recipes.utils.os.path.isfile") - @patch("sagemaker.core.modules.train.sm_recipes.utils.urlretrieve") - def test_load_base_recipe_url_error(self, mock_urlretrieve, mock_isfile): - """Test _load_base_recipe raises error on URL fetch failure""" - mock_isfile.return_value = False - mock_urlretrieve.side_effect = Exception("Network error") - - with pytest.raises(ValueError, match="Could not fetch the provided recipe"): - _load_base_recipe("https://example.com/recipe.yaml") - - def test_register_custom_resolvers(self): - """Test _register_custom_resolvers registers OmegaConf resolvers""" - _register_custom_resolvers() - - # Test multiply resolver - recipe = OmegaConf.create({"a": 5, "b": "${multiply:${a},2}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 10 - - # Test divide_ceil resolver - recipe = OmegaConf.create({"a": 10, "b": "${divide_ceil:${a},3}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 4 - - # Test divide_floor resolver - recipe = OmegaConf.create({"a": 10, "b": "${divide_floor:${a},3}"}) - OmegaConf.resolve(recipe) - assert recipe["b"] == 3 - - # Test add resolver - recipe = OmegaConf.create({"a": "${add:1,2,3}"}) - OmegaConf.resolve(recipe) - assert recipe["a"] == 6 - - def test_get_trainining_recipe_gpu_model_name_and_script_llama(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Llama""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("llama_v3_8b") - - assert model_name == "llama" - assert script == "llama_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_mistral(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Mistral""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("mistral_7b") - - assert model_name == "mistral" - assert script == "mistral_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_mixtral(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for Mixtral""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("mixtral_8x7b") - - assert model_name == "mixtral" - assert script == "mixtral_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_deepseek(self): - """Test _get_trainining_recipe_gpu_model_name_and_script for DeepSeek""" - model_name, script = _get_trainining_recipe_gpu_model_name_and_script("deepseek_v2") - - assert model_name == "deepseek" - assert script == "deepseek_pretrain.py" - - def test_get_trainining_recipe_gpu_model_name_and_script_unsupported(self): - """Test _get_trainining_recipe_gpu_model_name_and_script with unsupported model""" - with pytest.raises(ValueError, match="Model type .* not supported"): - _get_trainining_recipe_gpu_model_name_and_script("unsupported_model") - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": {"framework": "pytorch", "version": "2.0", "additional_args": {}}, - } - - recipe = OmegaConf.create({"model": {"model_type": "llama_v3"}}) - - recipe_train_dir = tempfile.TemporaryDirectory() - mock_retrieve.return_value = "test-image:latest" - - result = _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - assert "source_code" in result - assert "training_image" in result - assert "distributed" in result - assert result["training_image"] == "test-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args_string_image(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args with string image config""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": "custom-image:latest", - } - - recipe = OmegaConf.create({"model": {"model_type": "mistral"}}) - - recipe_train_dir = tempfile.TemporaryDirectory() - - result = _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - assert result["training_image"] == "custom-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_gpu_args_missing_model(self, mock_retrieve, mock_clone): - """Test _configure_gpu_args raises error when model field is missing""" - training_recipes_cfg = { - "adapter_repo": "https://github.com/test/adapter", - "gpu_image": "test-image:latest", - } - - recipe = OmegaConf.create({}) - recipe_train_dir = tempfile.TemporaryDirectory() - - with pytest.raises(ValueError, match="does not contain required field model"): - _configure_gpu_args(training_recipes_cfg, "us-west-2", recipe, recipe_train_dir) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._run_clone_command_silent") - @patch("sagemaker.core.modules.train.sm_recipes.utils.retrieve") - def test_configure_trainium_args(self, mock_retrieve, mock_clone): - """Test _configure_trainium_args""" - training_recipes_cfg = { - "neuron_dist_repo": "https://github.com/test/neuron", - "neuron_image": {"framework": "pytorch", "version": "1.13", "additional_args": {}}, - } - - recipe_train_dir = tempfile.TemporaryDirectory() - mock_retrieve.return_value = "neuron-image:latest" - - result = _configure_trainium_args(training_recipes_cfg, "us-west-2", recipe_train_dir) - - assert "source_code" in result - assert "training_image" in result - assert "distributed" in result - assert result["training_image"] == "neuron-image:latest" - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_gpu_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.save") - def test_get_args_from_recipe_gpu( - self, - mock_save, - mock_resolve, - mock_register, - mock_configure_gpu, - mock_load_recipe, - mock_load_cfg, - ): - """Test _get_args_from_recipe for GPU instance""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=2) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create( - {"trainer": {"num_nodes": 1}, "model": {"model_type": "llama_v3"}} - ) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = mock_recipe - - mock_configure_gpu.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "test-image:latest", - "distributed": Mock(), - } - - result, temp_dir = _get_args_from_recipe( - training_recipe="llama_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - assert "source_code" in result - assert "training_image" in result - assert "compute" in result - assert "hyperparameters" in result - assert result["compute"].instance_count == 2 - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_trainium_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils.OmegaConf.save") - def test_get_args_from_recipe_trainium( - self, - mock_save, - mock_resolve, - mock_register, - mock_configure_trainium, - mock_load_recipe, - mock_load_cfg, - ): - """Test _get_args_from_recipe for Trainium instance""" - compute = Compute(instance_type="ml.trn1.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create({"trainer": {"num_nodes": 1}}) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = mock_recipe - - mock_configure_trainium.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "neuron-image:latest", - "distributed": Mock(), - } - - result, temp_dir = _get_args_from_recipe( - training_recipe="neuron_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - assert "source_code" in result - assert "training_image" in result - - def test_get_args_from_recipe_no_instance_type(self): - """Test _get_args_from_recipe raises error without instance_type""" - compute = Compute(instance_count=1) - - with pytest.raises(ValueError, match="Must set `instance_type`"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - def test_get_args_from_recipe_missing_trainer(self, mock_load_recipe, mock_load_cfg): - """Test _get_args_from_recipe raises error when trainer field is missing""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create({}) - mock_load_recipe.return_value = mock_recipe - - with pytest.raises(ValueError, match="does not contain required field trainer"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg") - @patch("sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe") - @patch("sagemaker.core.modules.train.sm_recipes.utils._configure_gpu_args") - @patch("sagemaker.core.modules.train.sm_recipes.utils._register_custom_resolvers") - @patch("sagemaker.core.modules.train.sm_recipes.utils._try_resolve_recipe") - def test_get_args_from_recipe_unresolvable( - self, mock_resolve, mock_register, mock_configure_gpu, mock_load_recipe, mock_load_cfg - ): - """Test _get_args_from_recipe raises error when recipe cannot be resolved""" - compute = Compute(instance_type="ml.p3.2xlarge", instance_count=1) - - mock_load_cfg.return_value = {} - mock_recipe = OmegaConf.create( - {"trainer": {"num_nodes": 1}, "model": {"model_type": "llama_v3"}} - ) - mock_load_recipe.return_value = mock_recipe - mock_resolve.return_value = None # Cannot resolve - - mock_configure_gpu.return_value = { - "source_code": Mock(source_dir="/tmp/source"), - "training_image": "test-image:latest", - "distributed": Mock(), - } - - with pytest.raises(RuntimeError, match="Could not resolve provided recipe"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) - - def test_get_args_from_recipe_cpu_not_supported(self): - """Test _get_args_from_recipe raises error for CPU instances""" - compute = Compute(instance_type="ml.m5.xlarge", instance_count=1) - - with patch("sagemaker.core.modules.train.sm_recipes.utils._load_recipes_cfg"): - with patch( - "sagemaker.core.modules.train.sm_recipes.utils._load_base_recipe" - ) as mock_load: - mock_load.return_value = OmegaConf.create({"trainer": {"num_nodes": 1}}) - - with pytest.raises(ValueError, match="Devices of type cpu are not supported"): - _get_args_from_recipe( - training_recipe="test_recipe", - compute=compute, - region_name="us-west-2", - recipe_overrides=None, - requirements=None, - ) diff --git a/sagemaker-core/tests/unit/remote_function/__init__.py b/sagemaker-core/tests/unit/remote_function/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/remote_function/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py deleted file mode 100644 index 6549052177..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py deleted file mode 100644 index 461a3ecb73..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for bootstrap_runtime_environment module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import subprocess -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment import ( - _parse_args, - _bootstrap_runtime_env_for_remote_function, - _bootstrap_runtime_env_for_pipeline_step, - _handle_pre_exec_scripts, - _install_dependencies, - _unpack_user_workspace, - _write_failure_reason_file, - log_key_value, - log_env_variables, - mask_sensitive_info, - num_cpus, - num_gpus, - num_neurons, - safe_serialize, - set_env, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - REMOTE_FUNCTION_WORKSPACE, - BASE_CHANNEL_PATH, - JOB_REMOTE_FUNCTION_WORKSPACE, - SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, - SENSITIVE_KEYWORDS, - HIDDEN_VALUE, -) -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, -) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_required_args(self): - """Test parsing required arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.client_python_version == "3.8" - - def test_parse_all_args(self): - """Test parsing all arguments.""" - args = [ - "--job_conda_env", "my-env", - "--client_python_version", "3.9", - "--client_sagemaker_pysdk_version", "2.100.0", - "--pipeline_execution_id", "exec-123", - "--dependency_settings", '{"dependency_file": "requirements.txt"}', - "--func_step_s3_dir", "s3://bucket/func", - "--distribution", "torchrun", - "--user_nproc_per_node", "4", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env == "my-env" - assert parsed.client_python_version == "3.9" - assert parsed.client_sagemaker_pysdk_version == "2.100.0" - assert parsed.pipeline_execution_id == "exec-123" - assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}' - assert parsed.func_step_s3_dir == "s3://bucket/func" - assert parsed.distribution == "torchrun" - assert parsed.user_nproc_per_node == "4" - - def test_parse_default_values(self): - """Test default values for optional arguments.""" - args = [ - "--client_python_version", "3.8", - ] - parsed = _parse_args(args) - assert parsed.job_conda_env is None - assert parsed.client_sagemaker_pysdk_version is None - assert parsed.pipeline_execution_id is None - assert parsed.dependency_settings is None - assert parsed.func_step_s3_dir is None - assert parsed.distribution is None - assert parsed.user_nproc_per_node is None - - -class TestLogKeyValue: - """Test log_key_value function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_regular_value(self, mock_logger): - """Test logs regular key-value pair.""" - log_key_value("my_name", "my_value") - mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value") - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_masks_sensitive_key(self, mock_logger): - """Test masks sensitive keywords.""" - for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]: - mock_logger.reset_mock() - log_key_value(f"my_{keyword}", "sensitive_value") - mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE) - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_dict_value(self, mock_logger): - """Test logs dictionary value.""" - value = {"field1": "value1", "field2": "value2"} - log_key_value("my_config", value) - mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value)) - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") - def test_logs_json_string_value(self, mock_logger): - """Test logs JSON string value.""" - value = '{"key1": "value1"}' - log_key_value("my_key", value) - mock_logger.info.assert_called_once() - - -class TestLogEnvVariables: - """Test log_env_variables function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value") - @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"}) - def test_logs_env_and_dict_variables(self, mock_log_kv): - """Test logs both environment and dictionary variables.""" - env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"} - log_env_variables(env_dict) - - # Should be called for env vars and dict vars - assert mock_log_kv.call_count >= 4 - - -class TestMaskSensitiveInfo: - """Test mask_sensitive_info function.""" - - def test_masks_sensitive_keys_in_dict(self): - """Test masks sensitive keys in dictionary.""" - data = { - "username": "user", - "password": "secret123", - "api_key": "key123", - } - result = mask_sensitive_info(data) - assert result["username"] == "user" - assert result["password"] == HIDDEN_VALUE - assert result["api_key"] == HIDDEN_VALUE - - def test_masks_nested_dict(self): - """Test masks sensitive keys in nested dictionary.""" - data = { - "config": { - "username": "user", - "secret": "secret123", - } - } - result = mask_sensitive_info(data) - assert result["config"]["username"] == "user" - assert result["config"]["secret"] == HIDDEN_VALUE - - def test_returns_non_dict_unchanged(self): - """Test returns non-dictionary unchanged.""" - data = "string_value" - result = mask_sensitive_info(data) - assert result == "string_value" - - -class TestNumCpus: - """Test num_cpus function.""" - - @patch("multiprocessing.cpu_count") - def test_returns_cpu_count(self, mock_cpu_count): - """Test returns CPU count.""" - mock_cpu_count.return_value = 8 - assert num_cpus() == 8 - - -class TestNumGpus: - """Test num_gpus function.""" - - @patch("subprocess.check_output") - def test_returns_gpu_count(self, mock_check_output): - """Test returns GPU count.""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" - assert num_gpus() == 2 - - @patch("subprocess.check_output") - def test_returns_zero_on_error(self, mock_check_output): - """Test returns zero when nvidia-smi fails.""" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") - assert num_gpus() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when nvidia-smi not found.""" - mock_check_output.side_effect = OSError() - assert num_gpus() == 0 - - -class TestNumNeurons: - """Test num_neurons function.""" - - @patch("subprocess.check_output") - def test_returns_neuron_count(self, mock_check_output): - """Test returns neuron core count.""" - mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}]) - mock_check_output.return_value = mock_output.encode("utf-8") - assert num_neurons() == 6 - - @patch("subprocess.check_output") - def test_returns_zero_on_os_error(self, mock_check_output): - """Test returns zero when neuron-ls not found.""" - mock_check_output.side_effect = OSError() - assert num_neurons() == 0 - - @patch("subprocess.check_output") - def test_returns_zero_on_called_process_error(self, mock_check_output): - """Test returns zero when neuron-ls fails.""" - error = subprocess.CalledProcessError(1, "neuron-ls") - error.output = b"error=No neuron devices found" - mock_check_output.side_effect = error - assert num_neurons() == 0 - - -class TestSafeSerialize: - """Test safe_serialize function.""" - - def test_returns_string_as_is(self): - """Test returns string without quotes.""" - assert safe_serialize("test_string") == "test_string" - - def test_serializes_dict(self): - """Test serializes dictionary.""" - data = {"key": "value"} - assert safe_serialize(data) == '{"key": "value"}' - - def test_serializes_list(self): - """Test serializes list.""" - data = [1, 2, 3] - assert safe_serialize(data) == "[1, 2, 3]" - - def test_returns_str_for_non_serializable(self): - """Test returns str() for non-serializable objects.""" - class CustomObj: - def __str__(self): - return "custom_object" - - obj = CustomObj() - assert safe_serialize(obj) == "custom_object" - - -class TestSetEnv: - """Test set_env function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets basic environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config) - - mock_file.assert_called_once() - mock_log_env.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets torchrun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p4d.24xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="torchrun") - - # Verify file was written - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test sets mpirun distribution environment variables.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], - "network_interface_name": "eth0", - } - - set_env(resource_config, distribution="mpirun") - - mock_file.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") - @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) - def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test uses user-specified nproc_per_node.""" - mock_cpus.return_value = 8 - mock_gpus.return_value = 2 - mock_neurons.return_value = 0 - - resource_config = { - "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1"], - "network_interface_name": "eth0", - } - - set_env(resource_config, user_nproc_per_node="4") - - mock_file.assert_called_once() - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestUnpackUserWorkspace: - """Test _unpack_user_workspace function.""" - - @patch("os.path.exists") - def test_returns_none_if_dir_not_exists(self, mock_exists): - """Test returns None if workspace directory doesn't exist.""" - mock_exists.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("os.path.isfile") - @patch("os.path.exists") - def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile): - """Test returns None if workspace archive doesn't exist.""" - mock_exists.return_value = True - mock_isfile.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch("shutil.unpack_archive") - @patch("os.path.isfile") - @patch("os.path.exists") - @patch("os.getcwd") - def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack): - """Test unpacks workspace successfully.""" - mock_getcwd.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_isfile.return_value = True - - result = _unpack_user_workspace() - - mock_unpack.assert_called_once() - assert result is not None - - -class TestHandlePreExecScripts: - """Test _handle_pre_exec_scripts function.""" - - @patch("os.path.isfile") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_runs_pre_exec_script(self, mock_manager_class, mock_isfile): - """Test runs pre-execution script.""" - mock_isfile.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/tmp/scripts") - - mock_manager.run_pre_exec_script.assert_called_once() - - -class TestInstallDependencies: - """Test _install_dependencies function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_installs_with_dependency_settings(self, mock_manager_class): - """Test installs dependencies with dependency settings.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file="requirements.txt") - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_skips_if_no_dependency_file(self, mock_manager_class): - """Test skips installation if no dependency file.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - dep_settings = _DependencySettings(dependency_file=None) - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - dep_settings - ) - - mock_manager.bootstrap.assert_not_called() - - @patch("os.listdir") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir): - """Test finds dependency file in legacy mode.""" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - mock_listdir.return_value = ["requirements.txt", "script.py"] - - _install_dependencies( - "/tmp/deps", - "my-env", - "3.8", - "channel", - None - ) - - mock_manager.bootstrap.assert_called_once() - - -class TestBootstrapRuntimeEnvForRemoteFunction: - """Test _bootstrap_runtime_env_for_remote_function function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install): - """Test bootstraps runtime environment successfully.""" - mock_unpack.return_value = "/tmp/workspace" - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_returns_early_if_no_workspace(self, mock_unpack): - """Test returns early if no workspace to unpack.""" - mock_unpack.return_value = None - - _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) - - mock_unpack.assert_called_once() - - -class TestBootstrapRuntimeEnvForPipelineStep: - """Test _bootstrap_runtime_env_for_pipeline_step function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("shutil.copy") - @patch("os.listdir") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install): - """Test bootstraps pipeline step with workspace.""" - mock_unpack.return_value = "/tmp/workspace" - mock_exists.return_value = True - mock_listdir.return_value = ["requirements.txt"] - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_unpack.assert_called_once() - mock_handle_scripts.assert_called_once() - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") - @patch("os.path.exists") - @patch("os.mkdir") - @patch("os.getcwd") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") - def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install): - """Test creates workspace directory if none exists.""" - mock_unpack.return_value = None - mock_getcwd.return_value = "/tmp" - mock_exists.return_value = False - - _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) - - mock_mkdir.assert_called_once() - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function") - @patch("getpass.getuser") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function successful execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = None - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = None - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure): - """Test main function handles exceptions.""" - mock_getuser.return_value = "root" - mock_manager = MagicMock() - mock_manager._validate_python_version.side_effect = Exception("Test error") - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") - @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step") - @patch("getpass.getuser") - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") - def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): - """Test main function for pipeline execution.""" - mock_getuser.return_value = "root" - mock_exists.return_value = True - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - # Mock parsed args - mock_args = MagicMock() - mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = None - mock_args.job_conda_env = None - mock_args.pipeline_execution_id = "exec-123" - mock_args.dependency_settings = None - mock_args.func_step_s3_dir = "s3://bucket/func" - mock_args.distribution = None - mock_args.user_nproc_per_node = None - mock_parse_args.return_value = mock_args - - args = [ - "--client_python_version", "3.8", - "--pipeline_execution_id", "exec-123", - "--func_step_s3_dir", "s3://bucket/func", - ] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_bootstrap.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") - @patch("getpass.getuser") - def test_main_non_root_user(self, mock_getuser, mock_manager_class): - """Test main function with non-root user.""" - mock_getuser.return_value = "ubuntu" - mock_manager = MagicMock() - mock_manager_class.return_value = mock_manager - - args = [ - "--client_python_version", "3.8", - ] - - with pytest.raises(SystemExit): - main(args) - - mock_manager.change_dir_permission.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py deleted file mode 100644 index b84dda5c1a..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py +++ /dev/null @@ -1,424 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for mpi_utils_remote module.""" -from __future__ import absolute_import - -import os -import pytest -import subprocess -import time -from unittest.mock import patch, MagicMock, mock_open, call -import paramiko - -from sagemaker.core.remote_function.runtime_environment.mpi_utils_remote import ( - CustomHostKeyPolicy, - _parse_args, - _can_connect, - _write_file_to_host, - _write_failure_reason_file, - _wait_for_master, - _wait_for_status_file, - _wait_for_workers, - bootstrap_master_node, - bootstrap_worker_node, - start_sshd_daemon, - write_status_file_to_workers, - main, - SUCCESS_EXIT_CODE, - DEFAULT_FAILURE_CODE, - FAILURE_REASON_PATH, - FINISHED_STATUS_FILE, - READY_FILE, - DEFAULT_SSH_PORT, -) - - -class TestCustomHostKeyPolicy: - """Test CustomHostKeyPolicy class.""" - - def test_accepts_algo_hostname(self): - """Test accepts hostnames starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "algo-1234" - mock_key = MagicMock() - mock_key.get_name.return_value = "ssh-rsa" - - # Should not raise exception - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key) - - def test_rejects_non_algo_hostname(self): - """Test rejects hostnames not starting with algo-.""" - policy = CustomHostKeyPolicy() - mock_client = MagicMock() - mock_hostname = "unknown-host" - mock_key = MagicMock() - - with pytest.raises(paramiko.SSHException): - policy.missing_host_key(mock_client, mock_hostname, mock_key) - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_default_args(self): - """Test parsing with default arguments.""" - args = [] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - def test_parse_job_ended_true(self): - """Test parsing with job_ended set to true.""" - args = ["--job_ended", "1"] - parsed = _parse_args(args) - assert parsed.job_ended == "1" - - def test_parse_job_ended_false(self): - """Test parsing with job_ended set to false.""" - args = ["--job_ended", "0"] - parsed = _parse_args(args) - assert parsed.job_ended == "0" - - -class TestCanConnect: - """Test _can_connect function.""" - - @patch("paramiko.SSHClient") - def test_can_connect_success(self, mock_ssh_client_class): - """Test successful connection.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is True - mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT) - - @patch("paramiko.SSHClient") - def test_can_connect_failure(self, mock_ssh_client_class): - """Test failed connection.""" - mock_client = MagicMock() - mock_client.connect.side_effect = Exception("Connection failed") - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - result = _can_connect("algo-1", DEFAULT_SSH_PORT) - - assert result is False - - @patch("paramiko.SSHClient") - def test_can_connect_uses_custom_port(self, mock_ssh_client_class): - """Test connection with custom port.""" - mock_client = MagicMock() - mock_ssh_client_class.return_value.__enter__.return_value = mock_client - - _can_connect("algo-1", 2222) - - mock_client.connect.assert_called_once_with("algo-1", port=2222) - - -class TestWriteFileToHost: - """Test _write_file_to_host function.""" - - @patch("subprocess.run") - def test_write_file_success(self, mock_run): - """Test successful file write.""" - mock_run.return_value = MagicMock(returncode=0) - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is True - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_write_file_failure(self, mock_run): - """Test failed file write.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - - result = _write_file_to_host("algo-1", "/tmp/status") - - assert result is False - - -class TestWriteFailureReasonFile: - """Test _write_failure_reason_file function.""" - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_writes_failure_file(self, mock_exists, mock_file): - """Test writes failure reason file.""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.exists") - def test_does_not_write_if_exists(self, mock_exists, mock_file): - """Test does not write if failure file already exists.""" - mock_exists.return_value = True - - _write_failure_reason_file("Test error message") - - mock_file.assert_not_called() - - -class TestWaitForMaster: - """Test _wait_for_master function.""" - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_success(self, mock_can_connect, mock_sleep): - """Test successful wait for master.""" - mock_can_connect.return_value = True - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time): - """Test timeout waiting for master.""" - mock_can_connect.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] # Simulate time passing - - with pytest.raises(TimeoutError): - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - @patch("time.time") - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time): - """Test retries before successful connection.""" - mock_can_connect.side_effect = [False, False, True] - # Return value instead of side_effect for time.time() - mock_time.return_value = 0 - - _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 3 - - -class TestWaitForStatusFile: - """Test _wait_for_status_file function.""" - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_exists(self, mock_exists, mock_sleep): - """Test wait for status file that exists.""" - mock_exists.return_value = True - - _wait_for_status_file("/tmp/status") - - mock_exists.assert_called_once_with("/tmp/status") - - @patch("time.sleep") - @patch("os.path.exists") - def test_wait_for_status_file_waits(self, mock_exists, mock_sleep): - """Test waits until status file exists.""" - mock_exists.side_effect = [False, False, True] - - _wait_for_status_file("/tmp/status") - - assert mock_exists.call_count == 3 - assert mock_sleep.call_count == 2 - - -class TestWaitForWorkers: - """Test _wait_for_workers function.""" - - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists): - """Test wait for workers with empty list.""" - _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) - - mock_can_connect.assert_not_called() - - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep): - """Test successful wait for workers.""" - mock_can_connect.return_value = True - mock_exists.return_value = True - - _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300) - - assert mock_can_connect.call_count == 2 - - @patch("time.time") - @patch("time.sleep") - @patch("os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time): - """Test timeout waiting for workers.""" - mock_can_connect.return_value = False - mock_exists.return_value = False - # Need enough values for all time.time() calls in the loop - mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] - - with pytest.raises(TimeoutError): - _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300) - - -class TestBootstrapMasterNode: - """Test bootstrap_master_node function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers") - def test_bootstrap_master_node(self, mock_wait): - """Test bootstrap master node.""" - worker_hosts = ["algo-2", "algo-3"] - - bootstrap_master_node(worker_hosts) - - mock_wait.assert_called_once_with(worker_hosts) - - -class TestBootstrapWorkerNode: - """Test bootstrap_worker_node function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_master") - def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status): - """Test bootstrap worker node.""" - bootstrap_worker_node("algo-1", "algo-2", "/tmp/status") - - mock_wait_master.assert_called_once_with("algo-1") - mock_write.assert_called_once() - mock_wait_status.assert_called_once_with("/tmp/status") - - -class TestStartSshdDaemon: - """Test start_sshd_daemon function.""" - - @patch("subprocess.Popen") - @patch("os.path.exists") - def test_starts_sshd_successfully(self, mock_exists, mock_popen): - """Test starts SSH daemon successfully.""" - mock_exists.return_value = True - - start_sshd_daemon() - - mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - - @patch("os.path.exists") - def test_raises_error_if_sshd_not_found(self, mock_exists): - """Test raises error if SSH daemon not found.""" - mock_exists.return_value = False - - with pytest.raises(RuntimeError): - start_sshd_daemon() - - -class TestWriteStatusFileToWorkers: - """Test write_status_file_to_workers function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_writes_to_all_workers(self, mock_write): - """Test writes status file to all workers.""" - mock_write.return_value = True - worker_hosts = ["algo-2", "algo-3"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_retries_on_failure(self, mock_write, mock_sleep): - """Test retries writing status file on failure.""" - mock_write.side_effect = [False, False, True] - worker_hosts = ["algo-2"] - - write_status_file_to_workers(worker_hosts, "/tmp/status") - - assert mock_write.call_count == 3 - assert mock_sleep.call_count == 2 - - @patch("time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") - def test_raises_timeout_after_retries(self, mock_write, mock_sleep): - """Test raises timeout after max retries.""" - mock_write.return_value = False - worker_hosts = ["algo-2"] - - with pytest.raises(TimeoutError): - write_status_file_to_workers(worker_hosts, "/tmp/status") - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker): - """Test main function for worker node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_worker.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master): - """Test main function for master node during job run.""" - args = ["--job_ended", "0"] - - main(args) - - mock_start_sshd.assert_called_once() - mock_bootstrap_master.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) - def test_main_master_node_job_ended(self, mock_write_status): - """Test main function for master node after job ends.""" - args = ["--job_ended", "1"] - - main(args) - - mock_write_status.assert_called_once() - - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_worker_node_job_ended(self): - """Test main function for worker node after job ends.""" - args = ["--job_ended", "1"] - - # Should not raise any exceptions - main(args) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) - def test_main_handles_exception(self, mock_start_sshd, mock_write_failure): - """Test main function handles exceptions.""" - mock_start_sshd.side_effect = Exception("Test error") - args = ["--job_ended", "0"] - - with pytest.raises(SystemExit) as exc_info: - main(args) - - assert exc_info.value.code == DEFAULT_FAILURE_CODE - mock_write_failure.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py deleted file mode 100644 index a300daf2b3..0000000000 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py +++ /dev/null @@ -1,572 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for runtime_environment_manager module.""" -from __future__ import absolute_import - -import json -import os -import subprocess -import sys -import pytest -from unittest.mock import patch, MagicMock, mock_open, call - -from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( - _DependencySettings, - RuntimeEnvironmentManager, - RuntimeEnvironmentError, - get_logger, - _run_and_get_output_shell_cmd, - _run_pre_execution_command_script, - _run_shell_cmd, - _log_output, - _log_error, - _python_executable, -) - - -class TestDependencySettings: - """Test _DependencySettings class.""" - - def test_init_with_no_file(self): - """Test initialization without dependency file.""" - settings = _DependencySettings() - assert settings.dependency_file is None - - def test_init_with_file(self): - """Test initialization with dependency file.""" - settings = _DependencySettings(dependency_file="requirements.txt") - assert settings.dependency_file == "requirements.txt" - - def test_to_string(self): - """Test converts to JSON string.""" - settings = _DependencySettings(dependency_file="requirements.txt") - result = settings.to_string() - assert result == '{"dependency_file": "requirements.txt"}' - - def test_from_string_with_file(self): - """Test creates from JSON string with file.""" - json_str = '{"dependency_file": "requirements.txt"}' - settings = _DependencySettings.from_string(json_str) - assert settings.dependency_file == "requirements.txt" - - def test_from_string_with_none(self): - """Test creates from None.""" - settings = _DependencySettings.from_string(None) - assert settings is None - - def test_from_dependency_file_path_with_none(self): - """Test creates from None file path.""" - settings = _DependencySettings.from_dependency_file_path(None) - assert settings.dependency_file is None - - def test_from_dependency_file_path_with_auto_capture(self): - """Test creates from auto_capture.""" - settings = _DependencySettings.from_dependency_file_path("auto_capture") - assert settings.dependency_file == "env_snapshot.yml" - - def test_from_dependency_file_path_with_path(self): - """Test creates from file path.""" - settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") - assert settings.dependency_file == "requirements.txt" - - -class TestGetLogger: - """Test get_logger function.""" - - def test_returns_logger(self): - """Test returns logger instance.""" - logger = get_logger() - assert logger is not None - assert logger.name == "sagemaker.remote_function" - - -class TestRuntimeEnvironmentManager: - """Test RuntimeEnvironmentManager class.""" - - def test_init(self): - """Test initialization.""" - manager = RuntimeEnvironmentManager() - assert manager is not None - - @patch("os.path.isfile") - def test_snapshot_returns_none_for_none(self, mock_isfile): - """Test snapshot returns None when dependencies is None.""" - manager = RuntimeEnvironmentManager() - result = manager.snapshot(None) - assert result is None - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime") - def test_snapshot_auto_capture(self, mock_capture): - """Test snapshot with auto_capture.""" - mock_capture.return_value = "/path/to/env_snapshot.yml" - manager = RuntimeEnvironmentManager() - result = manager.snapshot("auto_capture") - assert result == "/path/to/env_snapshot.yml" - mock_capture.assert_called_once() - - @patch("os.path.isfile") - def test_snapshot_with_txt_file(self, mock_isfile): - """Test snapshot with requirements.txt file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("requirements.txt") - assert result == "requirements.txt" - - @patch("os.path.isfile") - def test_snapshot_with_yml_file(self, mock_isfile): - """Test snapshot with conda.yml file.""" - mock_isfile.return_value = True - manager = RuntimeEnvironmentManager() - result = manager.snapshot("environment.yml") - assert result == "environment.yml" - - @patch("os.path.isfile") - def test_snapshot_raises_error_for_invalid_file(self, mock_isfile): - """Test snapshot raises error for invalid file.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("requirements.txt") - - def test_snapshot_raises_error_for_invalid_format(self): - """Test snapshot raises error for invalid format.""" - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager.snapshot("invalid.json") - - @patch("os.getenv") - def test_get_active_conda_env_prefix(self, mock_getenv): - """Test gets active conda environment prefix.""" - mock_getenv.return_value = "/opt/conda/envs/myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_prefix() - assert result == "/opt/conda/envs/myenv" - - @patch("os.getenv") - def test_get_active_conda_env_name(self, mock_getenv): - """Test gets active conda environment name.""" - mock_getenv.return_value = "myenv" - manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_name() - assert result == "myenv" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix") - @patch("os.getcwd") - @patch("os.getenv") - def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export): - """Test captures from local runtime.""" - mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - result = manager._capture_from_local_runtime() - assert result == "/tmp/env_snapshot.yml" - mock_export.assert_called_once() - - @patch("os.getenv") - def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv): - """Test raises error when no conda environment active.""" - mock_getenv.return_value = None - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._capture_from_local_runtime() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt") - def test_bootstrap_with_txt_file_no_conda(self, mock_install): - """Test bootstrap with requirements.txt without conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", None) - mock_install.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env") - def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write): - """Test bootstrap with requirements.txt with conda.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("requirements.txt", "3.8", "myenv") - mock_install.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env") - def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write): - """Test bootstrap with conda.yml with existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", "myenv") - mock_update.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env") - def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write): - """Test bootstrap with conda.yml without existing conda env.""" - manager = RuntimeEnvironmentManager() - manager.bootstrap("environment.yml", "3.8", None) - mock_create.assert_called_once() - mock_validate.assert_called_once() - mock_write.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script): - """Test runs pre-execution script when it exists.""" - mock_isfile.return_value = True - mock_run_script.return_value = (0, "") - manager = RuntimeEnvironmentManager() - manager.run_pre_exec_script("/path/to/script.sh") - mock_run_script.assert_called_once() - - @patch("os.path.isfile") - def test_run_pre_exec_script_not_exists(self, mock_isfile): - """Test handles pre-execution script not existing.""" - mock_isfile.return_value = False - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") - @patch("os.path.isfile") - def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script): - """Test raises error when pre-execution script fails.""" - mock_isfile.return_value = True - mock_run_script.return_value = (1, "Error message") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.run_pre_exec_script("/path/to/script.sh") - - @patch("subprocess.run") - def test_change_dir_permission_success(self, mock_run): - """Test changes directory permissions successfully.""" - manager = RuntimeEnvironmentManager() - manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777") - mock_run.assert_called_once() - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_on_failure(self, mock_run): - """Test raises error when permission change fails.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("subprocess.run") - def test_change_dir_permission_raises_error_no_sudo(self, mock_run): - """Test raises error when sudo not found.""" - mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir1"], "777") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - def test_install_requirements_txt(self, mock_run_cmd): - """Test installs requirements.txt.""" - manager = RuntimeEnvironmentManager() - manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_create_conda_env(self, mock_get_conda, mock_run_cmd): - """Test creates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._create_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd): - """Test installs requirements.txt in conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt") - mock_run_cmd.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_update_conda_env(self, mock_get_conda, mock_run_cmd): - """Test updates conda environment.""" - mock_get_conda.return_value = "conda" - manager = RuntimeEnvironmentManager() - manager._update_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - - @patch("builtins.open", new_callable=mock_open) - @patch("subprocess.Popen") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_export_conda_env_from_prefix(self, mock_get_conda, mock_popen, mock_file): - """Test exports conda environment.""" - mock_get_conda.return_value = "conda" - mock_process = MagicMock() - mock_process.communicate.return_value = (b"env output", b"") - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - - manager = RuntimeEnvironmentManager() - manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml") - - mock_popen.assert_called_once() - mock_file.assert_called_once_with("/tmp/env.yml", "w") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.getcwd") - def test_write_conda_env_to_file(self, mock_getcwd, mock_file): - """Test writes conda environment name to file.""" - mock_getcwd.return_value = "/tmp" - manager = RuntimeEnvironmentManager() - manager._write_conda_env_to_file("myenv") - mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w") - mock_file().write.assert_called_once_with("myenv") - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_mamba(self, mock_popen): - """Test returns mamba when available.""" - mock_popen.return_value.wait.side_effect = [0, 1] # mamba exists, conda doesn't - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "mamba" - - @patch("subprocess.Popen") - def test_get_conda_exe_returns_conda(self, mock_popen): - """Test returns conda when mamba not available.""" - mock_popen.return_value.wait.side_effect = [1, 0] # mamba doesn't exist, conda does - manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "conda" - - @patch("subprocess.Popen") - def test_get_conda_exe_raises_error(self, mock_popen): - """Test raises error when neither conda nor mamba available.""" - mock_popen.return_value.wait.return_value = 1 - manager = RuntimeEnvironmentManager() - with pytest.raises(ValueError): - manager._get_conda_exe() - - @patch("subprocess.check_output") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output): - """Test gets Python version in conda environment.""" - mock_get_conda.return_value = "conda" - mock_check_output.return_value = b"Python 3.8.10" - manager = RuntimeEnvironmentManager() - result = manager._python_version_in_conda_env("myenv") - assert result == "3.8" - - @patch("subprocess.check_output") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") - def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output): - """Test raises error when getting Python version fails.""" - mock_get_conda.return_value = "conda" - mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error") - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._python_version_in_conda_env("myenv") - - def test_current_python_version(self): - """Test gets current Python version.""" - manager = RuntimeEnvironmentManager() - result = manager._current_python_version() - expected = f"{sys.version_info.major}.{sys.version_info.minor}" - assert result == expected - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_with_conda(self, mock_python_version): - """Test validates Python version with conda environment.""" - mock_python_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") - def test_validate_python_version_mismatch_with_conda(self, mock_python_version): - """Test raises error on Python version mismatch with conda.""" - mock_python_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", "myenv") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_without_conda(self, mock_current_version): - """Test validates Python version without conda environment.""" - mock_current_version.return_value = "3.8" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_python_version("3.8", None) - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") - def test_validate_python_version_mismatch_without_conda(self, mock_current_version): - """Test raises error on Python version mismatch without conda.""" - mock_current_version.return_value = "3.9" - manager = RuntimeEnvironmentManager() - with pytest.raises(RuntimeEnvironmentError): - manager._validate_python_version("3.8", None) - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_match(self, mock_current_version): - """Test validates matching SageMaker SDK version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception or warning - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version): - """Test logs warning on SageMaker SDK version mismatch.""" - mock_current_version.return_value = "2.101.0" - manager = RuntimeEnvironmentManager() - # Should log warning but not raise exception - manager._validate_sagemaker_pysdk_version("2.100.0") - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_none(self, mock_current_version): - """Test handles None client version.""" - mock_current_version.return_value = "2.100.0" - manager = RuntimeEnvironmentManager() - # Should not raise exception - manager._validate_sagemaker_pysdk_version(None) - - -class TestRunAndGetOutputShellCmd: - """Test _run_and_get_output_shell_cmd function.""" - - @patch("subprocess.check_output") - def test_runs_command_successfully(self, mock_check_output): - """Test runs command and returns output.""" - mock_check_output.return_value = b"command output" - result = _run_and_get_output_shell_cmd("echo test") - assert result == "command output" - - -class TestRunPreExecutionCommandScript: - """Test _run_pre_execution_command_script function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script successfully.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 0 - assert error_logs == "" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - @patch("os.path.dirname") - def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): - """Test runs script that returns error.""" - mock_dirname.return_value = "/tmp" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") - - assert return_code == 1 - assert error_logs == "Error message" - - -class TestRunShellCmd: - """Test _run_shell_cmd function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error): - """Test runs command successfully.""" - mock_process = MagicMock() - mock_process.wait.return_value = 0 - mock_popen.return_value = mock_process - mock_log_error.return_value = "" - - _run_shell_cmd(["echo", "test"]) - - mock_popen.assert_called_once() - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") - @patch("subprocess.Popen") - def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error): - """Test raises error when command fails.""" - mock_process = MagicMock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process - mock_log_error.return_value = "Error message" - - with pytest.raises(RuntimeEnvironmentError): - _run_shell_cmd(["false"]) - - -class TestLogOutput: - """Test _log_output function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_output(self, mock_logger): - """Test logs process output.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stdout = BytesIO(b"line1\nline2\n") - - _log_output(mock_process) - - assert mock_logger.info.call_count == 2 - - -class TestLogError: - """Test _log_error function.""" - - @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") - def test_logs_error(self, mock_logger): - """Test logs process errors.""" - from io import BytesIO - mock_process = MagicMock() - mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n") - - error_logs = _log_error(mock_process) - - assert "ERROR: error message" in error_logs - assert "warning message" in error_logs - - -class TestPythonExecutable: - """Test _python_executable function.""" - - def test_returns_python_executable(self): - """Test returns Python executable path.""" - result = _python_executable() - assert result == sys.executable - - @patch("sys.executable", None) - def test_raises_error_if_no_executable(self): - """Test raises error if no Python executable.""" - with pytest.raises(RuntimeEnvironmentError): - _python_executable() - - -class TestRuntimeEnvironmentError: - """Test RuntimeEnvironmentError class.""" - - def test_creates_error_with_message(self): - """Test creates error with message.""" - error = RuntimeEnvironmentError("Test error") - assert str(error) == "Test error" - assert error.message == "Test error" diff --git a/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py deleted file mode 100644 index 98a5f8bcc8..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for checkpoint_location module.""" -from __future__ import absolute_import - -import pytest -from sagemaker.core.remote_function.checkpoint_location import ( - CheckpointLocation, - _validate_s3_uri_for_checkpoint, - _JOB_CHECKPOINT_LOCATION, -) - - -class TestValidateS3Uri: - """Test _validate_s3_uri_for_checkpoint function.""" - - def test_valid_s3_uri(self): - """Test valid s3:// URI.""" - assert _validate_s3_uri_for_checkpoint("s3://my-bucket/path/to/checkpoints") - - def test_valid_https_uri(self): - """Test valid https:// URI.""" - assert _validate_s3_uri_for_checkpoint("https://my-bucket.s3.amazonaws.com/path") - - def test_valid_s3_uri_no_path(self): - """Test valid s3:// URI without path.""" - assert _validate_s3_uri_for_checkpoint("s3://my-bucket") - - def test_invalid_uri_no_protocol(self): - """Test invalid URI without protocol.""" - assert not _validate_s3_uri_for_checkpoint("my-bucket/path") - - def test_invalid_uri_wrong_protocol(self): - """Test invalid URI with wrong protocol.""" - assert not _validate_s3_uri_for_checkpoint("http://my-bucket/path") - - def test_invalid_uri_empty(self): - """Test invalid empty URI.""" - assert not _validate_s3_uri_for_checkpoint("") - - -class TestCheckpointLocation: - """Test CheckpointLocation class.""" - - def test_init_with_valid_s3_uri(self): - """Test initialization with valid s3 URI.""" - s3_uri = "s3://my-bucket/checkpoints" - checkpoint_loc = CheckpointLocation(s3_uri) - assert checkpoint_loc._s3_uri == s3_uri - - def test_init_with_valid_https_uri(self): - """Test initialization with valid https URI.""" - s3_uri = "https://my-bucket.s3.amazonaws.com/checkpoints" - checkpoint_loc = CheckpointLocation(s3_uri) - assert checkpoint_loc._s3_uri == s3_uri - - def test_init_with_invalid_uri_raises_error(self): - """Test initialization with invalid URI raises ValueError.""" - with pytest.raises(ValueError, match="CheckpointLocation should be specified with valid s3 URI"): - CheckpointLocation("invalid-uri") - - def test_fspath_returns_local_path(self): - """Test __fspath__ returns the job local path.""" - checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") - assert checkpoint_loc.__fspath__() == _JOB_CHECKPOINT_LOCATION - - def test_can_be_used_as_pathlike(self): - """Test CheckpointLocation can be used as os.PathLike.""" - import os - checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") - path = os.fspath(checkpoint_loc) - assert path == _JOB_CHECKPOINT_LOCATION diff --git a/sagemaker-core/tests/unit/remote_function/test_client.py b/sagemaker-core/tests/unit/remote_function/test_client.py deleted file mode 100644 index 83e1a2db80..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_client.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -from unittest.mock import Mock -from collections import deque - -from sagemaker.core.remote_function.client import ( - RemoteExecutor, - _submit_worker, - _polling_worker, - _API_CALL_LIMIT, - _PENDING, - _RUNNING, - _CANCELLED, - _FINISHED, -) - - -class TestConstants: - """Test module constants""" - - def test_api_call_limit_constants(self): - assert _API_CALL_LIMIT["SubmittingIntervalInSecs"] == 1 - assert _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"] == 10 - assert _API_CALL_LIMIT["PollingIntervalInSecs"] == 0.5 - - def test_future_state_constants(self): - assert _PENDING == "PENDING" - assert _RUNNING == "RUNNING" - assert _CANCELLED == "CANCELLED" - assert _FINISHED == "FINISHED" - - -class TestRemoteExecutorValidation: - """Test RemoteExecutor argument validation""" - - def test_validate_submit_args_with_valid_args(self): - def my_function(x, y, z=10): - return x + y + z - - RemoteExecutor._validate_submit_args(my_function, 1, 2, z=3) - - def test_validate_submit_args_with_missing_args(self): - def my_function(x, y): - return x + y - - with pytest.raises(TypeError): - RemoteExecutor._validate_submit_args(my_function, 1) - - def test_validate_submit_args_with_extra_args(self): - def my_function(x): - return x - - with pytest.raises(TypeError): - RemoteExecutor._validate_submit_args(my_function, 1, 2) - - -class TestWorkerFunctions: - """Test worker thread functions""" - - def test_submit_worker_exits_on_none(self): - """Test that submit worker exits when None is in queue""" - executor = Mock() - executor._pending_request_queue = deque([None]) - executor._running_jobs = {} - executor.max_parallel_jobs = 1 - - mock_condition = Mock() - mock_condition.__enter__ = Mock(return_value=mock_condition) - mock_condition.__exit__ = Mock(return_value=False) - mock_condition.wait_for = Mock(return_value=True) - executor._state_condition = mock_condition - - _submit_worker(executor) - - assert len(executor._pending_request_queue) == 0 - - def test_polling_worker_exits_on_shutdown(self): - """Test that polling worker exits when shutdown flag is set""" - executor = Mock() - executor._running_jobs = {} - executor._pending_request_queue = deque() - executor._shutdown = True - executor._state_condition = Mock() - - _polling_worker(executor) diff --git a/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py deleted file mode 100644 index 5145a77adf..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for custom_file_filter module.""" -from __future__ import absolute_import - -import os -import tempfile -import shutil -from unittest.mock import patch, MagicMock -import pytest - -from sagemaker.core.remote_function.custom_file_filter import ( - CustomFileFilter, - resolve_custom_file_filter_from_config_file, - copy_workdir, -) - - -class TestCustomFileFilter: - """Test CustomFileFilter class.""" - - def test_init_with_no_patterns(self): - """Test initialization without ignore patterns.""" - filter_obj = CustomFileFilter() - assert filter_obj.ignore_name_patterns == [] - assert filter_obj.workdir == os.getcwd() - - def test_init_with_patterns(self): - """Test initialization with ignore patterns.""" - patterns = ["*.pyc", "__pycache__", "*.log"] - filter_obj = CustomFileFilter(ignore_name_patterns=patterns) - assert filter_obj.ignore_name_patterns == patterns - - def test_ignore_name_patterns_property(self): - """Test ignore_name_patterns property.""" - patterns = ["*.txt", "temp*"] - filter_obj = CustomFileFilter(ignore_name_patterns=patterns) - assert filter_obj.ignore_name_patterns == patterns - - def test_workdir_property(self): - """Test workdir property.""" - filter_obj = CustomFileFilter() - assert filter_obj.workdir == os.getcwd() - - -class TestResolveCustomFileFilterFromConfigFile: - """Test resolve_custom_file_filter_from_config_file function.""" - - def test_returns_direct_input_when_provided_as_filter(self): - """Test returns direct input when CustomFileFilter is provided.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.pyc"]) - result = resolve_custom_file_filter_from_config_file(direct_input=filter_obj) - assert result is filter_obj - - def test_returns_direct_input_when_provided_as_callable(self): - """Test returns direct input when callable is provided.""" - def custom_filter(path, names): - return [] - result = resolve_custom_file_filter_from_config_file(direct_input=custom_filter) - assert result is custom_filter - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_returns_none_when_no_config(self, mock_resolve): - """Test returns None when no config is found.""" - mock_resolve.return_value = None - result = resolve_custom_file_filter_from_config_file() - assert result is None - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_creates_filter_from_config(self, mock_resolve): - """Test creates CustomFileFilter from config.""" - patterns = ["*.pyc", "*.log"] - mock_resolve.return_value = patterns - result = resolve_custom_file_filter_from_config_file() - assert isinstance(result, CustomFileFilter) - assert result.ignore_name_patterns == patterns - - @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") - def test_passes_sagemaker_session_to_resolve(self, mock_resolve): - """Test passes sagemaker_session to resolve_value_from_config.""" - mock_session = MagicMock() - mock_resolve.return_value = None - resolve_custom_file_filter_from_config_file(sagemaker_session=mock_session) - mock_resolve.assert_called_once() - assert mock_resolve.call_args[1]["sagemaker_session"] == mock_session - - -class TestCopyWorkdir: - """Test copy_workdir function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.temp_src = tempfile.mkdtemp() - self.temp_dst = tempfile.mkdtemp() - - # Create test files - with open(os.path.join(self.temp_src, "test.py"), "w") as f: - f.write("print('test')") - with open(os.path.join(self.temp_src, "test.txt"), "w") as f: - f.write("text file") - os.makedirs(os.path.join(self.temp_src, "__pycache__")) - with open(os.path.join(self.temp_src, "__pycache__", "test.pyc"), "w") as f: - f.write("compiled") - - def teardown_method(self): - """Clean up test fixtures.""" - if os.path.exists(self.temp_src): - shutil.rmtree(self.temp_src) - if os.path.exists(self.temp_dst): - shutil.rmtree(self.temp_dst) - - @patch("os.getcwd") - def test_copy_workdir_without_filter_only_python_files(self, mock_getcwd): - """Test copy_workdir without filter copies only Python files.""" - mock_getcwd.return_value = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - assert not os.path.exists(os.path.join(dst, "__pycache__")) - - @patch("os.getcwd") - def test_copy_workdir_with_callable_filter(self, mock_getcwd): - """Test copy_workdir with callable filter.""" - mock_getcwd.return_value = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - def custom_filter(path, names): - return ["test.txt"] - - copy_workdir(dst, custom_file_filter=custom_filter) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - - def test_copy_workdir_with_custom_file_filter_object(self): - """Test copy_workdir with CustomFileFilter object.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.py"]) - filter_obj._workdir = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst, custom_file_filter=filter_obj) - - assert not os.path.exists(os.path.join(dst, "test.py")) - assert os.path.exists(os.path.join(dst, "test.txt")) - - def test_copy_workdir_with_pattern_matching(self): - """Test copy_workdir with pattern matching in CustomFileFilter.""" - filter_obj = CustomFileFilter(ignore_name_patterns=["*.txt", "__pycache__"]) - filter_obj._workdir = self.temp_src - dst = os.path.join(self.temp_dst, "output") - - copy_workdir(dst, custom_file_filter=filter_obj) - - assert os.path.exists(os.path.join(dst, "test.py")) - assert not os.path.exists(os.path.join(dst, "test.txt")) - assert not os.path.exists(os.path.join(dst, "__pycache__")) diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py deleted file mode 100644 index 4810eba2e0..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for invoke_function module.""" -from __future__ import absolute_import - -import json -import pytest -from unittest.mock import patch, MagicMock, call - -from sagemaker.core.remote_function.invoke_function import ( - _parse_args, - _get_sagemaker_session, - _load_run_object, - _load_pipeline_context, - _execute_remote_function, - main, - SUCCESS_EXIT_CODE, -) -from sagemaker.core.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME - - -class TestParseArgs: - """Test _parse_args function.""" - - def test_parse_required_args(self): - """Test parsing required arguments.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://my-bucket/path", - ] - parsed = _parse_args(args) - assert parsed.region == "us-west-2" - assert parsed.s3_base_uri == "s3://my-bucket/path" - - def test_parse_all_args(self): - """Test parsing all arguments.""" - args = [ - "--region", "us-east-1", - "--s3_base_uri", "s3://bucket/path", - "--s3_kms_key", "key-123", - "--run_in_context", '{"experiment": "exp1"}', - "--pipeline_step_name", "step1", - "--pipeline_execution_id", "exec-123", - "--property_references", "prop1", "val1", "prop2", "val2", - "--serialize_output_to_json", "true", - "--func_step_s3_dir", "s3://bucket/func", - ] - parsed = _parse_args(args) - assert parsed.region == "us-east-1" - assert parsed.s3_base_uri == "s3://bucket/path" - assert parsed.s3_kms_key == "key-123" - assert parsed.run_in_context == '{"experiment": "exp1"}' - assert parsed.pipeline_step_name == "step1" - assert parsed.pipeline_execution_id == "exec-123" - assert parsed.property_references == ["prop1", "val1", "prop2", "val2"] - assert parsed.serialize_output_to_json is True - assert parsed.func_step_s3_dir == "s3://bucket/func" - - def test_parse_serialize_output_false(self): - """Test parsing serialize_output_to_json as false.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://bucket/path", - "--serialize_output_to_json", "false", - ] - parsed = _parse_args(args) - assert parsed.serialize_output_to_json is False - - def test_parse_default_values(self): - """Test default values for optional arguments.""" - args = [ - "--region", "us-west-2", - "--s3_base_uri", "s3://bucket/path", - ] - parsed = _parse_args(args) - assert parsed.s3_kms_key is None - assert parsed.run_in_context is None - assert parsed.pipeline_step_name is None - assert parsed.pipeline_execution_id is None - assert parsed.property_references == [] - assert parsed.serialize_output_to_json is False - assert parsed.func_step_s3_dir is None - - -class TestGetSagemakerSession: - """Test _get_sagemaker_session function.""" - - @patch("sagemaker.core.remote_function.invoke_function.boto3.session.Session") - @patch("sagemaker.core.remote_function.invoke_function.Session") - def test_creates_session_with_region(self, mock_session_class, mock_boto_session): - """Test creates SageMaker session with correct region.""" - mock_boto = MagicMock() - mock_boto_session.return_value = mock_boto - - _get_sagemaker_session("us-west-2") - - mock_boto_session.assert_called_once_with(region_name="us-west-2") - mock_session_class.assert_called_once_with(boto_session=mock_boto) - - -class TestLoadRunObject: - """Test _load_run_object function.""" - - @patch("sagemaker.core.experiments.run.Run") - def test_loads_run_from_json(self, mock_run_class): - """Test loads Run object from JSON string.""" - run_dict = { - KEY_EXPERIMENT_NAME: "my-experiment", - KEY_RUN_NAME: "my-run", - } - run_json = json.dumps(run_dict) - mock_session = MagicMock() - - _load_run_object(run_json, mock_session) - - mock_run_class.assert_called_once_with( - experiment_name="my-experiment", - run_name="my-run", - sagemaker_session=mock_session, - ) - - -class TestLoadPipelineContext: - """Test _load_pipeline_context function.""" - - def test_loads_context_with_all_fields(self): - """Test loads pipeline context with all fields.""" - args = MagicMock() - args.pipeline_step_name = "step1" - args.pipeline_execution_id = "exec-123" - args.property_references = ["prop1", "val1", "prop2", "val2"] - args.serialize_output_to_json = True - args.func_step_s3_dir = "s3://bucket/func" - - context = _load_pipeline_context(args) - - assert context.step_name == "step1" - assert context.execution_id == "exec-123" - assert context.property_references == {"prop1": "val1", "prop2": "val2"} - assert context.serialize_output_to_json is True - assert context.func_step_s3_dir == "s3://bucket/func" - - def test_loads_context_with_empty_property_references(self): - """Test loads pipeline context with empty property references.""" - args = MagicMock() - args.pipeline_step_name = "step1" - args.pipeline_execution_id = "exec-123" - args.property_references = [] - args.serialize_output_to_json = False - args.func_step_s3_dir = None - - context = _load_pipeline_context(args) - - assert context.property_references == {} - - -class TestExecuteRemoteFunction: - """Test _execute_remote_function function.""" - - @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") - def test_executes_without_run_context(self, mock_stored_function_class): - """Test executes stored function without run context.""" - mock_stored_func = MagicMock() - mock_stored_function_class.return_value = mock_stored_func - mock_session = MagicMock() - mock_context = MagicMock() - - _execute_remote_function( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key="key-123", - run_in_context=None, - context=mock_context, - ) - - mock_stored_function_class.assert_called_once_with( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key="key-123", - context=mock_context, - ) - mock_stored_func.load_and_invoke.assert_called_once() - - @patch("sagemaker.core.remote_function.invoke_function._load_run_object") - @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") - def test_executes_with_run_context(self, mock_stored_function_class, mock_load_run): - """Test executes stored function with run context.""" - mock_stored_func = MagicMock() - mock_stored_function_class.return_value = mock_stored_func - mock_run = MagicMock() - mock_load_run.return_value = mock_run - mock_session = MagicMock() - mock_context = MagicMock() - run_json = '{"experiment": "exp1"}' - - _execute_remote_function( - sagemaker_session=mock_session, - s3_base_uri="s3://bucket/path", - s3_kms_key=None, - run_in_context=run_json, - context=mock_context, - ) - - # Verify run object was loaded and used as context manager - mock_load_run.assert_called_once_with(run_json, mock_session) - mock_run.__enter__.assert_called_once() - mock_run.__exit__.assert_called_once() - - -class TestMain: - """Test main function.""" - - @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.core.remote_function.invoke_function._parse_args") - def test_main_success(self, mock_parse, mock_load_context, mock_get_session, mock_execute): - """Test main function successful execution.""" - mock_args = MagicMock() - mock_args.region = "us-west-2" - mock_args.s3_base_uri = "s3://bucket/path" - mock_args.s3_kms_key = None - mock_args.run_in_context = None - mock_parse.return_value = mock_args - - mock_context = MagicMock() - mock_context.step_name = None - mock_load_context.return_value = mock_context - - mock_session = MagicMock() - mock_get_session.return_value = mock_session - - with pytest.raises(SystemExit) as exc_info: - main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) - - assert exc_info.value.code == SUCCESS_EXIT_CODE - mock_execute.assert_called_once() - - @patch("sagemaker.core.remote_function.invoke_function.handle_error") - @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") - @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") - @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") - @patch("sagemaker.core.remote_function.invoke_function._parse_args") - def test_main_handles_exception( - self, mock_parse, mock_load_context, mock_get_session, mock_execute, mock_handle_error - ): - """Test main function handles exceptions.""" - mock_args = MagicMock() - mock_args.region = "us-west-2" - mock_args.s3_base_uri = "s3://bucket/path" - mock_args.s3_kms_key = None - mock_args.run_in_context = None - mock_parse.return_value = mock_args - - mock_context = MagicMock() - mock_context.step_name = None - mock_load_context.return_value = mock_context - - mock_session = MagicMock() - mock_get_session.return_value = mock_session - - test_exception = Exception("Test error") - mock_execute.side_effect = test_exception - mock_handle_error.return_value = 1 - - with pytest.raises(SystemExit) as exc_info: - main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) - - assert exc_info.value.code == 1 - mock_handle_error.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_job.py b/sagemaker-core/tests/unit/remote_function/test_job.py deleted file mode 100644 index 6f10016643..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_job.py +++ /dev/null @@ -1,932 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.remote_function.job module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import sys -from unittest.mock import Mock, patch, MagicMock, call, mock_open -from io import BytesIO - -from sagemaker.core.remote_function.job import ( - _JobSettings, - _Job, - _prepare_and_upload_runtime_scripts, - _generate_input_data_config, - _prepare_dependencies_and_pre_execution_scripts, - _prepare_and_upload_workspace, - _convert_run_to_json, - _prepare_and_upload_spark_dependent_files, - _upload_spark_submit_deps, - _upload_serialized_spark_configuration, - _extend_mpirun_to_request, - _extend_torchrun_to_request, - _extend_spark_config_to_request, - _update_job_request_with_checkpoint_config, - _RunInfo, - _get_initial_job_state, - _logs_for_job, - _check_job_status, - _flush_log_streams, - _rule_statuses_changed, - _logs_init, - LogState, -) -from sagemaker.core.remote_function.spark_config import SparkConfig -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation - - -@pytest.fixture -def mock_session(): - session = Mock() - session.boto_region_name = "us-west-2" - session.default_bucket.return_value = "test-bucket" - session.default_bucket_prefix = "prefix" - session.sagemaker_client = Mock() - session.boto_session = Mock() - session.sagemaker_config = None - return session - - -class TestJobSettings: - """Test _JobSettings class.""" - - def test_init_with_spark_and_image_raises_error(self, mock_session): - """Test that spark_config and image_uri cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="spark_config and image_uri cannot be specified"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - image_uri="test-image", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_spark_and_conda_env_raises_error(self, mock_session): - """Test that spark_config and job_conda_env cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support job_conda_env"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - job_conda_env="test-env", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_spark_and_auto_capture_raises_error(self, mock_session): - """Test that spark_config and auto_capture dependencies cannot be set together.""" - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support automatically"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - dependencies="auto_capture", - instance_type="ml.m5.xlarge", - ) - - def test_init_with_pre_execution_commands_and_script_raises_error(self, mock_session): - """Test that pre_execution_commands and pre_execution_script cannot be set together.""" - with pytest.raises( - ValueError, match="Only one of pre_execution_commands or pre_execution_script" - ): - _JobSettings( - sagemaker_session=mock_session, - pre_execution_commands=["echo test"], - pre_execution_script="/path/to/script.sh", - instance_type="ml.m5.xlarge", - image_uri="test-image", - ) - - def test_init_without_instance_type_raises_error(self, mock_session): - """Test that instance_type is required.""" - with pytest.raises(ValueError, match="instance_type is a required parameter"): - _JobSettings(sagemaker_session=mock_session, image_uri="test-image") - - @patch.dict(os.environ, {"SAGEMAKER_INTERNAL_IMAGE_URI": "custom-image"}) - def test_get_default_image_from_env(self, mock_session): - """Test getting default image from environment variable.""" - image = _JobSettings._get_default_image(mock_session) - assert image == "custom-image" - - def test_get_default_image_unsupported_python_raises_error(self, mock_session): - """Test that unsupported Python version raises error.""" - with patch.object(sys, "version_info", (3, 7, 0)): - with pytest.raises( - ValueError, match="Default image is supported only for Python versions" - ): - _JobSettings._get_default_image(mock_session) - - def test_get_default_spark_image_unsupported_python_raises_error(self, mock_session): - """Test that unsupported Python version for Spark raises error.""" - with patch.object(sys, "version_info", (3, 8, 0)): - with pytest.raises( - ValueError, - match="SageMaker Spark image for remote job only supports Python version 3.9", - ): - _JobSettings._get_default_spark_image(mock_session) - - -class TestJob: - """Test _Job class.""" - - def test_init(self, mock_session): - """Test _Job initialization.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - - def test_from_describe_response(self, mock_session): - """Test creating _Job from describe response.""" - response = { - "TrainingJobName": "test-job", - "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - } - job = _Job.from_describe_response(response, mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - - def test_describe_returns_cached_response(self, mock_session): - """Test that describe returns cached response for completed jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Completed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Completed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_calls_api_for_in_progress_jobs(self, mock_session): - """Test that describe calls API for in-progress jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_session.sagemaker_client.describe_training_job.return_value = { - "TrainingJobStatus": "InProgress" - } - - result = job.describe() - assert result["TrainingJobStatus"] == "InProgress" - mock_session.sagemaker_client.describe_training_job.assert_called_once() - - def test_stop(self, mock_session): - """Test stopping a job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job.stop() - mock_session.sagemaker_client.stop_training_job.assert_called_once_with( - TrainingJobName="test-job" - ) - - @patch("sagemaker.core.remote_function.job._logs_for_job") - def test_wait(self, mock_logs, mock_session): - """Test waiting for job completion.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_logs.return_value = {"TrainingJobStatus": "Completed"} - - job.wait(timeout=100) - mock_logs.assert_called_once_with( - sagemaker_session=mock_session, job_name="test-job", wait=True, timeout=100 - ) - - -class TestUpdateJobRequestWithCheckpointConfig: - """Test _update_job_request_with_checkpoint_config function.""" - - def test_with_checkpoint_in_args(self): - """Test checkpoint config in positional args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = (checkpoint,) - kwargs = {} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - assert request_dict["CheckpointConfig"]["LocalPath"] == "/opt/ml/checkpoints/" - - def test_with_checkpoint_in_kwargs(self): - """Test checkpoint config in keyword args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = () - kwargs = {"checkpoint": checkpoint} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - - def test_with_multiple_checkpoints_raises_error(self): - """Test that multiple checkpoints raise error.""" - checkpoint1 = CheckpointLocation(s3_uri="s3://bucket/checkpoint1") - checkpoint2 = CheckpointLocation(s3_uri="s3://bucket/checkpoint2") - args = (checkpoint1,) - kwargs = {"checkpoint": checkpoint2} - request_dict = {} - - with pytest.raises( - ValueError, match="cannot have more than one argument of type CheckpointLocation" - ): - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - def test_without_checkpoint(self): - """Test without checkpoint location.""" - args = ("arg1", "arg2") - kwargs = {"key": "value"} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" not in request_dict - - -class TestConvertRunToJson: - """Test _convert_run_to_json function.""" - - def test_convert_run_to_json(self): - """Test converting run to JSON.""" - mock_run = Mock() - mock_run.experiment_name = "test-experiment" - mock_run.run_name = "test-run" - - result = _convert_run_to_json(mock_run) - data = json.loads(result) - - assert data["experiment_name"] == "test-experiment" - assert data["run_name"] == "test-run" - - -class TestUploadSerializedSparkConfiguration: - """Test _upload_serialized_spark_configuration function.""" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_upload_spark_config(self, mock_uploader, mock_session): - """Test uploading Spark configuration.""" - config = {"spark.executor.memory": "4g"} - - _upload_serialized_spark_configuration("s3://bucket/base", "kms-key", config, mock_session) - - mock_uploader.upload_string_as_file_body.assert_called_once() - - def test_upload_spark_config_none(self, mock_session): - """Test uploading None Spark configuration.""" - result = _upload_serialized_spark_configuration( - "s3://bucket/base", "kms-key", None, mock_session - ) - - assert result is None - - -class TestUploadSparkSubmitDeps: - """Test _upload_spark_submit_deps function.""" - - def test_with_none_deps(self, mock_session): - """Test with None dependencies.""" - result = _upload_spark_submit_deps( - None, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert result is None - - def test_with_s3_uri(self, mock_session): - """Test with S3 URI.""" - deps = ["s3://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3://bucket/dep.jar" in result - - def test_with_empty_workspace_raises_error(self, mock_session): - """Test with empty workspace name.""" - deps = ["s3://bucket/dep.jar"] - with pytest.raises(ValueError, match="workspace_name or s3_base_uri may not be empty"): - _upload_spark_submit_deps(deps, "", "s3://bucket", "kms-key", mock_session) - - @patch("os.path.isfile", return_value=False) - def test_with_invalid_local_file_raises_error(self, mock_isfile, mock_session): - """Test with invalid local file.""" - deps = ["/invalid/path.jar"] - with pytest.raises(ValueError, match="is not a valid local file"): - _upload_spark_submit_deps(deps, "workspace", "s3://bucket", "kms-key", mock_session) - - -class TestExtendMpirunToRequest: - """Test _extend_mpirun_to_request function.""" - - def test_without_mpirun(self, mock_session): - """Test without mpirun enabled.""" - job_settings = Mock() - job_settings.use_mpirun = False - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_single_instance(self, mock_session): - """Test with single instance.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_multiple_instances(self, mock_session): - """Test with multiple instances.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestExtendTorchrunToRequest: - """Test _extend_torchrun_to_request function.""" - - def test_without_torchrun(self, mock_session): - """Test without torchrun enabled.""" - job_settings = Mock() - job_settings.use_torchrun = False - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_single_instance(self, mock_session): - """Test with single instance.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_with_multiple_instances(self, mock_session): - """Test with multiple instances.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestExtendSparkConfigToRequest: - """Test _extend_spark_config_to_request function.""" - - def test_without_spark_config(self, mock_session): - """Test without spark config.""" - job_settings = Mock() - job_settings.spark_config = None - request_dict = {"AlgorithmSpecification": {"ContainerEntrypoint": []}} - - result = _extend_spark_config_to_request(request_dict, job_settings, "s3://bucket") - assert result == request_dict - - @patch("sagemaker.core.remote_function.job._prepare_and_upload_spark_dependent_files") - def test_with_spark_config(self, mock_upload, mock_session): - """Test with spark config.""" - mock_upload.return_value = (None, None, None, "s3://bucket/config.json") - - job_settings = Mock() - spark_config = SparkConfig(spark_event_logs_uri="s3://bucket/logs") - job_settings.spark_config = spark_config - job_settings.s3_kms_key = None - job_settings.sagemaker_session = mock_session - - request_dict = { - "AlgorithmSpecification": {"ContainerEntrypoint": []}, - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}], - } - - result = _extend_spark_config_to_request(request_dict, job_settings, "s3://bucket") - assert ( - "--spark-event-logs-s3-uri" in result["AlgorithmSpecification"]["ContainerEntrypoint"] - ) - - -class TestGetInitialJobState: - """Test _get_initial_job_state function.""" - - def test_with_completed_job_and_wait(self): - """Test with completed job and wait=True.""" - description = {"TrainingJobStatus": "Completed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_with_in_progress_job_and_wait(self): - """Test with in-progress job and wait=True.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.TAILING - - def test_with_in_progress_job_and_no_wait(self): - """Test with in-progress job and wait=False.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", False) - assert state == LogState.COMPLETE - - -class TestCheckJobStatus: - """Test _check_job_status function.""" - - def test_with_completed_status(self): - """Test with completed status.""" - desc = {"TrainingJobStatus": "Completed"} - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_with_stopped_status(self): - """Test with stopped status.""" - desc = {"TrainingJobStatus": "Stopped"} - with patch("sagemaker.core.remote_function.job.logger") as mock_logger: - _check_job_status("test-job", desc, "TrainingJobStatus") - mock_logger.warning.assert_called_once() - - def test_with_failed_status_raises_error(self): - """Test with failed status.""" - desc = {"TrainingJobStatus": "Failed", "FailureReason": "Test failure"} - with pytest.raises(Exception): - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_with_capacity_error_raises_capacity_error(self): - """Test with CapacityError.""" - desc = { - "TrainingJobStatus": "Failed", - "FailureReason": "CapacityError: Insufficient capacity", - } - from sagemaker.core import exceptions - - with pytest.raises(exceptions.CapacityError): - _check_job_status("test-job", desc, "TrainingJobStatus") - - -class TestRuleStatusesChanged: - """Test _rule_statuses_changed function.""" - - def test_with_no_last_statuses(self): - """Test with no last statuses.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, None) - assert result is True - - def test_with_changed_status(self): - """Test with changed status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "Completed"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is True - - def test_with_unchanged_status(self): - """Test with unchanged status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is False - - -class TestLogsInit: - """Test _logs_init function.""" - - def test_with_training_job(self, mock_session): - """Test with training job.""" - description = {"ResourceConfig": {"InstanceCount": 2}} - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 2 - assert log_group == "/aws/sagemaker/TrainingJobs" - - def test_with_training_job_instance_groups(self, mock_session): - """Test with training job using instance groups.""" - description = { - "ResourceConfig": {"InstanceGroups": [{"InstanceCount": 2}, {"InstanceCount": 3}]} - } - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 5 - - def test_with_transform_job(self, mock_session): - """Test with transform job.""" - description = {"TransformResources": {"InstanceCount": 1}} - result = _logs_init(mock_session.boto_session, description, "Transform") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 1 - assert log_group == "/aws/sagemaker/TransformJobs" - - def test_with_processing_job(self, mock_session): - """Test with processing job.""" - description = {"ProcessingResources": {"ClusterConfig": {"InstanceCount": 3}}} - result = _logs_init(mock_session.boto_session, description, "Processing") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 3 - assert log_group == "/aws/sagemaker/ProcessingJobs" - - def test_with_automl_job(self, mock_session): - """Test with AutoML job.""" - description = {} - result = _logs_init(mock_session.boto_session, description, "AutoML") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 0 - assert log_group == "/aws/sagemaker/AutoMLJobs" - - -class TestFlushLogStreams: - """Test _flush_log_streams function.""" - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_no_streams(self, mock_logs, mock_session): - """Test with no log streams.""" - stream_names = [] - positions = {} - client = Mock() - client.describe_log_streams.return_value = {"logStreams": []} - - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_client_error_resource_not_found(self, mock_logs, mock_session): - """Test with ResourceNotFoundException.""" - from botocore.exceptions import ClientError - - stream_names = [] - positions = {} - client = Mock() - error_response = {"Error": {"Code": "ResourceNotFoundException"}} - client.describe_log_streams.side_effect = ClientError( - error_response, "describe_log_streams" - ) - - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - @patch("sagemaker.core.remote_function.job.sagemaker_logs") - def test_with_client_error_other(self, mock_logs, mock_session): - """Test with other ClientError.""" - from botocore.exceptions import ClientError - - stream_names = [] - positions = {} - client = Mock() - error_response = {"Error": {"Code": "OtherError"}} - client.describe_log_streams.side_effect = ClientError( - error_response, "describe_log_streams" - ) - - with pytest.raises(ClientError): - _flush_log_streams( - stream_names, - 1, - client, - "/aws/sagemaker/TrainingJobs", - "test-job", - positions, - False, - lambda x, y: None, - ) - - -class TestPrepareAndUploadRuntimeScripts: - """Test _prepare_and_upload_runtime_scripts function.""" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_without_spark_or_distributed( - self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session - ): - """Test without Spark or distributed training.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, False, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_spark(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with Spark config.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - spark_config = SparkConfig() - result = _prepare_and_upload_runtime_scripts( - spark_config, "s3://bucket", "kms-key", mock_session, False, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_torchrun(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with torchrun.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, True, False - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("builtins.open", new_callable=mock_open) - def test_with_mpirun(self, mock_file, mock_shutil, mock_tmpdir, mock_uploader, mock_session): - """Test with mpirun.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_uploader.upload.return_value = "s3://bucket/scripts" - - result = _prepare_and_upload_runtime_scripts( - None, "s3://bucket", "kms-key", mock_session, False, True - ) - - assert result == "s3://bucket/scripts" - - -class TestPrepareAndUploadWorkspace: - """Test _prepare_and_upload_workspace function.""" - - def test_without_dependencies_or_workdir(self, mock_session): - """Test without dependencies or workdir.""" - result = _prepare_and_upload_workspace( - None, False, None, None, "s3://bucket", "kms-key", mock_session, None - ) - assert result is None - - @patch("sagemaker.core.remote_function.job.S3Uploader") - @patch("sagemaker.core.remote_function.job._tmpdir") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.copy_workdir") - @patch("os.mkdir") - @patch("os.path.isdir", return_value=False) - def test_with_workdir( - self, - mock_isdir, - mock_mkdir, - mock_copy, - mock_shutil, - mock_tmpdir, - mock_uploader, - mock_session, - ): - """Test with workdir.""" - mock_tmpdir.return_value.__enter__ = Mock(return_value="/tmp/test") - mock_tmpdir.return_value.__exit__ = Mock(return_value=False) - mock_shutil.make_archive.return_value = "/tmp/test/workspace.zip" - mock_uploader.upload.return_value = "s3://bucket/workspace.zip" - - result = _prepare_and_upload_workspace( - None, True, None, None, "s3://bucket", "kms-key", mock_session, None - ) - - assert result == "s3://bucket/workspace.zip" - - -class TestPrepareDependenciesAndPreExecutionScripts: - """Test _prepare_dependencies_and_pre_execution_scripts function.""" - - def test_without_dependencies_or_scripts(self, mock_session): - """Test without dependencies or scripts.""" - result = _prepare_dependencies_and_pre_execution_scripts( - None, None, None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - assert result is None - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_dependencies(self, mock_uploader, mock_shutil, mock_context, mock_session): - """Test with dependencies file.""" - mock_shutil.copy2.return_value = "/tmp/requirements.txt" - mock_uploader.upload.return_value = "s3://bucket/deps" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - "/path/to/requirements.txt", None, None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/deps" - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("builtins.open", create=True) - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_pre_execution_commands( - self, mock_uploader, mock_open, mock_context, mock_session - ): - """Test with pre-execution commands.""" - mock_uploader.upload.return_value = "s3://bucket/scripts" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - None, ["echo test"], None, "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/scripts" - - @patch("sagemaker.core.workflow.utilities.load_step_compilation_context") - @patch("sagemaker.core.remote_function.job.shutil") - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_with_pre_execution_script( - self, mock_uploader, mock_shutil, mock_context, mock_session - ): - """Test with pre-execution script.""" - mock_shutil.copy2.return_value = "/tmp/pre_exec.sh" - mock_uploader.upload.return_value = "s3://bucket/scripts" - mock_context.return_value = Mock(step_name="step", pipeline_build_time="123") - - result = _prepare_dependencies_and_pre_execution_scripts( - None, None, "/path/to/script.sh", "s3://bucket", "kms-key", mock_session, "/tmp" - ) - - assert result == "s3://bucket/scripts" - - -class TestPrepareAndUploadSparkDependentFiles: - """Test _prepare_and_upload_spark_dependent_files function.""" - - def test_without_spark_config(self, mock_session): - """Test without Spark config.""" - result = _prepare_and_upload_spark_dependent_files( - None, "s3://bucket", "kms-key", mock_session - ) - assert result == (None, None, None, None) - - @patch("sagemaker.core.remote_function.job._upload_spark_submit_deps") - @patch("sagemaker.core.remote_function.job._upload_serialized_spark_configuration") - def test_with_spark_config(self, mock_upload_config, mock_upload_deps, mock_session): - """Test with Spark config.""" - mock_upload_deps.return_value = "s3://bucket/deps" - mock_upload_config.return_value = "s3://bucket/config.json" - - spark_config = SparkConfig( - submit_jars=["test.jar"], - submit_py_files=["test.py"], - submit_files=["test.txt"], - configuration={"Classification": "spark-defaults", "Properties": {"key": "value"}}, - ) - - result = _prepare_and_upload_spark_dependent_files( - spark_config, "s3://bucket", "kms-key", mock_session - ) - - assert len(result) == 4 - - -class TestJobCompile: - """Test _Job.compile method.""" - - @patch("sagemaker.core.remote_function.job.StoredFunction") - @patch("sagemaker.core.remote_function.job._generate_input_data_config") - def test_compile_basic(self, mock_input_config, mock_stored_func, mock_session): - """Test basic compile.""" - mock_input_config.return_value = [] - mock_stored_func.return_value.save = Mock() - - job_settings = Mock() - job_settings.max_runtime_in_seconds = 3600 - job_settings.max_wait_time_in_seconds = None - job_settings.max_retry_attempts = 1 - job_settings.role = "arn:aws:iam::123456789012:role/test" - job_settings.tags = None - job_settings.s3_kms_key = None - job_settings.disable_output_compression = False - job_settings.volume_size = 30 - job_settings.instance_count = 1 - job_settings.instance_type = "ml.m5.xlarge" - job_settings.volume_kms_key = None - job_settings.keep_alive_period_in_seconds = None - job_settings.enable_network_isolation = False - job_settings.encrypt_inter_container_traffic = False - job_settings.vpc_config = None - job_settings.use_spot_instances = False - job_settings.environment_variables = {} - job_settings.image_uri = "test-image" - job_settings.sagemaker_session = mock_session - job_settings.use_torchrun = False - job_settings.use_mpirun = False - job_settings.nproc_per_node = None - job_settings.job_conda_env = None - job_settings.spark_config = None - job_settings.dependencies = None - - def test_func(): - pass - - result = _Job.compile(job_settings, "test-job", "s3://bucket", test_func, (), {}) - - assert result["TrainingJobName"] == "test-job" - assert result["RoleArn"] == "arn:aws:iam::123456789012:role/test" - - -class TestJobStart: - """Test _Job.start method.""" - - @patch("sagemaker.core.remote_function.job._Job.compile") - @patch("sagemaker.core.remote_function.job._Job._get_job_name") - def test_start(self, mock_get_name, mock_compile, mock_session): - """Test starting a job.""" - mock_get_name.return_value = "test-job" - mock_compile.return_value = { - "TrainingJobName": "test-job", - "Environment": {}, - } - - job_settings = Mock() - job_settings.s3_root_uri = "s3://bucket" - job_settings.sagemaker_session = mock_session - - def test_func(): - pass - - job = _Job.start(job_settings, test_func, (), {}) - - assert job.job_name == "test-job" - mock_session.sagemaker_client.create_training_job.assert_called_once() - - -class TestJobGetJobName: - """Test _Job._get_job_name method.""" - - def test_with_job_name_prefix(self, mock_session): - """Test with job_name_prefix.""" - job_settings = Mock() - job_settings.job_name_prefix = "my-job" - - def test_func(): - pass - - result = _Job._get_job_name(job_settings, test_func) - assert "my-job" in result - - def test_without_job_name_prefix(self, mock_session): - """Test without job_name_prefix.""" - job_settings = Mock() - job_settings.job_name_prefix = None - - def test_func(): - pass - - result = _Job._get_job_name(job_settings, test_func) - assert "test-func" in result - - def test_with_special_characters_in_func_name(self, mock_session): - """Test with special characters in function name.""" - job_settings = Mock() - job_settings.job_name_prefix = None - - def _test_func(): - pass - - result = _Job._get_job_name(job_settings, _test_func) - assert result.startswith("test-func") diff --git a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py deleted file mode 100644 index bc8d5a8e56..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Comprehensive unit tests for uncovered lines in sagemaker.core.remote_function.job module.""" -from __future__ import absolute_import - -import json -import os -import pytest -import sys -import tempfile -from unittest.mock import Mock, patch, MagicMock, mock_open -from io import BytesIO - -from sagemaker.core.remote_function.job import ( - _JobSettings, - _Job, - _update_job_request_with_checkpoint_config, - _convert_run_to_json, - _upload_spark_submit_deps, - _upload_serialized_spark_configuration, - _extend_mpirun_to_request, - _extend_torchrun_to_request, - _check_job_status, - _rule_statuses_changed, - _logs_init, - _get_initial_job_state, - LogState, - _RunInfo, -) -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation - - -@pytest.fixture -def mock_session(): - session = Mock() - session.boto_region_name = "us-west-2" - session.default_bucket.return_value = "test-bucket" - session.default_bucket_prefix = "prefix" - session.sagemaker_client = Mock() - session.boto_session = Mock() - session.sagemaker_config = {} - return session - - -class TestJobSettingsValidation: - """Test _JobSettings validation logic for uncovered lines.""" - - def test_spark_config_with_image_uri_raises_error(self, mock_session): - """Test lines 619-620: spark_config and image_uri validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="spark_config and image_uri cannot be specified"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - image_uri="test-image", - instance_type="ml.m5.xlarge", - ) - - def test_spark_config_with_conda_env_raises_error(self, mock_session): - """Test lines 622-623: spark_config and job_conda_env validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support job_conda_env"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - job_conda_env="test-env", - instance_type="ml.m5.xlarge", - ) - - def test_spark_config_with_auto_capture_raises_error(self, mock_session): - """Test lines 625-628: spark_config and auto_capture validation.""" - from sagemaker.core.remote_function.spark_config import SparkConfig - - spark_config = SparkConfig() - with pytest.raises(ValueError, match="Remote Spark jobs do not support automatically"): - _JobSettings( - sagemaker_session=mock_session, - spark_config=spark_config, - dependencies="auto_capture", - instance_type="ml.m5.xlarge", - ) - - def test_pre_execution_commands_and_script_raises_error(self, mock_session): - """Test lines 651-653: pre_execution validation.""" - with pytest.raises( - ValueError, match="Only one of pre_execution_commands or pre_execution_script" - ): - _JobSettings( - sagemaker_session=mock_session, - pre_execution_commands=["echo test"], - pre_execution_script="/path/to/script.sh", - instance_type="ml.m5.xlarge", - image_uri="test-image", - ) - - def test_instance_type_required(self, mock_session): - """Test lines 665-666: instance_type validation.""" - with pytest.raises(ValueError, match="instance_type is a required parameter"): - _JobSettings(sagemaker_session=mock_session, image_uri="test-image") - - @patch.dict(os.environ, {"SAGEMAKER_INTERNAL_IMAGE_URI": "custom-image"}) - def test_get_default_image_from_env(self, mock_session): - """Test lines 785-788: get default image from environment.""" - image = _JobSettings._get_default_image(mock_session) - assert image == "custom-image" - - def test_get_default_image_unsupported_python(self, mock_session): - """Test lines 792-795: unsupported Python version.""" - with patch.object(sys, "version_info", (3, 7, 0)): - with pytest.raises( - ValueError, match="Default image is supported only for Python versions" - ): - _JobSettings._get_default_image(mock_session) - - def test_get_default_spark_image_unsupported_python(self, mock_session): - """Test lines 815-817: unsupported Python for Spark.""" - with patch.object(sys, "version_info", (3, 8, 0)): - with pytest.raises( - ValueError, - match="SageMaker Spark image for remote job only supports Python version 3.9", - ): - _JobSettings._get_default_spark_image(mock_session) - - -class TestJobMethods: - """Test _Job class methods for uncovered lines.""" - - def test_from_describe_response(self, mock_session): - """Test lines 848-852: from_describe_response method.""" - response = { - "TrainingJobName": "test-job", - "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - } - job = _Job.from_describe_response(response, mock_session) - assert job.job_name == "test-job" - assert job.s3_uri == "s3://bucket/output" - assert job._last_describe_response == response - - def test_describe_cached_completed(self, mock_session): - """Test lines 865-871: describe with cached completed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Completed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Completed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_cached_failed(self, mock_session): - """Test lines 865-871: describe with cached failed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Failed"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Failed" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_describe_cached_stopped(self, mock_session): - """Test lines 865-871: describe with cached stopped job.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job._last_describe_response = {"TrainingJobStatus": "Stopped"} - - result = job.describe() - assert result["TrainingJobStatus"] == "Stopped" - mock_session.sagemaker_client.describe_training_job.assert_not_called() - - def test_stop(self, mock_session): - """Test lines 886-887: stop method.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - job.stop() - mock_session.sagemaker_client.stop_training_job.assert_called_once_with( - TrainingJobName="test-job" - ) - - @patch("sagemaker.core.remote_function.job._logs_for_job") - def test_wait(self, mock_logs, mock_session): - """Test lines 889-903: wait method.""" - job = _Job("test-job", "s3://bucket/output", mock_session) - mock_logs.return_value = {"TrainingJobStatus": "Completed"} - - job.wait(timeout=100) - mock_logs.assert_called_once_with( - sagemaker_session=mock_session, job_name="test-job", wait=True, timeout=100 - ) - assert job._last_describe_response["TrainingJobStatus"] == "Completed" - - -class TestCheckpointConfig: - """Test checkpoint configuration for uncovered lines.""" - - def test_checkpoint_in_args(self): - """Test lines 1219-1227: checkpoint in positional args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = (checkpoint,) - kwargs = {} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - assert request_dict["CheckpointConfig"]["LocalPath"] == "/opt/ml/checkpoints/" - - def test_checkpoint_in_kwargs(self): - """Test lines 1228-1230: checkpoint in keyword args.""" - checkpoint = CheckpointLocation(s3_uri="s3://bucket/checkpoint") - args = () - kwargs = {"checkpoint": checkpoint} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" in request_dict - assert request_dict["CheckpointConfig"]["S3Uri"] == "s3://bucket/checkpoint" - - def test_multiple_checkpoints_raises_error(self): - """Test lines 1237-1239: multiple checkpoints error.""" - checkpoint1 = CheckpointLocation(s3_uri="s3://bucket/checkpoint1") - checkpoint2 = CheckpointLocation(s3_uri="s3://bucket/checkpoint2") - args = (checkpoint1,) - kwargs = {"checkpoint": checkpoint2} - request_dict = {} - - with pytest.raises( - ValueError, match="cannot have more than one argument of type CheckpointLocation" - ): - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - def test_no_checkpoint(self): - """Test lines 1232-1233: no checkpoint location.""" - args = ("arg1", "arg2") - kwargs = {"key": "value"} - request_dict = {} - - _update_job_request_with_checkpoint_config(args, kwargs, request_dict) - - assert "CheckpointConfig" not in request_dict - - -class TestConvertRunToJson: - """Test _convert_run_to_json for uncovered lines.""" - - def test_convert_run(self): - """Test lines 1276-1278: convert run to JSON.""" - mock_run = Mock() - mock_run.experiment_name = "test-experiment" - mock_run.run_name = "test-run" - - result = _convert_run_to_json(mock_run) - data = json.loads(result) - - assert data["experiment_name"] == "test-experiment" - assert data["run_name"] == "test-run" - - -class TestSparkDependencies: - """Test Spark dependency functions for uncovered lines.""" - - def test_upload_spark_config_none(self, mock_session): - """Test lines 1356: upload None Spark configuration.""" - result = _upload_serialized_spark_configuration( - "s3://bucket/base", "kms-key", None, mock_session - ) - assert result is None - - @patch("sagemaker.core.remote_function.job.S3Uploader") - def test_upload_spark_config(self, mock_uploader, mock_session): - """Test lines 1339-1356: upload Spark configuration.""" - config = {"spark.executor.memory": "4g"} - mock_uploader.upload_string_as_file_body = Mock() - - _upload_serialized_spark_configuration("s3://bucket/base", "kms-key", config, mock_session) - - mock_uploader.upload_string_as_file_body.assert_called_once() - - def test_upload_spark_deps_none(self, mock_session): - """Test lines 1379-1380: None dependencies.""" - result = _upload_spark_submit_deps( - None, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert result is None - - def test_upload_spark_deps_s3_uri(self, mock_session): - """Test lines 1388-1389: S3 URI dependency.""" - deps = ["s3://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3://bucket/dep.jar" in result - - def test_upload_spark_deps_s3a_uri(self, mock_session): - """Test lines 1388-1389: S3A URI dependency.""" - deps = ["s3a://bucket/dep.jar"] - result = _upload_spark_submit_deps( - deps, "workspace", "s3://bucket", "kms-key", mock_session - ) - assert "s3a://bucket/dep.jar" in result - - def test_upload_spark_deps_empty_workspace_raises_error(self, mock_session): - """Test lines 1382-1383: empty workspace validation.""" - deps = ["s3://bucket/dep.jar"] - with pytest.raises(ValueError, match="workspace_name or s3_base_uri may not be empty"): - _upload_spark_submit_deps(deps, "", "s3://bucket", "kms-key", mock_session) - - @patch("os.path.isfile", return_value=False) - def test_upload_spark_deps_invalid_file_raises_error(self, mock_isfile, mock_session): - """Test lines 1391-1392: invalid local file.""" - deps = ["/invalid/path.jar"] - with pytest.raises(ValueError, match="is not a valid local file"): - _upload_spark_submit_deps(deps, "workspace", "s3://bucket", "kms-key", mock_session) - - -class TestDistributedTraining: - """Test distributed training functions for uncovered lines.""" - - def test_extend_mpirun_no_mpirun(self, mock_session): - """Test lines 1441-1442: mpirun disabled.""" - job_settings = Mock() - job_settings.use_mpirun = False - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_mpirun_single_instance(self, mock_session): - """Test lines 1444-1445: single instance.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_mpirun_multiple_instances(self, mock_session): - """Test lines 1447-1453: multiple instances.""" - job_settings = Mock() - job_settings.use_mpirun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_mpirun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - def test_extend_torchrun_no_torchrun(self, mock_session): - """Test lines 1506-1507: torchrun disabled.""" - job_settings = Mock() - job_settings.use_torchrun = False - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_torchrun_single_instance(self, mock_session): - """Test lines 1524-1525: single instance.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 1 - request_dict = {"InputDataConfig": []} - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert result == request_dict - - def test_extend_torchrun_multiple_instances(self, mock_session): - """Test lines 1527-1533: multiple instances.""" - job_settings = Mock() - job_settings.use_torchrun = True - job_settings.instance_count = 2 - request_dict = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/data"}}}] - } - - result = _extend_torchrun_to_request(request_dict, job_settings) - assert ( - result["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3DataDistributionType"] - == "FullyReplicated" - ) - - -class TestJobStatus: - """Test job status functions for uncovered lines.""" - - def test_check_job_status_completed(self): - """Test lines 1978-1979: completed status.""" - desc = {"TrainingJobStatus": "Completed"} - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_check_job_status_stopped(self): - """Test lines 1978-1986: stopped status.""" - desc = {"TrainingJobStatus": "Stopped"} - with patch("sagemaker.core.remote_function.job.logger") as mock_logger: - _check_job_status("test-job", desc, "TrainingJobStatus") - mock_logger.warning.assert_called_once() - - def test_check_job_status_failed(self): - """Test lines 1987-2011: failed status.""" - desc = {"TrainingJobStatus": "Failed", "FailureReason": "Test failure"} - from sagemaker.core import exceptions - - with pytest.raises(exceptions.UnexpectedStatusException): - _check_job_status("test-job", desc, "TrainingJobStatus") - - def test_check_job_status_capacity_error(self): - """Test lines 2002-2007: CapacityError.""" - desc = { - "TrainingJobStatus": "Failed", - "FailureReason": "CapacityError: Insufficient capacity", - } - from sagemaker.core import exceptions - - with pytest.raises(exceptions.CapacityError): - _check_job_status("test-job", desc, "TrainingJobStatus") - - -class TestRuleStatuses: - """Test rule status functions for uncovered lines.""" - - def test_rule_statuses_no_last(self): - """Test lines 2092-2093: no last statuses.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, None) - assert result is True - - def test_rule_statuses_changed(self): - """Test lines 2095-2098: changed status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "Completed"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is True - - def test_rule_statuses_unchanged(self): - """Test lines 2100: unchanged status.""" - current = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - last = [{"RuleConfigurationName": "rule1", "RuleEvaluationStatus": "InProgress"}] - result = _rule_statuses_changed(current, last) - assert result is False - - -class TestLogsInit: - """Test _logs_init function for uncovered lines.""" - - def test_logs_init_training_job(self, mock_session): - """Test lines 2098-2105: training job.""" - description = {"ResourceConfig": {"InstanceCount": 2}} - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 2 - assert log_group == "/aws/sagemaker/TrainingJobs" - - def test_logs_init_training_job_instance_groups(self, mock_session): - """Test lines 2098-2103: training job with instance groups.""" - description = { - "ResourceConfig": {"InstanceGroups": [{"InstanceCount": 2}, {"InstanceCount": 3}]} - } - result = _logs_init(mock_session.boto_session, description, "Training") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 5 - - def test_logs_init_transform_job(self, mock_session): - """Test lines 2106-2107: transform job.""" - description = {"TransformResources": {"InstanceCount": 1}} - result = _logs_init(mock_session.boto_session, description, "Transform") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 1 - assert log_group == "/aws/sagemaker/TransformJobs" - - def test_logs_init_processing_job(self, mock_session): - """Test lines 2108-2109: processing job.""" - description = {"ProcessingResources": {"ClusterConfig": {"InstanceCount": 3}}} - result = _logs_init(mock_session.boto_session, description, "Processing") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 3 - assert log_group == "/aws/sagemaker/ProcessingJobs" - - def test_logs_init_automl_job(self, mock_session): - """Test lines 2110-2111: AutoML job.""" - description = {} - result = _logs_init(mock_session.boto_session, description, "AutoML") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = result - assert instance_count == 0 - assert log_group == "/aws/sagemaker/AutoMLJobs" - - -class TestGetInitialJobState: - """Test _get_initial_job_state for uncovered lines.""" - - def test_completed_with_wait(self): - """Test lines 2021-2023: completed job with wait.""" - description = {"TrainingJobStatus": "Completed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_failed_with_wait(self): - """Test lines 2021-2023: failed job with wait.""" - description = {"TrainingJobStatus": "Failed"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_stopped_with_wait(self): - """Test lines 2021-2023: stopped job with wait.""" - description = {"TrainingJobStatus": "Stopped"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.COMPLETE - - def test_in_progress_with_wait(self): - """Test lines 2022: in-progress job with wait.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", True) - assert state == LogState.TAILING - - def test_in_progress_without_wait(self): - """Test lines 2022: in-progress job without wait.""" - description = {"TrainingJobStatus": "InProgress"} - state = _get_initial_job_state(description, "TrainingJobStatus", False) - assert state == LogState.COMPLETE diff --git a/sagemaker-core/tests/unit/remote_function/test_logging_config.py b/sagemaker-core/tests/unit/remote_function/test_logging_config.py deleted file mode 100644 index 6454ea1071..0000000000 --- a/sagemaker-core/tests/unit/remote_function/test_logging_config.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Tests for logging_config module.""" -from __future__ import absolute_import - -import logging -import time -from unittest.mock import patch -from sagemaker.core.remote_function.logging_config import _UTCFormatter, get_logger - - -class TestUTCFormatter: - """Test _UTCFormatter class.""" - - def test_converter_is_gmtime(self): - """Test that converter is set to gmtime.""" - formatter = _UTCFormatter() - assert formatter.converter == time.gmtime - - def test_formats_time_in_utc(self): - """Test that time is formatted in UTC.""" - formatter = _UTCFormatter("%(asctime)s") - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test message", - args=(), - exc_info=None, - ) - formatted = formatter.format(record) - # Should contain UTC time format - assert formatted - - -class TestGetLogger: - """Test get_logger function.""" - - def test_returns_logger_with_correct_name(self): - """Test that logger has correct name.""" - logger = get_logger() - assert logger.name == "sagemaker.remote_function" - - def test_logger_has_info_level(self): - """Test that logger is set to INFO level.""" - logger = get_logger() - assert logger.level == logging.INFO - - def test_logger_has_handler(self): - """Test that logger has at least one handler.""" - logger = get_logger() - assert len(logger.handlers) > 0 - - def test_logger_handler_has_utc_formatter(self): - """Test that logger handler uses UTC formatter.""" - logger = get_logger() - handler = logger.handlers[0] - # Check that formatter has gmtime converter (UTC formatter characteristic) - assert handler.formatter.converter == time.gmtime - - def test_logger_does_not_propagate(self): - """Test that logger does not propagate to root logger.""" - logger = get_logger() - assert logger.propagate == 0 - - def test_get_logger_is_idempotent(self): - """Test that calling get_logger multiple times returns same logger.""" - logger1 = get_logger() - logger2 = get_logger() - assert logger1 is logger2 - - def test_logger_handler_is_stream_handler(self): - """Test that logger uses StreamHandler.""" - logger = get_logger() - assert isinstance(logger.handlers[0], logging.StreamHandler) diff --git a/sagemaker-core/tests/unit/test_base_deserializers.py b/sagemaker-core/tests/unit/test_base_deserializers.py deleted file mode 100644 index b6c94370f0..0000000000 --- a/sagemaker-core/tests/unit/test_base_deserializers.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from __future__ import absolute_import - -import pytest -import warnings - - -def test_base_deserializers_deprecation_warning(): - """Test that importing from base_deserializers raises DeprecationWarning.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # Import the module which should trigger the warning - import sagemaker.core.base_deserializers # noqa: F401 - - # Check that a warning was raised - assert len(w) >= 1 - - # Find the deprecation warning - deprecation_warnings = [ - warning for warning in w if issubclass(warning.category, DeprecationWarning) - ] - assert len(deprecation_warnings) >= 1 - - # Check the warning message - assert "base_deserializers is deprecated" in str(deprecation_warnings[0].message) - assert "sagemaker.core.deserializers" in str(deprecation_warnings[0].message) - - -def test_base_deserializers_imports_from_deserializers(): - """Test that base_deserializers re-exports from deserializers module.""" - import sagemaker.core.base_deserializers as base_deser - import sagemaker.core.deserializers as deser - - # Check that the modules have the same attributes - # (excluding private attributes and module-specific ones) - base_attrs = {attr for attr in dir(base_deser) if not attr.startswith("_")} - deser_attrs = {attr for attr in dir(deser) if not attr.startswith("_")} - - # base_deserializers should have at least the public attributes from deserializers - assert base_attrs.intersection(deser_attrs) == deser_attrs diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py index 3b98b49028..78bb32dfd5 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/_utils.py @@ -20,7 +20,7 @@ import tempfile from typing import List, Union, Optional, TYPE_CHECKING from sagemaker.core import image_uris -from sagemaker.core.training.configs import InputData +from sagemaker.train.configs import InputData # Lazy import to avoid circular dependency if TYPE_CHECKING: pass @@ -152,7 +152,7 @@ def __init__( requirements_file = self._requirements if self._requirements and self._requirements.endswith('.txt') else None # Configure ModelTrainer components for repacking - from sagemaker.core.training.configs import SourceCode, Compute, Networking + from sagemaker.train.configs import SourceCode, Compute, Networking source_code = SourceCode( source_dir=self._source_dir, diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py index 1f51612c59..ecdf1135a6 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/function_step.py @@ -44,8 +44,8 @@ from sagemaker.core.common_utils import unique_name_from_base_uuid4, format_tags, Tags if TYPE_CHECKING: - from sagemaker.core.remote_function.spark_config import SparkConfig - from sagemaker.core.remote_function.job import _JobSettings + from sagemaker.train.remote_function.spark_config import SparkConfig + from sagemaker.train.remote_function.job import _JobSettings logger = logging.getLogger(__name__) @@ -83,11 +83,11 @@ def __init__( func_kwargs (dict): keyword arguments of the python function. **kwargs: Additional arguments to be passed to the `step` decorator. """ - from sagemaker.core.remote_function.core.pipeline_variables import ( + from sagemaker.train.remote_function.core.pipeline_variables import ( convert_pipeline_variables_to_pickleable, ) - from sagemaker.core.remote_function.core.serialization import CloudpickleSerializer - from sagemaker.core.remote_function.core.stored_function import _SerializedData + from sagemaker.train.remote_function.core.serialization import CloudpickleSerializer + from sagemaker.train.remote_function.core.stored_function import _SerializedData super(_FunctionStep, self).__init__( name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies @@ -151,7 +151,7 @@ def depends_on(self, depends_on: List[Union[str, "Step", StepOutput]]): def _job_settings(self) -> "_JobSettings": """Returns the job settings for the step.""" - from sagemaker.core.remote_function.job import _JobSettings + from sagemaker.train.remote_function.job import _JobSettings context = load_step_compilation_context() @@ -193,7 +193,7 @@ def _job_settings(self) -> "_JobSettings": @property def arguments(self) -> RequestType: """Generates the arguments dictionary that is used to call `create_training_job`.""" - from sagemaker.core.remote_function.job import _Job + from sagemaker.train.remote_function.job import _Job step_compilation_context = load_step_compilation_context() @@ -274,7 +274,7 @@ def expr(self) -> RequestType: def _to_json_get(self) -> JsonGet: """Expression structure for workflow service calls using JsonGet resolution.""" - from sagemaker.core.remote_function.core.stored_function import ( + from sagemaker.train.remote_function.core.stored_function import ( JSON_SERIALIZED_RESULT_KEY, JSON_RESULTS_FILE, ) @@ -547,7 +547,7 @@ def _step(func): raise ValueError("Auto Capture of dependencies is not supported for pipeline steps.") # avoid circular import - from sagemaker.core.remote_function.client import RemoteExecutor + from sagemaker.train.remote_function.client import RemoteExecutor @wraps(func) def wrapper(*args, **kwargs): diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py index 9b4b9a191b..4d9f485ace 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py @@ -28,10 +28,10 @@ from sagemaker.core.local.local_session import LocalSession from sagemaker.core._studio import _append_project_tags from sagemaker.core.config.config_schema import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH -from sagemaker.core.remote_function.core.serialization import deserialize_obj_from_s3 -from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER -from sagemaker.core.remote_function.errors import RemoteFunctionError -from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT +from sagemaker.train.remote_function.core.serialization import deserialize_obj_from_s3 +from sagemaker.train.remote_function.core.stored_function import RESULTS_FOLDER +from sagemaker.train.remote_function.errors import RemoteFunctionError +from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT from sagemaker.core.s3 import s3_path_join from sagemaker.core.helper.session_helper import Session from sagemaker.core.common_utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py index 9169f1ce7f..eb2abe916d 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py @@ -352,8 +352,8 @@ def test_get_function_step_result_wrong_container(mock_session): def test_get_function_step_result_incomplete_job(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT - from sagemaker.core.remote_function.errors import RemoteFunctionError + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.errors import RemoteFunctionError step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { @@ -368,7 +368,7 @@ def test_get_function_step_result_incomplete_job(mock_session): def test_get_function_step_result_success(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { @@ -431,7 +431,7 @@ def test_pipeline_execution_result_waiter_error(mock_session): def test_pipeline_execution_result_terminal_failure(mock_session): from sagemaker.mlops.workflow.pipeline import _PipelineExecution from botocore.exceptions import WaiterError - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session) mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = { @@ -451,7 +451,7 @@ def test_pipeline_execution_result_terminal_failure(mock_session): def test_get_function_step_result_obsolete_s3_path(mock_session): from sagemaker.mlops.workflow.pipeline import get_function_step_result - from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT + from sagemaker.train.remote_function.job import JOBS_CONTAINER_ENTRYPOINT step_list = [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}] mock_session.describe_training_job.return_value = { diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 1c3016cf86..75cb927ca0 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -126,7 +126,7 @@ def build(self): from sagemaker.core.helper.pipeline_variable import PipelineVariable from sagemaker.core import model_uris from sagemaker.serve.utils.local_hardware import _get_available_gpus -from sagemaker.core.base_serializers import JSONSerializer +from sagemaker.core.serializers import JSONSerializer from sagemaker.core.deserializers import JSONDeserializer from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.builder.requirements_manager import RequirementsManager diff --git a/sagemaker-train/src/sagemaker/train/base_trainer.py b/sagemaker-train/src/sagemaker/train/base_trainer.py index 873b42f81b..f44c1c9b79 100644 --- a/sagemaker-train/src/sagemaker/train/base_trainer.py +++ b/sagemaker-train/src/sagemaker/train/base_trainer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List, Union from sagemaker.core.helper.session_helper import Session -from sagemaker.core.training.configs import Tag, Networking, InputData, Channel +from sagemaker.train.configs import Tag, Networking, InputData, Channel from sagemaker.core.shapes import shapes from sagemaker.core.resources import TrainingJob diff --git a/sagemaker-train/src/sagemaker/train/configs.py b/sagemaker-train/src/sagemaker/train/configs.py index 79b4eedc5e..c164c8ece4 100644 --- a/sagemaker-train/src/sagemaker/train/configs.py +++ b/sagemaker-train/src/sagemaker/train/configs.py @@ -10,22 +10,300 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.training.configs +"""This module provides the configuration classes used in ``sagemaker.modules``. + +Some of these classes are re-exported from ``sagemaker.core.shapes``. For convinence, +users can import these classes directly from ``sagemaker.modules.configs``. -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.training.configs import ... +For more documentation on ``sagemaker.core.shapes``, see: + - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes """ + from __future__ import absolute_import -import warnings +from typing import Optional, Union +from pydantic import BaseModel, model_validator, ConfigDict -# Backward compatibility: re-export from core -from sagemaker.core.training.configs import * # noqa: F401, F403 +import sagemaker.core.shapes as shapes +from sagemaker.core.helper.pipeline_variable import StrPipeVar -warnings.warn( - "sagemaker.train.configs has been moved to sagemaker.core.training.configs. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 +# TODO: Can we add custom logic to some of these to set better defaults? +from sagemaker.core.shapes import ( + StoppingCondition, + RetryStrategy, + Channel, + ShuffleConfig, + DataSource, + S3DataSource, + FileSystemDataSource, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + InstanceGroup, + HubAccessConfig, + ModelAccessConfig, + MetricDefinition, + DatasetSource, ) + +from sagemaker.train.utils import convert_unassigned_to_none + +__all__ = [ + "BaseConfig", + "SourceCode", + "StoppingCondition", + "RetryStrategy", + "OutputDataConfig", + "Channel", + "ShuffleConfig", + "DataSource", + "S3DataSource", + "FileSystemDataSource", + "TrainingImageConfig", + "TrainingRepositoryAuthConfig", + "Tag", + "InfraCheckConfig", + "RemoteDebugConfig", + "SessionChainingConfig", + "InstanceGroup", + "TensorBoardOutputConfig", + "CheckpointConfig", + "HubAccessConfig", + "ModelAccessConfig", + "Compute", + "Networking", + "InputData", + "MetricDefinition", + "DatasetSource", +] + + +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): + """SourceCode. + + The SourceCode class allows the user to specify the source code location, dependencies, + entry script, or commands to be executed in the training job container. + + Parameters: + source_dir (Optional[StrPipeVar]): + The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that + contains the source code to be used in the training job container. + requirements (Optional[StrPipeVar]): + The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed + requirements will be installed in the training job container. + entry_script (Optional[StrPipeVar]): + The path within ``source_dir`` to the entry script that will be executed in the training + job container. If not specified, command must be provided. + command (Optional[StrPipeVar]): + The command(s) to execute in the training job container. Example: "python my_script.py". + If not specified, entry_script must be provided. + """ + + source_dir: Optional[StrPipeVar] = None + requirements: Optional[StrPipeVar] = None + entry_script: Optional[StrPipeVar] = None + command: Optional[StrPipeVar] = None + + +class OutputDataConfig(shapes.OutputDataConfig): + """OutputDataConfig. + + Provides the configuration for the output data location of the training job. + + Parameters: + s3_output_path (Optional[StrPipeVar]): + The S3 URI where the output data will be stored. This is the location where the + training job will save its output data, such as model artifacts and logs. + kms_key_id (Optional[StrPipeVar]): + The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that + SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side + encryption. + compression_type (Optional[StrPipeVar]): + The model output compression type. Select None to output an uncompressed model, + recommended for large model outputs. Defaults to gzip. + """ + + s3_output_path: Optional[StrPipeVar] = None + kms_key_id: Optional[StrPipeVar] = None + compression_type: Optional[StrPipeVar] = None + + +class Compute(shapes.ResourceConfig): + """Compute. + + The Compute class is a subclass of ``sagemaker.core.shapes.ResourceConfig`` + and allows the user to specify the compute resources for the training job. + + Parameters: + instance_type (Optional[StrPipeVar]): + The ML compute instance type. For information about available instance types, + see https://aws.amazon.com/sagemaker/pricing/. + instance_count (Optional[int]): The number of ML compute instances to use. For distributed + training, provide a value greater than 1. + volume_size_in_gb (Optional[int]): + The size of the ML storage volume that you want to provision. ML storage volumes store + model artifacts and incremental states. Training algorithms might also use the ML + storage volume for scratch space. Default: 30 + volume_kms_key_id (Optional[StrPipeVar]): + The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage + volume attached to the ML compute instance(s) that run the training job. + keep_alive_period_in_seconds (Optional[int]): + The duration of time in seconds to retain configured resources in a warm pool for + subsequent training jobs. + instance_groups (Optional[List[InstanceGroup]]): + A list of instance groups for heterogeneous clusters to be used in the training job. + training_plan_arn (Optional[StrPipeVar]): + The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. + enable_managed_spot_training (Optional[bool]): + To train models using managed spot training, choose True. Managed spot training + provides a fully managed and scalable infrastructure for training machine learning + models. this option is useful when training jobs can be interrupted and when there + is flexibility when the training job is run. + """ + + volume_size_in_gb: Optional[int] = 30 + enable_managed_spot_training: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Compute": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_resource_config(self) -> shapes.ResourceConfig: + """Convert to a sagemaker.core.shapes.ResourceConfig object.""" + compute_config_dict = self.model_dump() + resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) + filtered_dict = { + k: v + for k, v in compute_config_dict.items() + if k in resource_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.ResourceConfig(**filtered_dict) + + +class Networking(shapes.VpcConfig): + """Networking. + + The Networking class is a subclass of ``sagemaker.core.shapes.VpcConfig`` and + allows the user to specify the networking configuration for the training job. + + Parameters: + security_group_ids (Optional[List[StrPipeVar]]): + The VPC security group IDs, in the form sg-xxxxxxxx. Specify the + security groups for the VPC that is specified in the Subnets field. + subnets (Optional[List[StrPipeVar]]): + The ID of the subnets in the VPC to which you want to connect your + training job or model. + enable_network_isolation (Optional[bool]): + Isolates the training container. No inbound or outbound network calls can be made, + except for calls between peers within a training cluster for distributed training. + If you enable network isolation for training jobs that are configured to use a VPC, + SageMaker downloads and uploads customer data and model artifacts through the + specified VPC, but the training container does not have network access. + enable_inter_container_traffic_encryption (Optional[bool]): + To encrypt all communications between ML compute instances in distributed training + choose True. Encryption provides greater security for distributed training, but + training might take longer. How long it takes depends on the amount of + communication between compute instances, especially if you use a deep learning + algorithm in distributed training. + """ + + security_group_ids: Optional[list[StrPipeVar]] = None + subnets: Optional[list[StrPipeVar]] = None + enable_network_isolation: Optional[bool] = None + enable_inter_container_traffic_encryption: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Networking": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_vpc_config(self) -> shapes.VpcConfig: + """Convert to a sagemaker.core.shapes.VpcConfig object.""" + compute_config_dict = self.model_dump() + vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) + filtered_dict = { + k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.VpcConfig(**filtered_dict) + + +class InputData(BaseConfig): + """InputData. + + This config allows the user to specify an input data source for the training job. + + Will be found at ``/opt/ml/input/data/`` within the training container. + For convience, can be referenced inside the training container like: + + .. code:: python + + import os + input_data_dir = os.environ['SM_CHANNEL_'] + + Parameters: + channel_name (StrPipeVar): + The name of the input data source channel. + data_source (Union[StrPipeVar, S3DataSource, FileSystemDataSource, DatasetSource]): + The data source for the channel. Can be an S3 URI string, local file path string, + S3DataSource object, FileSystemDataSource object, DatasetSource object, or a + pipeline variable (Properties) from a previous step. + content_type (StrPipeVar): + The MIME type of the data. + """ + + channel_name: StrPipeVar = None + data_source: Union[StrPipeVar, FileSystemDataSource, S3DataSource, DatasetSource] = None + content_type: StrPipeVar = None + + +class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): + """TensorBoardOutputConfig. + + The TensorBoardOutputConfig class is a subclass of ``sagemaker.core.shapes.TensorBoardOutputConfig`` + and allows the user to specify the storage locations for the Amazon SageMaker + Debugger TensorBoard. + + Parameters: + s3_output_path (Optional[StrPipeVar]): + Path to Amazon S3 storage location for TensorBoard output. If not specified, will + default to + ``s3://////tensorboard-output`` + local_path (Optional[StrPipeVar]): + Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. + """ + + s3_output_path: Optional[StrPipeVar] = None + local_path: Optional[StrPipeVar] = "/opt/ml/output/tensorboard" + + +class CheckpointConfig(shapes.CheckpointConfig): + """CheckpointConfig. + + The CheckpointConfig class is a subclass of ``sagemaker.core.shapes.CheckpointConfig`` + and allows the user to specify the checkpoint configuration for the training job. + + Parameters: + s3_uri (Optional[StrPipeVar]): + Path to Amazon S3 storage location for the Checkpoint data. If not specified, will + default to + ``s3://////checkpoints`` + local_path (Optional[StrPipeVar]): + The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. + """ + + s3_uri: Optional[StrPipeVar] = None + local_path: Optional[StrPipeVar] = "/opt/ml/checkpoints" diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 309265d659..2e136d4277 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -10,16 +10,12 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.training.constants - -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.training.constants import ... -""" +"""Constants module.""" from __future__ import absolute_import - import os +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" + SM_CODE = "code" SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" diff --git a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py index 03146a3bbe..c07aa1359a 100644 --- a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py +++ b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py @@ -124,8 +124,10 @@ def safe_deserialize(data: Any) -> Any: This function handles the following cases: 1. If `data` is not a string, it returns the input as-is. - 2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. - 3. If `data` is a string but cannot be decoded as JSON, it returns the original string. + 2. If `data` is a string and matches common boolean values ("true" or "false"), + it returns the corresponding boolean value (True or False). + 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. + 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. Returns: Any: The deserialized data, or the original input if it cannot be JSON-decoded. @@ -133,6 +135,12 @@ def safe_deserialize(data: Any) -> Any: if not isinstance(data, str): return data + lower_data = data.lower() + if lower_data in ["true"]: + return True + if lower_data in ["false"]: + return False + try: return json.loads(data) except json.JSONDecodeError: diff --git a/sagemaker-train/src/sagemaker/train/modules/model_trainer.py b/sagemaker-train/src/sagemaker/train/modules/model_trainer.py index 32af7f6c88..b7b0bd8bf5 100644 --- a/sagemaker-train/src/sagemaker/train/modules/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/modules/model_trainer.py @@ -23,8 +23,8 @@ from graphene.utils.str_converters import to_camel_case, to_snake_case -from sagemaker.core.training.configs import Compute, Networking, InputData, SourceCode -from sagemaker.core.training.constants import DEFAULT_INSTANCE_TYPE, DEFAULT_CONTAINER_ENTRYPOINT, \ +from sagemaker.train.configs import Compute, Networking, InputData, SourceCode +from sagemaker.train.constants import DEFAULT_INSTANCE_TYPE, DEFAULT_CONTAINER_ENTRYPOINT, \ DEFAULT_CONTAINER_ARGUMENTS, SM_DRIVERS, SM_CODE_CONTAINER_PATH, TRAIN_SCRIPT, DISTRIBUTED_JSON, SOURCE_CODE_JSON, \ SM_CODE, SM_DRIVERS_LOCAL_PATH from sagemaker.train.distributed import DistributedConfig, Torchrun diff --git a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py index bf29079921..87e9aca383 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/__init__.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/__init__.py @@ -10,25 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function - -This is a backward compatibility shim. Please update your imports to: - from sagemaker.core.remote_function import ... -""" +"""Defines classes and helper methods used in remote function executions.""" from __future__ import absolute_import -import warnings - -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import remote, RemoteExecutor # noqa: F401 -from sagemaker.core.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 -from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 -from sagemaker.core.remote_function.spark_config import SparkConfig # noqa: F401 - -warnings.warn( - "sagemaker.train.remote_function has been moved to sagemaker.core.remote_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) +from sagemaker.train.remote_function.client import remote, RemoteExecutor # noqa: F401 +from sagemaker.train.remote_function.checkpoint_location import CheckpointLocation # noqa: F401 +from sagemaker.train.remote_function.custom_file_filter import CustomFileFilter # noqa: F401 +from sagemaker.train.remote_function.spark_config import SparkConfig # noqa: F401 diff --git a/sagemaker-train/src/sagemaker/train/remote_function/client.py b/sagemaker-train/src/sagemaker/train/remote_function/client.py index eb99d14c1e..ecc193b8b4 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/client.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/client.py @@ -10,21 +10,1276 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.client - -This is a backward compatibility shim. -""" +"""SageMaker remote function client.""" from __future__ import absolute_import -import warnings +from concurrent.futures import ThreadPoolExecutor +from collections import deque +import time +import threading +from typing import Callable, Dict, List, Optional, Tuple, Any, Union +import functools +import itertools +import inspect -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.client import * # noqa: F401, F403 +from botocore.exceptions import ClientError +from sagemaker.core.exceptions import UnexpectedStatusException +from sagemaker.core.experiments._run_context import _RunContext -warnings.warn( - "sagemaker.train.remote_function.client has been moved to sagemaker.core.remote_function.client. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 +import sagemaker.train.remote_function.core.serialization as serialization +from sagemaker.train.remote_function.errors import ( + RemoteFunctionError, + ServiceError, + DeserializationError, +) +from sagemaker.train.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER +from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentError, ) + +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3 import s3_path_join +from sagemaker.train.remote_function.job import _JobSettings, _Job, _RunInfo +from sagemaker.train.remote_function import logging_config +from sagemaker.core.common_utils import name_from_base, base_from_name +from sagemaker.train.remote_function.spark_config import SparkConfig +from sagemaker.train.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.core.telemetry.constants import Feature + +_API_CALL_LIMIT = { + "SubmittingIntervalInSecs": 1, + "MinBatchPollingIntervalInSecs": 10, + "PollingIntervalInSecs": 0.5, +} + +# Possible future states. +_PENDING = "PENDING" +_RUNNING = "RUNNING" +# The future was cancelled by the user... +_CANCELLED = "CANCELLED" +_FINISHED = "FINISHED" + +logger = logging_config.get_logger() + + +@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote") +def remote( + _func=None, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, str] = None, + image_uri: str = None, + include_local_workdir: bool = None, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + instance_count: int = 1, + instance_type: str = None, + job_conda_env: str = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: int = 0, + max_retry_attempts: int = 1, + max_runtime_in_seconds: int = 24 * 60 * 60, + role: str = None, + s3_kms_key: str = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[str] = None, + subnets: List[str] = None, + tags: List[Tuple[str, str]] = None, + volume_kms_key: str = None, + volume_size: int = 30, + encrypt_inter_container_traffic: bool = None, + spark_config: SparkConfig = None, + use_spot_instances=False, + max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, +): + """Decorator for running the annotated function as a SageMaker training job. + + This decorator wraps the annotated code and runs it as a new SageMaker job synchronously + with the provided runtime settings. + + If a parameter value is not set, the decorator first looks up the value from the SageMaker + configuration file. If no value is specified in the configuration file or no configuration file + is found, the decorator selects the default as specified below. For more information, see + `Configuring and using defaults with the SageMaker Python SDK `_. + + Args: + _func (Optional): A Python function to run as a SageMaker training job. + + dependencies (str): Either the path to a dependencies file or the reserved keyword + ``auto_capture``. Defaults to ``None``. + If ``dependencies`` is provided, the value must be one of the following: + + * A path to a conda environment.yml file. The following conditions apply. + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then the + conda environment is updated by installing dependencies from the yaml file and the + function is invoked within that conda environment. For this to succeed, the + conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and + ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. + * If none of the previous conditions are met, a new conda environment named + ``sagemaker-runtime-env`` is created and the function annotated with the remote + decorator is invoked in that conda environment. + + * A path to a requirements.txt file. The following conditions apply. + + * If ``job_conda_env`` is set in the remote decorator, dependencies are installed + within that conda environment and the function annotated with the remote decorator + is invoked in the same conda environment. For this to succeed, the specified + conda environment must already exist in the image. + * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + dependencies are installed within that conda environment and the function annotated + with the remote decorator is invoked in the same. For this to succeed, the conda + environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and + ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. + * If none of the above conditions are met, conda is not used. Dependencies are + installed at the system level, without any virtual environment, and the function + annotated with the remote decorator is invoked using the Python runtime available + in the system path. + + * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically + generate an env_snapshot.yml corresponding to the current active conda environment’s + snapshot. You do not need to provide a dependencies file. The following conditions + apply: + + * You must run the remote function within an active conda environment. + * When installing the dependencies on the training job, the same conditions as when + dependencies is set to a path to a conda environment file apply. These conditions are + as follows: + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then + the conda environment is updated by installing dependencies from the yaml file + and the function is invoked within that conda environment. For this to + succeed, the conda environment name must already be set in + ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist + in the image. + * If none of the previous conditions are met, a new conda environment with name + ``sagemaker-runtime-env`` is created and the function annotated with the + remote decorator is invoked in that conda environment. + + * ``None``. SageMaker will assume that there are no dependencies to install while + executing the remote annotated function in the training job. + + pre_execution_commands (List[str]): List of commands to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + pre_execution_script (str): Path to script file to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + environment_variables (Dict): The environment variables used inside the decorator function. + Defaults to ``None``. + + image_uri (str): The universal resource identifier (URI) location of a Docker image on + Amazon Elastic Container Registry (ECR). Defaults to the following based on where the SDK + is running: + + * For users who specify ``spark_config`` and want to run the function in a Spark + application, the ``image_uri`` should be ``None``. A SageMaker Spark image will + be used for training, otherwise a ``ValueError`` is thrown. + * For users on SageMaker Studio notebooks, the image used as the kernel image for the + notebook is used. + * For other users, it is resolved to base python image with the same python version + as the environment running the local code. + + If no compatible image is found, a ValueError is thrown. + + include_local_workdir (bool): A flag to indicate that the remote function should include + local directories. Set to ``True`` if the remote function code imports local modules and + methods that are not available via PyPI or conda. Only python files are included. + Default value is ``False``. + + custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function + that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object + that specifies the local directories and files to be included in the remote function. + If a callable is passed in, the function should follow the protocol of ``ignore`` argument + of ``shutil.copytree``. Defaults to ``None``, which means only python + files are accepted and uploaded to S3. + + instance_count (int): The number of instances to use. Defaults to 1. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities + + instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run + the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. + + job_conda_env (str): The name of the conda environment to activate during job's runtime. + Defaults to ``None``. + + job_name_prefix (str): The prefix used used to create the underlying SageMaker job. + + keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse provisioned + infrastructure after the completion of a training job, also known as SageMaker managed + warm pools. The use of warmpools reduces the latency time spent to provision new + resources. The default value for ``keep_alive_period_in_seconds`` is 0. + NOTE: Additional charges associated with warm pools may apply. Using this parameter also + activates a new persistent cache feature, which will further reduce job start up + latency than over using SageMaker managed warm pools alone by caching the package source + downloaded in the previous runs. + + max_retry_attempts (int): The max number of times the job is retried on + ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. + + max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After + this specified amount of time, SageMaker terminates the job regardless of its current + status. Defaults to 1 day or (86400 seconds). + + role (str): The IAM role (either name or full ARN) used to run your SageMaker training + job. Defaults to: + + * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or + SageMaker Studio Notebooks. + * if not above, a ValueError is be thrown. + + s3_kms_key (str): The key used to encrypt the input and output data. Default to ``None``. + + s3_root_uri (str): The root S3 folder to which the code archives and data are + uploaded to. Defaults to ``s3://``. + + sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which + SageMaker service calls are delegated to (default: None). If not provided, one is created + using a default configuration chain. + + security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and the + training job is created without VPC config. + + subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is created + without VPC config. + + tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` and + the training job is created without tags. + + volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an + Amazon Elastic Block Storage (EBS) volume attached to the training instance. Defaults to + ``None``. + + volume_size (int): The size in GB of the storage volume for storing input and output data + during training. Defaults to ``30``. + + encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between + training containers is encrypted for the training job. Defaults to ``False``. + + spark_config (SparkConfig): Configurations to the Spark application that runs on + Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri + will be used for training. Note that ``image_uri`` can not be specified at the + same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. + + use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for + training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set. + Defaults to ``False``. + + max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. + After this amount of time Amazon SageMaker will stop waiting for managed spot training + job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. + """ + + def _remote(func): + + job_settings = _JobSettings( + dependencies=dependencies, + pre_execution_commands=pre_execution_commands, + pre_execution_script=pre_execution_script, + environment_variables=environment_variables, + image_uri=image_uri, + include_local_workdir=include_local_workdir, + custom_file_filter=custom_file_filter, + instance_count=instance_count, + instance_type=instance_type, + job_conda_env=job_conda_env, + job_name_prefix=job_name_prefix, + keep_alive_period_in_seconds=keep_alive_period_in_seconds, + max_retry_attempts=max_retry_attempts, + max_runtime_in_seconds=max_runtime_in_seconds, + role=role, + s3_kms_key=s3_kms_key, + s3_root_uri=s3_root_uri, + sagemaker_session=sagemaker_session, + security_group_ids=security_group_ids, + subnets=subnets, + tags=tags, + volume_kms_key=volume_kms_key, + volume_size=volume_size, + encrypt_inter_container_traffic=encrypt_inter_container_traffic, + spark_config=spark_config, + use_spot_instances=use_spot_instances, + max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, + use_torchrun=use_torchrun, + use_mpirun=use_mpirun, + nproc_per_node=nproc_per_node, + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): + raise ValueError( + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + + "Please provide instance_count = 1" + ) + + RemoteExecutor._validate_submit_args(func, *args, **kwargs) + + job = _Job.start(job_settings, func, args, kwargs) + + try: + job.wait() + except UnexpectedStatusException as usex: + if usex.actual_status == "Failed": + try: + exception = serialization.deserialize_exception_from_s3( + sagemaker_session=job_settings.sagemaker_session, + s3_uri=s3_path_join( + job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER + ), + + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] # pylint: disable=no-member + == "404" + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + describe_result = job.describe() + if ( + "FailureReason" in describe_result + and describe_result["FailureReason"] + and "RuntimeEnvironmentError: " in describe_result["FailureReason"] + ): + failure_msg = describe_result["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + raise RuntimeEnvironmentError(failure_msg) + raise RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + raise serr + + raise exception + + raise TimeoutError( + "Job for remote function timed out before reaching a termination status." + ) + + if job.describe()["TrainingJobStatus"] == "Completed": + return serialization.deserialize_obj_from_s3( + sagemaker_session=job_settings.sagemaker_session, + s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), + + ) + + if job.describe()["TrainingJobStatus"] == "Stopped": + raise RemoteFunctionError("Job for remote function has been aborted.") + + return None + + wrapper.job_settings = job_settings + wrapper.wrapped_func = func + return wrapper + + if _func is None: + return _remote + return _remote(_func) + + +class _SubmitRequest: + """Class that holds parameters and data for creating a new job.""" + + def __init__( + self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None + ): + self.future = future + self.job_settings = job_settings + self.func = func + self.args = func_args + self.kwargs = func_kwargs + self.run_info = run_info + + +def _submit_worker(executor): + """Background worker that submits job requests.""" + + def has_work_to_do(): + return ( + len(executor._pending_request_queue) > 0 + and len(executor._running_jobs) < executor.max_parallel_jobs + ) + + try: + while True: + with executor._state_condition: + executor._state_condition.wait_for(has_work_to_do) + request = executor._pending_request_queue[0] + + if request is None: + with executor._state_condition: + # remove the anchor from the pending queue + executor._pending_request_queue.popleft() + return + + time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"]) + # submit a new job + job = request.future._start_and_notify( + request.job_settings, request.func, request.args, request.kwargs, request.run_info + ) + + with executor._state_condition: + if job: + executor._running_jobs[job.job_name] = job + # remove the request from the pending queue + executor._pending_request_queue.popleft() + except Exception: # pylint: disable=broad-except + logger.exception("Error occurred while submitting CreateTrainingJob requests.") + + +def _polling_worker(executor): + """Background worker that polls the status of the running jobs.""" + try: + while True: + with executor._state_condition: + if ( + executor._shutdown + and len(executor._running_jobs) + len(executor._pending_request_queue) == 0 + ): + return + + time.sleep( + max( + _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"] + - len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"], + 0, + ) + ) + + # check if running jobs are terminated + for job_name in list(executor._running_jobs.keys()): + try: + time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"]) + if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [ + "Completed", + "Failed", + "Stopped", + ]: + with executor._state_condition: + del executor._running_jobs[job_name] + executor._state_condition.notify_all() + except Exception as e: # pylint: disable=broad-except + if ( + not isinstance(e, ClientError) + or e.response["Error"]["Code"] # pylint: disable=no-member + != "LimitExceededException" + ): + # Couldn't check the job status, move on + logger.exception( + "Error occurred while checking the status of job %s", job_name + ) + with executor._state_condition: + del executor._running_jobs[job_name] + executor._state_condition.notify_all() + except Exception: # pylint: disable=broad-except + logger.exception("Error occurred while monitoring the job statuses.") + + +class RemoteExecutor(object): + """Run Python functions asynchronously as SageMaker jobs""" + + def __init__( + self, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, str] = None, + image_uri: str = None, + include_local_workdir: bool = None, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + instance_count: int = 1, + instance_type: str = None, + job_conda_env: str = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: int = 0, + max_parallel_jobs: int = 1, + max_retry_attempts: int = 1, + max_runtime_in_seconds: int = 24 * 60 * 60, + role: str = None, + s3_kms_key: str = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[str] = None, + subnets: List[str] = None, + tags: List[Tuple[str, str]] = None, + volume_kms_key: str = None, + volume_size: int = 30, + encrypt_inter_container_traffic: bool = None, + spark_config: SparkConfig = None, + use_spot_instances=False, + max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, + ): + """Constructor for RemoteExecutor + + If a parameter value is not set, the constructor first looks up the value from the + SageMaker configuration file. If no value is specified in the configuration file or + no configuration file is found, the constructor selects the default as specified below. + For more information, see `Configuring and using defaults with the SageMaker Python SDK + `_. + + Args: + _func (Optional): A Python function to run as a SageMaker training job. + + dependencies (str): Either the path to a dependencies file or the reserved keyword + ``auto_capture``. Defaults to ``None``. + If ``dependencies`` is provided, the value must be one of the following: + + * A path to a conda environment.yml file. The following conditions apply. + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then + the conda environment is updated by installing dependencies from the yaml file and + the function is invoked within that conda environment. For this to succeed, the + conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and + ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. + * If none of the previous conditions are met, a new conda environment named + ``sagemaker-runtime-env`` is created and the function annotated with the remote + decorator is invoked in that conda environment. + + * A path to a requirements.txt file. The following conditions apply. + + * If ``job_conda_env`` is set in the remote decorator, dependencies are installed + within that conda environment and the function annotated with the remote decorator + is invoked in the same conda environment. For this to succeed, the specified + conda environment must already exist in the image. + * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + dependencies are installed within that conda environment and the function annotated + with the remote decorator is invoked in the same. For this to succeed, the + conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and + ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. + * If none of the above conditions are met, conda is not used. Dependencies are + installed at the system level, without any virtual environment, and the function + annotated with the remote decorator is invoked using the Python runtime available + in the system path. + + * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically + generate an env_snapshot.yml corresponding to the current active conda environment’s + snapshot. You do not need to provide a dependencies file. The following conditions + apply: + + * You must run the remote function within an active conda environment. + * When installing the dependencies on the training job, the same conditions as when + dependencies is set to a path to a conda environment file apply. These conditions + are as follows: + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + then the conda environment is updated by installing dependencies from the yaml + file and the function is invoked within that conda environment. For this to + succeed, the conda environment name must already be set in + ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist + in the image. + * If none of the previous conditions are met, a new conda environment with name + ``sagemaker-runtime-env`` is created and the function annotated with the + remote decorator is invoked in that conda environment. + + * ``None``. SageMaker will assume that there are no dependencies to install while + executing the remote annotated function in the training job. + + pre_execution_commands (List[str]): List of commands to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + pre_execution_script (str): Path to script file to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + environment_variables (Dict): The environment variables used inside the decorator + function. Defaults to ``None``. + + image_uri (str): The universal resource identifier (URI) location of a Docker image on + Amazon Elastic Container Registry (ECR). Defaults to the following based on where the + SDK is running: + + * For users who specify ``spark_config`` and want to run the function in a Spark + application, the ``image_uri`` should be ``None``. A SageMaker Spark image will + be used for training, otherwise a ``ValueError`` is thrown. + * For users on SageMaker Studio notebooks, the image used as the kernel image for + the notebook is used. + * For other users, it is resolved to base python image with the same python + version as the environment running the local code. + + If no compatible image is found, a ValueError is thrown. + + include_local_workdir (bool): A flag to indicate that the remote function should include + local directories. Set to ``True`` if the remote function code imports local modules + and methods that are not available via PyPI or conda. Default value is ``False``. + + custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function + that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object + that specifies the local directories and files to be included in the remote function. + If a callable is passed in, that function is passed to the ``ignore`` argument of + ``shutil.copytree``. Defaults to ``None``, which means only python + files are accepted and uploaded to S3. + + instance_count (int): The number of instances to use. Defaults to 1. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities + + instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run + the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. + + job_conda_env (str): The name of the conda environment to activate during job's runtime. + Defaults to ``None``. + + job_name_prefix (str): The prefix used used to create the underlying SageMaker job. + + keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse + provisioned infrastructure after the completion of a training job, also known as + SageMaker managed warm pools. The use of warmpools reduces the latency time spent to + provision new resources. The default value for ``keep_alive_period_in_seconds`` is 0. + NOTE: Additional charges associated with warm pools may apply. Using this parameter + also activates a new pesistent cache feature, which will further reduce job start + up latency than over using SageMaker managed warm pools alone by caching the package + source downloaded in the previous runs. + + max_parallel_jobs (int): Maximum number of jobs that run in parallel. Defaults to 1. + + max_retry_attempts (int): The max number of times the job is retried on + ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. + + max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After + this specified amount of time, SageMaker terminates the job regardless of its current + status. Defaults to 1 day or (86400 seconds). + + role (str): The IAM role (either name or full ARN) used to run your SageMaker training + job. Defaults to: + + * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or + SageMaker Studio Notebooks. + * if not above, a ValueError is be thrown. + + s3_kms_key (str): The key used to encrypt the input and output data. + Default to ``None``. + + s3_root_uri (str): The root S3 folder to which the code archives and data are + uploaded to. Defaults to ``s3://``. + + sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which + SageMaker service calls are delegated to (default: None). If not provided, one is + created using a default configuration chain. + + security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and + the training job is created without VPC config. + + subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is + created without VPC config. + + tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` + and the training job is created without tags. + + volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an + Amazon Elastic Block Storage (EBS) volume attached to the training instance. + Defaults to ``None``. + + volume_size (int): The size in GB of the storage volume for storing input and output + data during training. Defaults to ``30``. + + encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between + training containers is encrypted for the training job. Defaults to ``False``. + + spark_config (SparkConfig): Configurations to the Spark application that runs on + Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri + will be used for training. Note that ``image_uri`` can not be specified at the + same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. + + use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for + training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set. + Defaults to ``False``. + + max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. + After this amount of time Amazon SageMaker will stop waiting for managed spot training + job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. + """ + self.max_parallel_jobs = max_parallel_jobs + + if self.max_parallel_jobs <= 0: + raise ValueError("max_parallel_jobs must be greater than 0.") + + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): + raise ValueError( + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + + "Please provide instance_count = 1" + ) + + self.job_settings = _JobSettings( + dependencies=dependencies, + pre_execution_commands=pre_execution_commands, + pre_execution_script=pre_execution_script, + environment_variables=environment_variables, + image_uri=image_uri, + include_local_workdir=include_local_workdir, + custom_file_filter=custom_file_filter, + instance_count=instance_count, + instance_type=instance_type, + job_conda_env=job_conda_env, + job_name_prefix=job_name_prefix, + keep_alive_period_in_seconds=keep_alive_period_in_seconds, + max_retry_attempts=max_retry_attempts, + max_runtime_in_seconds=max_runtime_in_seconds, + role=role, + s3_kms_key=s3_kms_key, + s3_root_uri=s3_root_uri, + sagemaker_session=sagemaker_session, + security_group_ids=security_group_ids, + subnets=subnets, + tags=tags, + volume_kms_key=volume_kms_key, + volume_size=volume_size, + encrypt_inter_container_traffic=encrypt_inter_container_traffic, + spark_config=spark_config, + use_spot_instances=use_spot_instances, + max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, + use_torchrun=use_torchrun, + use_mpirun=use_mpirun, + nproc_per_node=nproc_per_node, + ) + + self._state_condition = threading.Condition() + self._pending_request_queue = deque() + # For thread safety, see + # https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm + self._running_jobs = dict() + self._shutdown = False + + self._workers: ThreadPoolExecutor = None + + def submit(self, func, *args, **kwargs): + """Execute the input function as a SageMaker job asynchronously. + + Args: + func: Python function to run as a SageMaker job. + *args: Positional arguments to the input function. + **kwargs: keyword arguments to the input function + """ + if self._shutdown: + raise RuntimeError("Cannot schedule new remote function executions after shutdown") + + self._validate_submit_args(func, *args, **kwargs) + + with self._state_condition: + future = Future() + + run_info = None + if _RunContext.get_current_run() is not None: + run = _RunContext.get_current_run() + run_info = _RunInfo(run.experiment_name, run.run_name) + + self._pending_request_queue.append( + _SubmitRequest(future, self.job_settings, func, args, kwargs, run_info) + ) + + if self._workers is None: + self._workers = ThreadPoolExecutor(2) + self._workers.submit(_submit_worker, self) + self._workers.submit(_polling_worker, self) + + self._state_condition.notify_all() + + return future + + def map(self, func, *iterables): + """Return an iterator that applies function to every item of iterable, yielding the results. + + If additional iterables arguments are passed, function must take that many arguments and + is applied to the items from all iterables in parallel. With multiple iterables, the + iterator stops when the shortest iterable is exhausted. + + Args: + func: Python function to run as a SageMaker job. + iterables: Arguments of the input python function. + """ + + futures = map(self.submit, itertools.repeat(func), *iterables) + return [future.result() for future in futures] + + def shutdown(self): + """Prevent more function executions to be submitted to this executor.""" + with self._state_condition: + self._shutdown = True + + # give a signal to the submitting worker so that it doesn't block on empty queue forever + self._pending_request_queue.append(None) + + self._state_condition.notify_all() + + if self._workers is not None: + self._workers.shutdown(wait=True) + + def __enter__(self): + """Create an executor instance and return it""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Make sure the executor instance is shutdown.""" + self.shutdown() + return False + + @staticmethod + def _validate_submit_args(func, *args, **kwargs): + """Validates input args passed to submit method.""" + + full_arg_spec = inspect.getfullargspec(func) + + # args related validations + + is_accepting_variable_positional_args = full_arg_spec.varargs is not None + num_default_positional_args = len(full_arg_spec.defaults) if full_arg_spec.defaults else 0 + minimum_num_expected_positional_args = len(full_arg_spec.args) - num_default_positional_args + + if not is_accepting_variable_positional_args and len(args) > len(full_arg_spec.args): + raise TypeError( + f"{func.__name__}() takes {len(full_arg_spec.args)} positional " + + f"{'arguments' if len(full_arg_spec.args) > 1 else 'argument'} but {len(args)} " + + f"{'were' if len(args) > 1 else 'was'} given." + ) + + if len(args) < minimum_num_expected_positional_args: + missing_positional_args = full_arg_spec.args[ + len(args) : minimum_num_expected_positional_args + ] + missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args)) + if missing_args: + missing_args_str = ( + ", ".join(map(lambda x: f"'{x}'", missing_args[:-1])) + + f", and '{missing_args[-1]}'" + if len(missing_args) > 1 + else f"'{missing_args[0]}'" + ) + raise TypeError( + f"{func.__name__}() missing {len(missing_args)} required positional " + + f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}" + ) + + # kwargs related validations + + for k in kwargs: + if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k): + raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'") + if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args: + raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'") + + missing_kwargs = [ + k + for k in full_arg_spec.kwonlyargs + if k not in full_arg_spec.kwonlydefaults and k not in kwargs + ] + if missing_kwargs: + missing_kwargs_string = ( + ", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1])) + + f", and '{missing_kwargs[-1]}'" + if len(missing_kwargs) > 1 + else f"'{missing_kwargs[0]}'" + ) + + raise TypeError( + f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only " + + f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: " + + f"{missing_kwargs_string}" + ) + + +class Future(object): + """Class representing a reference to a SageMaker job result. + + Reference to the SageMaker job created as a result of the remote function run. The job may + or may not have finished running. + """ + + def __init__(self): + self._condition = threading.Condition() + self._state = _PENDING + self._job = None + self._exception = None + self._return = None + + @staticmethod + def from_describe_response(describe_training_job_response, sagemaker_session): + """Construct a Future from a describe_training_job_response object.""" + future = Future() + job_exception = None + client_exception = None + job_return = None + job = _Job.from_describe_response(describe_training_job_response, sagemaker_session) + if describe_training_job_response["TrainingJobStatus"] in ["Stopping", "Stopped"]: + state = _CANCELLED + elif describe_training_job_response["TrainingJobStatus"] == "Completed": + state = _FINISHED + try: + job_return = serialization.deserialize_obj_from_s3( + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), + + ) + except DeserializationError as e: + client_exception = e + except ServiceError as e: + client_exception = e + elif describe_training_job_response["TrainingJobStatus"] == "Failed": + state = _FINISHED + try: + job_exception = serialization.deserialize_exception_from_s3( + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), + + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] == "404" # pylint: disable=no-member + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + if ( + "FailureReason" in describe_training_job_response + and describe_training_job_response["FailureReason"] + and "RuntimeEnvironmentError: " + in describe_training_job_response["FailureReason"] + ): + failure_msg = describe_training_job_response["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + job_exception = RuntimeEnvironmentError(failure_msg) + else: + job_exception = RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + else: + job_exception = serr + except DeserializationError as e: + client_exception = e + else: + state = _RUNNING + + future._job = job + future._state = state + future._exception = job_exception or client_exception + future._return = job_return + return future + + def _start_and_notify( + self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None + ): + """Start and record the newly created job in the future object. + + The job is recorded if one is successfully started. Otherwise, the exception is + recorded. The state update is broadcast to other waiting threads. + """ + with self._condition: + if self._state in [_PENDING]: + + try: + self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info) + except (Exception,) as e: # pylint: disable=broad-except + self._exception = e + self._state = _FINISHED + self._condition.notify_all() + return None + + self._state = _RUNNING + self._condition.notify_all() + return self._job + return None + + def result(self, timeout: float = None) -> Any: + """Returns the SageMaker job result. + + This method waits for the SageMaker job created from the remote function execution to + complete for up to the timeout value (if specified). If timeout is ``None``, + this method will wait until the SageMaker job completes. + + Args: + timeout (float): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: + The Python object returned by the remote function. + """ + try: + self.wait(timeout) + except UnexpectedStatusException: + pass + + with self._condition: + if self._state == _PENDING: + raise RuntimeError() + + if self._state == _RUNNING: + if self._job.describe()["TrainingJobStatus"] == "Completed": + self._return = serialization.deserialize_obj_from_s3( + sagemaker_session=self._job.sagemaker_session, + s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), + + ) + self._state = _FINISHED + return self._return + if self._job.describe()["TrainingJobStatus"] == "Failed": + try: + self._exception = serialization.deserialize_exception_from_s3( + sagemaker_session=self._job.sagemaker_session, + s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), + + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] # pylint: disable=no-member + == "404" + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + if ( + "FailureReason" in self._job.describe() + and self._job.describe()["FailureReason"] + and "RuntimeEnvironmentError: " + in self._job.describe()["FailureReason"] + ): + failure_msg = self._job.describe()["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + self._exception = RuntimeEnvironmentError(failure_msg) + else: + self._exception = RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + else: + self._exception = serr + self._state = _FINISHED + elif self._job.describe()["TrainingJobStatus"] == "Stopped": + self._state = _CANCELLED + raise RemoteFunctionError("Job for remote function has been aborted.") + else: + raise TimeoutError( + "Job for remote function timed out before reaching a termination status." + ) + + if self._state == _FINISHED: + if self._exception: + raise self._exception + return self._return + + return None + + def wait( + self, + timeout: int = None, + ) -> None: + """Wait for the underlying SageMaker job to complete. + + This method waits for the SageMaker job created as a result of the remote function run + to complete for up to the timeout value (if specified). If timeout is ``None``, this method + will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed before it is + stopped. Defaults to ``None``. + + Returns: + None + """ + + with self._condition: + if self._state == _PENDING: + self._condition.wait(timeout=timeout) + + if self._state == _RUNNING: + self._job.wait(timeout=timeout) + + def cancel(self) -> bool: + """Cancel the function execution. + + This method prevents the SageMaker job being created or stops the underlying SageMaker job + early if it is already in progress. + + Returns: + ``True`` if the underlying SageMaker job created as a result of the remote function + run is cancelled. + """ + with self._condition: + if self._state == _FINISHED: + return False + if self._state == _CANCELLED: + return True + + if self._job: + self._job.stop() + self._state = _CANCELLED + return True + + def running(self) -> bool: + """Check if the underlying SageMaker job is running. + + Returns: + ``True`` if the underlying SageMaker job is still running. ``False``, otherwise. + """ + with self._condition: + return self._state == _RUNNING + + def cancelled(self) -> bool: + """Check if the underlying SageMaker job was cancelled. + + Returns: + ``True`` if the underlying SageMaker job was cancelled. ``False``, otherwise. + """ + with self._condition: + return self._state == _CANCELLED + + def done(self) -> bool: + """Check if the underlying SageMaker job is finished. + + Returns: + ``True`` if the underlying SageMaker job finished running. ``False``, otherwise. + """ + with self._condition: + if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [ + "Completed", + "Failed", + ]: + self._state = _FINISHED + return True + + if self._state == _FINISHED: + return True + + return False + + +def get_future(job_name, sagemaker_session=None) -> Future: + """Get a future object with information about a job with the given job_name. + + Args: + job_name (str): name of the underlying SageMaker job created as a result of the remote + function run. + + sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Returns: + A `sagemaker.remote_function.client.Future` instance. + """ + if not sagemaker_session: + sagemaker_session = Session() + describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=job_name + ) + return Future.from_describe_response(describe_training_job_response, sagemaker_session) + + +def list_futures(job_name_prefix, sagemaker_session=None): + """Generates Future objects with information about jobs with given job_name_prefix. + + Args: + job_name_prefix (str): A prefix used to identify the SageMaker jobs associated with remote + function run. + sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Yields: + A `sagemaker.remote_function.client.Future` instance. + """ + if not sagemaker_session: + sagemaker_session = Session() + job_name = name_from_base(job_name_prefix) + # perform the following transformation because we might have trimmed the job_name_prefix while + # creating the job. + transformed_job_name_prefix = base_from_name(job_name) + next_token = None + list_training_job_kwargs = {"NameContains": transformed_job_name_prefix} + while True: + if next_token: + list_training_job_kwargs["NextToken"] = next_token + list_training_job_response = sagemaker_session.sagemaker_client.list_training_jobs( + **list_training_job_kwargs + ) + training_job_names = [ + job["TrainingJobName"] for job in list_training_job_response["TrainingJobSummaries"] + ] + for training_job_name in training_job_names: + describe_training_job_response = ( + sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_name + ) + ) + yield Future.from_describe_response(describe_training_job_response, sagemaker_session) + if "NextToken" in list_training_job_response: + next_token = list_training_job_response["NextToken"] + else: + break diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py b/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py index 7e9f2d30da..e69de29bb2 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/__init__.py @@ -1,27 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core - -This is a backward compatibility shim. -""" -from __future__ import absolute_import - -import warnings - -warnings.warn( - "sagemaker.train.remote_function.core has been moved to sagemaker.core.remote_function.core. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py index 20b7a297b5..857ac40eb0 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/_custom_dispatch_table.py @@ -1,4 +1,3 @@ - # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -24,7 +23,6 @@ ParameterBoolean, ) from sagemaker.core.workflow.execution_variables import ExecutionVariable -from sagemaker.mlops.workflow.function_step import DelayedReturn from sagemaker.core.workflow.properties import ( Properties, PropertiesMap, @@ -32,6 +30,19 @@ ) +# Lazy import to avoid circular dependency +# DelayedReturn is in MLOps package which depends on Core +def _get_delayed_return_class(): + """Lazy import of DelayedReturn to avoid circular dependency.""" + try: + from sagemaker.mlops.workflow.function_step import DelayedReturn + + return DelayedReturn + except ImportError: + # If MLOps is not installed, return None + return None + + def _pipeline_variable_reducer(pipeline_variable): """Reducer for pipeline variable.""" @@ -42,6 +53,7 @@ def _pipeline_variable_reducer(pipeline_variable): ) +# Build dispatch table with lazy loading for DelayedReturn dispatch_table = { ParameterInteger: _pipeline_variable_reducer, ParameterFloat: _pipeline_variable_reducer, @@ -52,5 +64,9 @@ def _pipeline_variable_reducer(pipeline_variable): Properties: _pipeline_variable_reducer, PropertiesMap: _pipeline_variable_reducer, PropertiesList: _pipeline_variable_reducer, - DelayedReturn: _pipeline_variable_reducer, } + +# Add DelayedReturn to dispatch table if MLOps is available +_delayed_return_class = _get_delayed_return_class() +if _delayed_return_class is not None: + dispatch_table[_delayed_return_class] = _pipeline_variable_reducer diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py index 5767a07596..2497302080 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/pipeline_variables.py @@ -10,21 +10,338 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.pipeline_variables - -This is a backward compatibility shim. -""" +"""SageMaker remote function data serializer/deserializer.""" from __future__ import absolute_import -import warnings +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Union, Dict, List, Tuple + +from sagemaker.core.s3 import s3_path_join +from sagemaker.train.remote_function.core.serialization import deserialize_obj_from_s3 +from sagemaker.core.workflow.step_outputs import get_step + + +@dataclass +class Context: + """Context for an execution.""" + + step_name: str = None + execution_id: str = None + property_references: Dict[str, str] = field(default_factory=dict) + serialize_output_to_json: bool = False + func_step_s3_dir: str = None + + +@dataclass +class _Parameter: + """Parameter to a function.""" + + name: str + + +class _ParameterInteger(_Parameter): + """Integer parameter to a function.""" + + ... + + +class _ParameterFloat(_Parameter): + """Float parameter to a function.""" + + ... + + +class _ParameterString(_Parameter): + """String parameter to a function.""" + + ... + + +class _ParameterBoolean(_Parameter): + """Boolean parameter to a function.""" + + ... + + +@dataclass +class _Properties: + """Properties of classic steps.""" + + path: str + + +@dataclass +class _ExecutionVariable: + """Execution variable.""" + + name: str + + +@dataclass +class _S3BaseUriIdentifier: + """Identifies that the class refers to function step s3 base uri. + + The s3_base_uri = s3_root_uri + pipeline_name. + This identifier is resolved in function step runtime by SDK. + """ + + NAME = "S3_BASE_URI" + + +@dataclass +class _DelayedReturn: + """Delayed return from a function.""" + + uri: Union[_Properties, List[Union[str, _Parameter, _ExecutionVariable]]] + reference_path: Tuple = field(default_factory=tuple) + + +class _ExecutionVariableResolver: + """Resolve execution variables.""" + + def __init__(self, context: Context): + """Resolve execution variables.""" + self._context = context + + def resolve(self, execution_variable: _ExecutionVariable): + """Resolve a single execution variable. + + Args: + execution_variable: execution variable to resolve. + Returns: + resolved value + """ + return self._context.property_references[f"Execution.{execution_variable.name}"] + + +class _ParameterResolver: + """Resolve parameters.""" + + def __init__(self, context: Context): + """Resolve parameters.""" + self._context = context + + def resolve(self, parameter: _Parameter): + """Resolve a single property reference. + + Args: + parameter: parameter to resolve. + Returns: + resolved value + """ + if isinstance(parameter, _ParameterInteger): + return int(self._context.property_references[f"Parameters.{parameter.name}"]) + if isinstance(parameter, _ParameterFloat): + return float(self._context.property_references[f"Parameters.{parameter.name}"]) + if isinstance(parameter, _ParameterString): + return self._context.property_references[f"Parameters.{parameter.name}"] + + return self._context.property_references[f"Parameters.{parameter.name}"] == "true" + + +class _PropertiesResolver: + """Resolve classic step properties.""" + + def __init__(self, context: Context): + """Resolve classic step properties.""" + self._context = context + + def resolve(self, properties: _Properties): + """Resolve classic step properties. + + Args: + properties: classic step properties. + Returns: + resolved value + """ + return self._context.property_references[properties.path] + + +class _DelayedReturnResolver: + """Resolve delayed returns.""" + + def __init__( + self, + delayed_returns: List[_DelayedReturn], + properties_resolver: _PropertiesResolver, + parameter_resolver: _ParameterResolver, + execution_variable_resolver: _ExecutionVariableResolver, + s3_base_uri: str, + **settings, + ): + """Resolve delayed return. + + Args: + delayed_returns: list of delayed returns to resolve. + properties_resolver: resolver used to resolve step properties. + parameter_resolver: resolver used to pipeline parameters. + execution_variable_resolver: resolver used to resolve execution variables. + s3_base_uri (str): the s3 base uri of the function step that + the serialized artifacts will be uploaded to. + The s3_base_uri = s3_root_uri + pipeline_name. + **settings: settings to pass to the deserialization function. + """ + self._s3_base_uri = s3_base_uri + self._parameter_resolver = parameter_resolver + self._execution_variable_resolver = execution_variable_resolver + self._properties_resolver = properties_resolver + # different delayed returns can have the same uri, so we need to dedupe + uris = { + self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns + } + + def deserialization_task(uri): + return uri, deserialize_obj_from_s3( + sagemaker_session=settings["sagemaker_session"], + s3_uri=uri, + ) + + with ThreadPoolExecutor() as executor: + self._deserialized_objects = dict(executor.map(deserialization_task, uris)) + + def resolve(self, delayed_return: _DelayedReturn) -> Any: + """Resolve a single delayed return. + + Args: + delayed_return: delayed return to resolve. + Returns: + resolved delayed return. + """ + deserialized_obj = self._deserialized_objects[ + self._resolve_delayed_return_uri(delayed_return) + ] + return _retrieve_child_item(delayed_return, deserialized_obj) + + def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn): + """Resolve the s3 uri of the delayed return.""" + if isinstance(delayed_return.uri, _Properties): + return self._properties_resolver.resolve(delayed_return.uri) + + # Keep the following old resolution logics to keep backward compatible + uri = [] + for component in delayed_return.uri: + if isinstance(component, _Parameter): + uri.append(self._parameter_resolver.resolve(component)) + elif isinstance(component, _ExecutionVariable): + uri.append(self._execution_variable_resolver.resolve(component)) + elif isinstance(component, _S3BaseUriIdentifier): + uri.append(self._s3_base_uri) + else: + uri.append(component) + return s3_path_join(*uri) + + +def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any): + """Retrieve child item from deserialized object.""" + result = deserialized_obj + for component in delayed_return.reference_path: + result = result[component[1]] + return result + + +def resolve_pipeline_variables( + context: Context, + func_args: Tuple, + func_kwargs: Dict, + s3_base_uri: str, + **settings, +): + """Resolve pipeline variables. + + Args: + context: context for the execution. + func_args: function args. + func_kwargs: function kwargs. + s3_base_uri: the s3 base uri of the function step that the serialized artifacts + will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. + **settings: settings to pass to the deserialization function. + """ + + delayed_returns = [] + + if func_args is not None: + for arg in func_args: + if isinstance(arg, _DelayedReturn): + delayed_returns.append(arg) + if func_kwargs is not None: + for arg in func_kwargs.values(): + if isinstance(arg, _DelayedReturn): + delayed_returns.append(arg) + + # build the resolvers + parameter_resolver = _ParameterResolver(context) + execution_variable_resolver = _ExecutionVariableResolver(context) + properties_resolver = _PropertiesResolver(context) + delayed_return_resolver = _DelayedReturnResolver( + delayed_returns=delayed_returns, + properties_resolver=properties_resolver, + parameter_resolver=parameter_resolver, + execution_variable_resolver=execution_variable_resolver, + s3_base_uri=s3_base_uri, + **settings, + ) + + # resolve the pipeline variables + resolved_func_args = None + if func_args is not None: + resolved_func_args = [] + for arg in func_args: + if isinstance(arg, _Parameter): + resolved_func_args.append(parameter_resolver.resolve(arg)) + elif isinstance(arg, _ExecutionVariable): + resolved_func_args.append(execution_variable_resolver.resolve(arg)) + elif isinstance(arg, _Properties): + resolved_func_args.append(properties_resolver.resolve(arg)) + elif isinstance(arg, _DelayedReturn): + resolved_func_args.append(delayed_return_resolver.resolve(arg)) + else: + resolved_func_args.append(arg) + resolved_func_args = tuple(resolved_func_args) + + resolved_func_kwargs = None + if func_kwargs is not None: + resolved_func_kwargs = {} + for key, value in func_kwargs.items(): + if isinstance(value, _Parameter): + resolved_func_kwargs[key] = parameter_resolver.resolve(value) + elif isinstance(value, _ExecutionVariable): + resolved_func_kwargs[key] = execution_variable_resolver.resolve(value) + elif isinstance(value, _Properties): + resolved_func_kwargs[key] = properties_resolver.resolve(value) + elif isinstance(value, _DelayedReturn): + resolved_func_kwargs[key] = delayed_return_resolver.resolve(value) + else: + resolved_func_kwargs[key] = value + + return resolved_func_args, resolved_func_kwargs + + +def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict): + """Convert pipeline variables to pickleable. + + Args: + func_args: function args. + func_kwargs: function kwargs. + """ + + from sagemaker.core.helper.pipeline_variable import PipelineVariable + + from sagemaker.mlops.workflow.function_step import DelayedReturn + + def convert(arg): + if isinstance(arg, DelayedReturn): + return _DelayedReturn( + uri=get_step(arg)._properties.OutputDataConfig.S3OutputPath._pickleable, + reference_path=arg._reference_path, + ) + + if isinstance(arg, PipelineVariable): + return arg._pickleable + + return arg -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.pipeline_variables import * # noqa: F401, F403 + converted_func_args = tuple(convert(arg) for arg in func_args) + converted_func_kwargs = {key: convert(arg) for key, arg in func_kwargs.items()} -warnings.warn( - "sagemaker.train.remote_function.core.pipeline_variables has been moved to sagemaker.core.remote_function.core.pipeline_variables. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) + return converted_func_args, converted_func_kwargs diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py index d30d1494d5..7eed7e0d21 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/serialization.py @@ -10,21 +10,401 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.serialization - -This is a backward compatibility shim. -""" +"""SageMaker remote function data serializer/deserializer.""" from __future__ import absolute_import -import warnings +import dataclasses +import json + +import io + +import sys +import hashlib +import pickle -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.serialization import * # noqa: F401, F403 +from typing import Any, Callable, Union -warnings.warn( - "sagemaker.train.remote_function.core.serialization has been moved to sagemaker.core.remote_function.core.serialization. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 +import cloudpickle +from tblib import pickling_support + +from sagemaker.train.remote_function.errors import ( + ServiceError, + SerializationError, + DeserializationError, ) +from sagemaker.core.s3 import S3Downloader, S3Uploader +from sagemaker.core.helper.session_helper import Session +from ._custom_dispatch_table import dispatch_table + +# Note: do not use os.path.join for s3 uris, fails on windows + + +def _get_python_version(): + """Returns the current python version.""" + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + +@dataclasses.dataclass +class _MetaData: + """Metadata about the serialized data or functions.""" + + sha256_hash: str + version: str = "2023-04-24" + python_version: str = _get_python_version() + serialization_module: str = "cloudpickle" + + def to_json(self): + """Converts metadata to json string.""" + return json.dumps(dataclasses.asdict(self)).encode() + + @staticmethod + def from_json(s): + """Converts json string to metadata object.""" + try: + obj = json.loads(s) + except json.decoder.JSONDecodeError: + raise DeserializationError("Corrupt metadata file. It is not a valid json file.") + + sha256_hash = obj.get("sha256_hash") + metadata = _MetaData(sha256_hash=sha256_hash) + metadata.version = obj.get("version") + metadata.python_version = obj.get("python_version") + metadata.serialization_module = obj.get("serialization_module") + + if not sha256_hash: + raise DeserializationError( + "Corrupt metadata file. SHA256 hash for the serialized data does not exist. " + "Please make sure to install SageMaker SDK version >= 2.156.0 on the client side " + "and try again." + ) + + if not ( + metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle" + ): + raise DeserializationError( + f"Corrupt metadata file. Serialization approach {s} is not supported." + ) + + return metadata + + +class CloudpickleSerializer: + """Serializer using cloudpickle.""" + + @staticmethod + def serialize(obj: Any) -> bytes: + """Serializes data object and uploads it to S3. + + Args: + obj: object to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + try: + io_buffer = io.BytesIO() + custom_pickler = cloudpickle.CloudPickler(io_buffer) + dt = pickle.Pickler.dispatch_table.__get__(custom_pickler) # pylint: disable=no-member + new_dt = dt.new_child(dispatch_table) + pickle.Pickler.dispatch_table.__set__( # pylint: disable=no-member + custom_pickler, new_dt + ) + custom_pickler.dump(obj) + return io_buffer.getvalue() + except Exception as e: + if isinstance( + e, NotImplementedError + ) and "Instance of Run type is not allowed to be pickled." in str(e): + raise SerializationError( + """You are trying to pass a sagemaker.experiments.run.Run object to + a remote function + or are trying to access a global sagemaker.experiments.run.Run object + from within the function. This is not supported. + You must use `load_run` to load an existing Run in the remote function + or instantiate a new Run in the function.""" + ) + + raise SerializationError( + "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) + ) from e + + @staticmethod + def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: + """Downloads from S3 and then deserializes data objects. + + Args: + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + bytes_to_deserialize: bytes to be deserialized. + Returns : + List of deserialized python objects. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + + try: + return cloudpickle.loads(bytes_to_deserialize) + except Exception as e: + raise DeserializationError( + "Error when deserializing bytes downloaded from {}: {}. " + "NOTE: this may be caused by inconsistent sagemaker python sdk versions " + "where remote function runs versus the one used on client side. " + "If the sagemaker versions do not match, a warning message would " + "be logged starting with 'Inconsistent sagemaker versions found'. " + "Please check it to validate.".format(s3_uri, repr(e)) + ) from e + + +# TODO: use dask serializer in case dask distributed is installed in users' environment. +def serialize_func_to_s3( + func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None +): + """Serializes function and uploads it to S3. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + func: function to be serialized and persisted + Raises: + SerializationError: when fail to serialize function to bytes. + """ + + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(func), + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, + ) + + +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: + """Downloads from S3 and then deserializes data objects. + + This method downloads the serialized training job outputs to a temporary directory and + then deserializes them using dask. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying sagemaker session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + The deserialized function. + Raises: + DeserializationError: when fail to serialize function to bytes. + """ + metadata = _MetaData.from_json( + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) + ) + + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) + + +def serialize_obj_to_s3( + obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None +): + """Serializes data object and uploads it to S3. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + obj: object to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(obj), + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, + ) + + +def json_serialize_obj_to_s3( + obj: Any, + json_key: str, + sagemaker_session: Session, + s3_uri: str, + s3_kms_key: str = None, +): + """Json serializes data object and uploads it to S3. + + If a function step's output is data referenced by other steps via JsonGet, + its output should be json serialized and uploaded to S3. + + Args: + obj: (Any) object to be serialized and persisted. + json_key: (str) the json key pointing to function step output. + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + """ + json_serialized_result = {} + try: + to_dump = {json_key: obj, "Exception": None} + json_serialized_result = json.dumps(to_dump) + except TypeError as e: + if "is not JSON serializable" in str(e): + to_dump = { + json_key: None, + "Exception": f"The function return ({obj}) is not JSON serializable.", + } + json_serialized_result = json.dumps(to_dump) + + S3Uploader.upload_string_as_file_body( + body=json_serialized_result, + desired_s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + kms_key=s3_kms_key, + ) + + +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: + """Downloads from S3 and then deserializes data objects. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying sagemaker session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + Deserialized python objects. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + + metadata = _MetaData.from_json( + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) + ) + + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) + + +def serialize_exception_to_s3( + exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None +): + """Serializes exception with traceback and uploads it to S3. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + exc: Exception to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + pickling_support.install() + + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(exc), + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, + ) + + +def _upload_payload_and_metadata_to_s3( + bytes_to_upload: Union[bytes, io.BytesIO], + s3_uri: str, + sagemaker_session: Session, + s3_kms_key, +): + """Uploads serialized payload and metadata to s3. + + Args: + bytes_to_upload (bytes): Serialized bytes to upload. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + """ + _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) + + sha256_hash = _compute_hash(bytes_to_upload) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + f"{s3_uri}/metadata.json", + s3_kms_key, + sagemaker_session, + ) + + +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: + """Downloads from S3 and then deserializes exception. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): + The underlying sagemaker session which AWS service calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + Deserialized exception with traceback. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + + metadata = _MetaData.from_json( + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) + ) + + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) + + +def _upload_bytes_to_s3(b: Union[bytes, io.BytesIO], s3_uri, s3_kms_key, sagemaker_session): + """Wrapping s3 uploading with exception translation for remote function.""" + try: + S3Uploader.upload_bytes(b, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session) + except Exception as e: + raise ServiceError( + "Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e)) + ) from e + + +def _read_bytes_from_s3(s3_uri, sagemaker_session): + """Wrapping s3 downloading with exception translation for remote function.""" + try: + return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session) + except Exception as e: + raise ServiceError( + "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) + ) from e + + +def _compute_hash(buffer: bytes) -> str: + """Compute the sha256 hash""" + return hashlib.sha256(buffer).hexdigest() + + +def _perform_integrity_check(expected_hash_value: str, buffer: bytes): + """Performs integrity checks for serialized code/arguments uploaded to s3. + + Verifies whether the hash read from s3 matches the hash calculated + during remote function execution. + """ + actual_hash_value = _compute_hash(buffer=buffer) + if expected_hash_value != actual_hash_value: + raise DeserializationError( + "Integrity check for the serialized function or data failed. " + "Please restrict access to your S3 bucket" + ) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py index 34915a4d42..ff5d8ddad4 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/core/stored_function.py @@ -10,21 +10,214 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.core.stored_function - -This is a backward compatibility shim. -""" +"""SageMaker job function serializer/deserializer.""" from __future__ import absolute_import -import warnings +import os +from dataclasses import dataclass +from typing import Any -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.core.stored_function import * # noqa: F401, F403 -warnings.warn( - "sagemaker.train.remote_function.core.stored_function has been moved to sagemaker.core.remote_function.core.stored_function. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 +from sagemaker.core.s3 import s3_path_join +from sagemaker.train.remote_function import logging_config +from sagemaker.train.remote_function.core.pipeline_variables import ( + Context, + resolve_pipeline_variables, ) + +import sagemaker.train.remote_function.core.serialization as serialization +from sagemaker.core.helper.session_helper import Session + + +logger = logging_config.get_logger() + + +FUNCTION_FOLDER = "function" +ARGUMENTS_FOLDER = "arguments" +RESULTS_FOLDER = "results" +EXCEPTION_FOLDER = "exception" +JSON_SERIALIZED_RESULT_KEY = "Result" +JSON_RESULTS_FILE = "results.json" + + +@dataclass +class _SerializedData: + """Data class to store serialized function and arguments""" + + func: bytes + args: bytes + + +class StoredFunction: + """Class representing a remote function stored in S3.""" + + def __init__( + self, + sagemaker_session: Session, + s3_base_uri: str, + s3_kms_key: str = None, + context: Context = Context(), + ): + """Construct a StoredFunction object. + + Args: + sagemaker_session: (sagemaker.session.Session): The underlying sagemaker session which + AWS service calls are delegated to. + s3_base_uri: the base uri to which serialized artifacts will be uploaded. + s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. + context: Build or run context of a pipeline step. + """ + self.sagemaker_session = sagemaker_session + self.s3_base_uri = s3_base_uri + self.s3_kms_key = s3_kms_key + self.context = context + + # For pipeline steps, function code is at: base/step_name/build_timestamp/ + # For results, path is: base/step_name/build_timestamp/execution_id/ + # This ensures uniqueness: build_timestamp per build, execution_id per run + if context.step_name and context.func_step_s3_dir: + # Pipeline step: include build timestamp in both paths + self.func_upload_path = s3_path_join( + s3_base_uri, context.step_name, context.func_step_s3_dir + ) + self.results_upload_path = s3_path_join( + s3_base_uri, context.step_name, context.func_step_s3_dir, context.execution_id + ) + else: + # Regular remote function: original behavior + self.func_upload_path = s3_path_join( + s3_base_uri, context.step_name, context.func_step_s3_dir + ) + self.results_upload_path = s3_path_join( + s3_base_uri, context.execution_id, context.step_name + ) + + def save(self, func, *args, **kwargs): + """Serialize and persist the function and arguments. + + Args: + func: the python function. + args: the positional arguments to func. + kwargs: the keyword arguments to func. + Returns: + None + """ + + logger.info( + "Serializing function code to %s", s3_path_join(self.func_upload_path, FUNCTION_FOLDER) + ) + serialization.serialize_func_to_s3( + func=func, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + s3_kms_key=self.s3_kms_key, + + ) + + logger.info( + "Serializing function arguments to %s", + s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + ) + + serialization.serialize_obj_to_s3( + obj=(args, kwargs), + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + + s3_kms_key=self.s3_kms_key, + ) + + def save_pipeline_step_function(self, serialized_data): + """Upload serialized function and arguments to s3. + + Args: + serialized_data (_SerializedData): The serialized function + and function arguments of a function step. + """ + + logger.info( + "Uploading serialized function code to %s", + s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + ) + serialization._upload_payload_and_metadata_to_s3( + bytes_to_upload=serialized_data.func, + + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + sagemaker_session=self.sagemaker_session, + s3_kms_key=self.s3_kms_key, + ) + + logger.info( + "Uploading serialized function arguments to %s", + s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + ) + serialization._upload_payload_and_metadata_to_s3( + bytes_to_upload=serialized_data.args, + + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + sagemaker_session=self.sagemaker_session, + s3_kms_key=self.s3_kms_key, + ) + + def load_and_invoke(self) -> Any: + """Load and deserialize the function and the arguments and then execute it.""" + + logger.info( + "Deserializing function code from %s", + s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + ) + func = serialization.deserialize_func_from_s3( + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + + ) + + logger.info( + "Deserializing function arguments from %s", + s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + ) + args, kwargs = serialization.deserialize_obj_from_s3( + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + + ) + + logger.info("Resolving pipeline variables") + resolved_args, resolved_kwargs = resolve_pipeline_variables( + self.context, + args, + kwargs, + + s3_base_uri=self.s3_base_uri, + sagemaker_session=self.sagemaker_session, + ) + + logger.info("Invoking the function") + result = func(*resolved_args, **resolved_kwargs) + + logger.info( + "Serializing the function return and uploading to %s", + s3_path_join(self.results_upload_path, RESULTS_FOLDER), + ) + serialization.serialize_obj_to_s3( + obj=result, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), + + s3_kms_key=self.s3_kms_key, + ) + + if self.context and self.context.serialize_output_to_json: + logger.info( + "JSON Serializing the function return and uploading to %s", + s3_path_join(self.results_upload_path, RESULTS_FOLDER), + ) + serialization.json_serialize_obj_to_s3( + obj=result, + json_key=JSON_SERIALIZED_RESULT_KEY, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join( + os.path.join(self.results_upload_path, RESULTS_FOLDER, JSON_RESULTS_FILE) + ), + s3_kms_key=self.s3_kms_key, + ) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py b/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py index 9c1b1e1baa..c82cc7eee7 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/custom_file_filter.py @@ -125,4 +125,4 @@ def _filter_non_python_files(path: str, names: List) -> List: _src, dst, ignore=_ignore, - ) \ No newline at end of file + ) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/errors.py b/sagemaker-train/src/sagemaker/train/remote_function/errors.py index e67fcf7d9f..bfebb0726a 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/errors.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/errors.py @@ -10,21 +10,93 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.errors - -This is a backward compatibility shim. -""" +"""Definitions for reomote job errors and error handling""" from __future__ import absolute_import -import warnings +import os + +from tblib import pickling_support +from sagemaker.core.s3 import s3_path_join +import sagemaker.train.remote_function.core.serialization as serialization + + +DEFAULT_FAILURE_CODE = 1 +FAILURE_REASON_PATH = "/opt/ml/output/failure" + + +@pickling_support.install +class RemoteFunctionError(Exception): + """The base exception class for remote function exceptions""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +@pickling_support.install +class ServiceError(RemoteFunctionError): + """Raised when errors encountered during interaction with SageMaker, S3 service APIs""" + + +@pickling_support.install +class SerializationError(RemoteFunctionError): + """Raised when errors encountered during serialization of remote function objects""" + + +@pickling_support.install +class DeserializationError(RemoteFunctionError): + """Raised when errors encountered during deserialization of remote function objects""" + + +def _get_valid_failure_exit_code(exit_code) -> int: + """Normalize exit code for terminating the process""" + try: + valid_exit_code = int(exit_code) + except (TypeError, ValueError): + valid_exit_code = DEFAULT_FAILURE_CODE + + return valid_exit_code + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if remote function execution failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write(failure_msg) + + +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: + """Handle all exceptions raised during remote function execution. + + Args: + error (Exception): The error to be handled. + sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which + AWS service calls are delegated to. + s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + Returns : + exit_code (int): Exit code to terminate current job. + """ + + failure_reason = repr(error) + if isinstance(error, RemoteFunctionError): + exit_code = DEFAULT_FAILURE_CODE + else: + error_number = getattr(error, "errno", DEFAULT_FAILURE_CODE) + exit_code = _get_valid_failure_exit_code(error_number) + + _write_failure_reason_file(failure_reason) -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.errors import * # noqa: F401, F403 + serialization.serialize_exception_to_s3( + exc=error, + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(s3_base_uri, "exception"), + s3_kms_key=s3_kms_key, + ) -warnings.warn( - "sagemaker.train.remote_function.errors has been moved to sagemaker.core.remote_function.errors. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) + return exit_code diff --git a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py index 3bafeffd5b..4606b73459 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/invoke_function.py @@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context: def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context + sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context ): """Execute stored remote function""" from sagemaker.train.remote_function.core.stored_function import StoredFunction @@ -107,7 +107,6 @@ def _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, context=context, ) @@ -138,15 +137,12 @@ def main(sys_args=None): run_in_context = args.run_in_context pipeline_context = _load_pipeline_context(args) - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - sagemaker_session = _get_sagemaker_session(region) _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, run_in_context=run_in_context, - hmac_key=hmac_key, context=pipeline_context, ) @@ -162,11 +158,10 @@ def main(sys_args=None): sagemaker_session=sagemaker_session, s3_base_uri=s3_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, ) finally: sys.exit(exit_code) if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/job.py b/sagemaker-train/src/sagemaker/train/remote_function/job.py index 33bf62af86..90c6807b53 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/job.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/job.py @@ -10,21 +10,2112 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Helper classes that interact with SageMaker Training service.""" +from __future__ import absolute_import + +import dataclasses +import json +import os +import re +import shutil +import sys +import time +from io import BytesIO +from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from urllib.parse import urlparse + +import botocore +from botocore.exceptions import ClientError + +from sagemaker.core.config.config_schema import ( + REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, + REMOTE_FUNCTION_IMAGE_URI, + REMOTE_FUNCTION_DEPENDENCIES, + REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, + REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, + REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, + REMOTE_FUNCTION_INSTANCE_TYPE, + REMOTE_FUNCTION_JOB_CONDA_ENV, + REMOTE_FUNCTION_ROLE_ARN, + REMOTE_FUNCTION_S3_ROOT_URI, + REMOTE_FUNCTION_S3_KMS_KEY_ID, + REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, + REMOTE_FUNCTION_TAGS, + REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, + REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, + REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +from sagemaker.core.experiments._run_context import _RunContext +from sagemaker.core.experiments.run import Run +from sagemaker.core.image_uris import get_base_python_image_uri +from sagemaker.core import image_uris +from sagemaker.train.remote_function.checkpoint_location import CheckpointLocation +from sagemaker.core.helper.session_helper import get_execution_role, expand_role, Session +from sagemaker.core.common_utils import ( + name_from_base, + _tmpdir, + resolve_value_from_config, + format_tags, + Tags, +) +from sagemaker.core.s3 import s3_path_join, S3Uploader + +from sagemaker.train.remote_function.core.stored_function import StoredFunction, _SerializedData +from sagemaker.train.remote_function.core.pipeline_variables import Context + +from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, + _DependencySettings, +) +from sagemaker.train.remote_function import logging_config +from sagemaker.train.remote_function.spark_config import SparkConfig +from sagemaker.train.remote_function.custom_file_filter import ( + CustomFileFilter, + copy_workdir, + resolve_custom_file_filter_from_config_file, +) + +# Lazy import to avoid circular dependency - DelayedReturn is in MLOps which depends on Core +# from sagemaker.mlops.workflow.function_step import DelayedReturn +from sagemaker.core.workflow.step_outputs import get_step +from sagemaker.core import exceptions +from sagemaker.core import network as vpc_utils + +from sagemaker.core import logs as sagemaker_logs + +from sagemaker.core.common_utils import ( + _wait_until, + secondary_training_status_changed, + secondary_training_status_message, +) +from sagemaker.core.config.config_utils import _append_sagemaker_config_tags + +if TYPE_CHECKING: + from sagemaker.core.helper.pipeline_variable import PipelineVariable + +# runtime script names +BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" +ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" +PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" +RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" +SPARK_APP_SCRIPT_NAME = "spark_app.py" + +# training channel names +RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" +REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" +JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" +SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" + +# Spark config channel and file name +SPARK_CONF_CHANNEL_NAME = "conf" +SPARK_CONF_FILE_NAME = "configuration.json" + +# Spark submitted files workspace names on S3 +SPARK_SUBMIT_JARS_WORKSPACE = "sm_rf_spark_jars" +SPARK_SUBMIT_PY_FILES_WORKSPACE = "sm_rf_spark_py_files" +SPARK_SUBMIT_FILES_WORKSPACE = "sm_rf_spark_data_files" +SPARK_CONF_WORKSPACE = "sm_rf_spark_conf" + +# default spark version +DEFAULT_SPARK_VERSION = "3.3" +DEFAULT_SPARK_CONTAINER_VERSION = "v1" + +SPARK_NAME = "spark" + +# run context dictionary keys +KEY_EXPERIMENT_NAME = "experiment_name" +KEY_RUN_NAME = "run_name" + +JOBS_CONTAINER_ENTRYPOINT = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", +] + +SPARK_APP_SCRIPT_PATH = f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{SPARK_APP_SCRIPT_NAME}" + +ENTRYPOINT_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" +else + printf "INFO: No conda env provided. Invoking remote function\\n" + printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" + python -m sagemaker.train.remote_function.invoke_function "$@" +fi """ -DEPRECATED: This module has been moved to sagemaker.core.remote_function.job -This is a backward compatibility shim. +ENTRYPOINT_MPIRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with mpirun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + + python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +else + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" + printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" + + mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +fi """ -from __future__ import absolute_import -import warnings +ENTRYPOINT_TORCHRUN_SCRIPT = f""" +#!/bin/bash -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.job import * # noqa: F401, F403 +# Entry point for bootstrapping runtime environment and invoking remote function with torchrun -warnings.warn( - "sagemaker.train.remote_function.job has been moved to sagemaker.core.remote_function.job. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.train.remote_function.invoke_function \\n" + + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.train.remote_function.invoke_function "$@" +else + printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" + printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" + + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" +fi +""" + +SPARK_ENTRYPOINT_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function for Spark + +set -eu + +printf "INFO: Bootstraping Spark runtime environment.\\n" + +python3 /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" + +# Spark Container entry point script to initiate the spark application +smspark-submit "$@" +""" + +_STATUS_CODE_TABLE = { + "COMPLETED": "Completed", + "INPROGRESS": "InProgress", + "IN_PROGRESS": "InProgress", + "FAILED": "Failed", + "STOPPED": "Stopped", + "STOPPING": "Stopping", + "STARTING": "Starting", + "PENDING": "Pending", +} + +logger = logging_config.get_logger() + + +class LogState(object): + """Placeholder docstring""" + + STARTING = 1 + WAIT_IN_PROGRESS = 2 + TAILING = 3 + JOB_COMPLETE = 4 + COMPLETE = 5 + + +class _JobSettings: + """Helper class that processes the job settings. + + It validates the job settings and provides default values if necessary. + """ + + def __init__( + self, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, Union[str, "PipelineVariable"]] = None, + image_uri: Union[str, "PipelineVariable"] = None, + include_local_workdir: bool = None, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + instance_count: Union[int, "PipelineVariable"] = 1, + instance_type: Union[str, "PipelineVariable"] = None, + job_conda_env: Union[str, "PipelineVariable"] = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: Union[int, "PipelineVariable"] = 0, + max_retry_attempts: Union[int, "PipelineVariable"] = 1, + max_runtime_in_seconds: Union[int, "PipelineVariable"] = 24 * 60 * 60, + role: str = None, + s3_kms_key: Union[str, "PipelineVariable"] = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[Union[str, "PipelineVariable"]] = None, + subnets: List[Union[str, "PipelineVariable"]] = None, + tags: Optional[Tags] = None, + volume_kms_key: Union[str, "PipelineVariable"] = None, + volume_size: Union[int, "PipelineVariable"] = 30, + encrypt_inter_container_traffic: Union[bool, "PipelineVariable"] = None, + spark_config: SparkConfig = None, + use_spot_instances=False, + max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, + ): + """Initialize a _JobSettings instance which configures the remote job. + + Args: + dependencies (str): Either the path to a dependencies file or the reserved keyword + ``auto_capture``. Defaults to ``None``. + If ``dependencies`` is provided, the value must be one of the following: + + * A path to a conda environment.yml file. The following conditions apply. + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + then the conda environment is updated by installing dependencies from the + yaml file and the function is invoked within that conda environment. For + this to succeed, the conda environment name must already be set in + ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already + exist in the image. + * If none of the previous conditions are met, a new conda environment named + ``sagemaker-runtime-env`` is created and the function annotated with the remote + decorator is invoked in that conda environment. + + * A path to a requirements.txt file. The following conditions apply. + + * If ``job_conda_env`` is set in the remote decorator, dependencies are installed + within that conda environment and the function annotated with the remote decorator + is invoked in the same conda environment. For this to succeed, the specified + conda environment must already exist in the image. + * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + dependencies are installed within that conda environment and the function + annotated with the remote decorator is invoked in the same. For this to succeed, + the conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and + ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image. + * If none of the above conditions are met, conda is not used. Dependencies are + installed at the system level, without any virtual environment, and the function + annotated with the remote decorator is invoked using the Python runtime available + in the system path. + + * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically + generate an env_snapshot.yml corresponding to the current active conda environment’s + snapshot. You do not need to provide a dependencies file. The following conditions + apply: + + * You must run the remote function within an active conda environment. + * When installing the dependencies on the training job, the same conditions + as when dependencies is set to a path to a conda environment file apply. + These conditions are as follows: + + * If job_conda_env is set, then the conda environment is updated by installing + dependencies from the yaml file and the function is invoked within that + conda environment. For this to succeed, the specified conda environment must + already exist in the image. + * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, + then the conda environment is updated by installing dependencies from the yaml + file and the function is invoked within that conda environment. For this to + succeed, the conda environment name must already be set in + ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist + in the image. + * If none of the previous conditions are met, a new conda environment with name + ``sagemaker-runtime-env`` is created and the function annotated with the + remote decorator is invoked in that conda environment. + + * ``None``. SageMaker will assume that there are no dependencies to install while + executing the remote annotated function in the training job. + + pre_execution_commands (List[str]): List of commands to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + pre_execution_script (str): Path to script file to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + + environment_variables (dict[str, str] or dict[str, PipelineVariable]): The environment + variables used inside the decorator function. Defaults to ``None``. + + image_uri (str, PipelineVariable): The universal resource identifier (URI) location of + a Docker image on Amazon Elastic Container Registry (ECR). Defaults to the following + based on where the SDK is running: + + * For users who specify ``spark_config`` and want to run the function in a Spark + application, the ``image_uri`` should be ``None``. A SageMaker Spark image will + be used for training, otherwise a ``ValueError`` is thrown. + * For users on SageMaker Studio notebooks, the image used as the kernel image for + the notebook is used. + * For other users, it is resolved to base python image with the same python version + as the environment running the local code. + + If no compatible image is found, a ValueError is thrown. + + include_local_workdir (bool): A flag to indicate that the remote function should include + local directories. Set to ``True`` if the remote function code imports local modules + and methods that are not available via PyPI or conda. Default value is ``False``. + + custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function + that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object + that specifies the local directories and files to be included in the remote function. + If a callable is passed in, that function is passed to the ``ignore`` argument of + ``shutil.copytree``. Defaults to ``None``, which means only python + files are accepted and uploaded to S3. + + instance_count (int, PipelineVariable): The number of instances to use. Defaults to 1. + + instance_type (str, PipelineVariable): The Amazon Elastic Compute Cloud (EC2) instance + type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, + a ValueError is thrown. + + job_conda_env (str, PipelineVariable): The name of the conda environment to activate + during job's runtime. Defaults to ``None``. + + job_name_prefix (str, PipelineVariable): The prefix used to create the underlying + SageMaker job. + + keep_alive_period_in_seconds (int, PipelineVariable): The duration in seconds to retain + and reuse provisioned infrastructure after the completion of a training job, also + known as SageMaker managed warm pools. The use of warm pools reduces the latency time + spent to provision new resources. The default value for + ``keep_alive_period_in_seconds`` is 0. + NOTE: Additional charges associated with warm pools may apply. Using this parameter + also activates a new persistent cache feature, which will further reduce job start up + latency than over using SageMaker managed warm pools alone by caching the package + source downloaded in the previous runs. + + max_retry_attempts (int, PipelineVariable): The max number of times the job is retried + on ``InternalServerFailure`` Error from SageMaker service. Defaults to 1. + + max_runtime_in_seconds (int, PipelineVariable): The upper limit in seconds to be used + for training. After this specified amount of time, SageMaker terminates the job + regardless of its current status. Defaults to 1 day or (86400 seconds). + + role (str): The IAM role (either name or full ARN) used to run your SageMaker training + job. Defaults to: + + * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or + SageMaker Studio Notebooks. + * if not above, a ValueError is thrown. + + s3_kms_key (str): The key used to encrypt the input and output data. + Default to ``None``. + + s3_root_uri (str): The root S3 folder to which the code archives and data are + uploaded to. Defaults to ``s3://``. + + sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to + which SageMaker service calls are delegated to (default: None). If not provided, + one is created using a default configuration chain. + + security_group_ids (List[str, PipelineVariable]): A list of security group IDs. + Defaults to ``None`` and the training job is created without VPC config. + + subnets (List[str, PipelineVariable]): A list of subnet IDs. Defaults to ``None`` + and the job is created without VPC config. + + tags (Optional[Tags]): Tags attached to the job. Defaults to ``None`` + and the training job is created without tags. + + volume_kms_key (str, PipelineVariable): An Amazon Key Management Service (KMS) key + used to encrypt an Amazon Elastic Block Storage (EBS) volume attached to the + training instance. Defaults to ``None``. + + volume_size (int, PipelineVariable): The size in GB of the storage volume for storing + input and output data during training. Defaults to ``30``. + + encrypt_inter_container_traffic (bool, PipelineVariable): A flag that specifies + whether traffic between training containers is encrypted for the training job. + Defaults to ``False``. + + spark_config (SparkConfig): Configurations to the Spark application that runs on + Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri + will be used for training. Note that ``image_uri`` can not be specified at the + same time otherwise a ``ValueError`` is thrown. Defaults to ``None``. + + use_spot_instances (bool, PipelineVariable): Specifies whether to use SageMaker + Managed Spot instances for training. If enabled then the ``max_wait`` arg should + also be set. Defaults to ``False``. + + max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. + After this amount of time Amazon SageMaker will stop waiting for managed spot + training job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. + """ + self.sagemaker_session = sagemaker_session or Session() + self.environment_variables = resolve_value_from_config( + direct_input=environment_variables, + config_path=REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, + default_value={}, + sagemaker_session=self.sagemaker_session, + ) + self.environment_variables.update( + {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} + ) + + if spark_config and image_uri: + raise ValueError("spark_config and image_uri cannot be specified at the same time!") + + if spark_config and job_conda_env: + raise ValueError("Remote Spark jobs do not support job_conda_env.") + + if spark_config and dependencies == "auto_capture": + raise ValueError( + "Remote Spark jobs do not support automatically capturing dependencies." + ) + + _image_uri = resolve_value_from_config( + direct_input=image_uri, + config_path=REMOTE_FUNCTION_IMAGE_URI, + sagemaker_session=self.sagemaker_session, + ) + + if spark_config: + self.image_uri = self._get_default_spark_image(self.sagemaker_session) + logger.info( + "Set the image uri as %s because value of spark_config is " + "indicating this is a remote spark job.", + self.image_uri, + ) + elif _image_uri: + self.image_uri = _image_uri + else: + self.image_uri = self._get_default_image(self.sagemaker_session) + + self.dependencies = resolve_value_from_config( + direct_input=dependencies, + config_path=REMOTE_FUNCTION_DEPENDENCIES, + sagemaker_session=self.sagemaker_session, + ) + + self.pre_execution_commands = resolve_value_from_config( + direct_input=pre_execution_commands, + config_path=REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, + sagemaker_session=self.sagemaker_session, + ) + + self.pre_execution_script = resolve_value_from_config( + direct_input=pre_execution_script, + config_path=REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, + sagemaker_session=self.sagemaker_session, + ) + + if self.pre_execution_commands is not None and self.pre_execution_script is not None: + raise ValueError( + "Only one of pre_execution_commands or pre_execution_script can be specified!" + ) + + self.include_local_workdir = resolve_value_from_config( + direct_input=include_local_workdir, + config_path=REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + + self.custom_file_filter = resolve_custom_file_filter_from_config_file( + custom_file_filter, self.sagemaker_session + ) + + self.instance_type = resolve_value_from_config( + direct_input=instance_type, + config_path=REMOTE_FUNCTION_INSTANCE_TYPE, + sagemaker_session=self.sagemaker_session, + ) + if not self.instance_type: + raise ValueError("instance_type is a required parameter!") + + self.instance_count = instance_count + self.volume_size = volume_size + self.max_runtime_in_seconds = max_runtime_in_seconds + self.max_retry_attempts = max_retry_attempts + self.keep_alive_period_in_seconds = keep_alive_period_in_seconds + self.spark_config = spark_config + self.use_spot_instances = use_spot_instances + self.max_wait_time_in_seconds = max_wait_time_in_seconds + self.job_conda_env = resolve_value_from_config( + direct_input=job_conda_env, + config_path=REMOTE_FUNCTION_JOB_CONDA_ENV, + sagemaker_session=self.sagemaker_session, + ) + self.job_name_prefix = job_name_prefix + self.encrypt_inter_container_traffic = resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + self.enable_network_isolation = False + + _role = resolve_value_from_config( + direct_input=role, + config_path=REMOTE_FUNCTION_ROLE_ARN, + sagemaker_session=self.sagemaker_session, + ) + if _role: + self.role = expand_role(self.sagemaker_session.boto_session, _role) + else: + self.role = get_execution_role(self.sagemaker_session) + + self.s3_root_uri = resolve_value_from_config( + direct_input=s3_root_uri, + config_path=REMOTE_FUNCTION_S3_ROOT_URI, + default_value=s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + ), + sagemaker_session=self.sagemaker_session, + ) + + self.s3_kms_key = resolve_value_from_config( + direct_input=s3_kms_key, + config_path=REMOTE_FUNCTION_S3_KMS_KEY_ID, + sagemaker_session=self.sagemaker_session, + ) + self.volume_kms_key = resolve_value_from_config( + direct_input=volume_kms_key, + config_path=REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, + sagemaker_session=self.sagemaker_session, + ) + + _subnets = resolve_value_from_config( + direct_input=subnets, + config_path=REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, + sagemaker_session=self.sagemaker_session, + ) + _security_group_ids = resolve_value_from_config( + direct_input=security_group_ids, + config_path=REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, + sagemaker_session=self.sagemaker_session, + ) + vpc_config = vpc_utils.to_dict(subnets=_subnets, security_group_ids=_security_group_ids) + self.vpc_config = vpc_utils.sanitize(vpc_config) + + tags = format_tags(tags) + self.tags = _append_sagemaker_config_tags( + self.sagemaker_session, tags, REMOTE_FUNCTION_TAGS + ) + + self.disable_output_compression = disable_output_compression + self.use_torchrun = use_torchrun + self.use_mpirun = use_mpirun + self.nproc_per_node = nproc_per_node + + @staticmethod + def _get_default_image(session): + """Return Studio notebook image, if in Studio env. Else, base python. + + Args: + session (Session): Boto session. + + Returns: + Default SageMaker base python image. + """ + + if ( + "SAGEMAKER_INTERNAL_IMAGE_URI" in os.environ + and os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] + ): + return os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] + + py_version = str(sys.version_info[0]) + str(sys.version_info[1]) + + if py_version not in ["310", "38"]: + raise ValueError( + "Default image is supported only for Python versions 3.8 and 3.10. If you " + "are using any other python version, you must provide a compatible image_uri." + ) + + region = session.boto_region_name + image_uri = get_base_python_image_uri(region=region, py_version=py_version) + + return image_uri + + @staticmethod + def _get_default_spark_image(session): + """Return the Spark image. + + Args: + session (Session): Boto session. + + Returns: + SageMaker Spark container image uri. + """ + + region = session.boto_region_name + + py_version = str(sys.version_info[0]) + str(sys.version_info[1]) + + if py_version not in ["39"]: + raise ValueError( + "The SageMaker Spark image for remote job only supports Python version 3.9. " + ) + + image_uri = image_uris.retrieve( + framework=SPARK_NAME, + region=region, + version=DEFAULT_SPARK_VERSION, + instance_type=None, + py_version=f"py{py_version}", + container_version=DEFAULT_SPARK_CONTAINER_VERSION, + ) + + return image_uri + + +class _Job: + """Helper class that interacts with the SageMaker training service.""" + + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): + """Initialize a _Job object. + + Args: + job_name (str): The training job name. + s3_uri (str): The training job output S3 uri. + sagemaker_session (Session): SageMaker boto session. + """ + self.job_name = job_name + self.s3_uri = s3_uri + self.sagemaker_session = sagemaker_session + self._last_describe_response = None + + @staticmethod + def from_describe_response(describe_training_job_response, sagemaker_session): + """Construct a _Job from a describe_training_job_response object. + + Args: + describe_training_job_response (Dict): Describe training job response. + sagemaker_session (Session): SageMaker boto session. + + Returns: + the _Job object. + """ + job_name = describe_training_job_response["TrainingJobName"] + s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] + + job = _Job(job_name, s3_uri, sagemaker_session) + job._last_describe_response = describe_training_job_response + return job + + @staticmethod + def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None): + """Start a training job. + + Args: + job_settings (_JobSettings): the job settings. + func: the function to be executed. + func_args: the positional arguments to the function. + func_kwargs: the keyword arguments to the function + + Returns: + the _Job object. + """ + job_name = _Job._get_job_name(job_settings, func) + s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name) + + training_job_request = _Job.compile( + job_settings=job_settings, + job_name=job_name, + s3_base_uri=s3_base_uri, + func=func, + func_args=func_args, + func_kwargs=func_kwargs, + run_info=run_info, + ) + + logger.info("Creating job: %s", job_name) + + job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) + + return _Job( + job_name, + s3_base_uri, + job_settings.sagemaker_session, + ) + + @staticmethod + def compile( + job_settings: _JobSettings, + job_name: str, + s3_base_uri: str, + func: Callable, + func_args: tuple, + func_kwargs: dict, + run_info=None, + serialized_data: _SerializedData = None, + ) -> dict: + """Build the artifacts and generate the training job request.""" + from sagemaker.core.workflow.properties import Properties + from sagemaker.core.workflow.parameters import Parameter + from sagemaker.core.workflow.functions import Join + from sagemaker.core.workflow.execution_variables import ( + ExecutionVariables, + ExecutionVariable, + ) + from sagemaker.core.workflow.utilities import load_step_compilation_context + + step_compilation_context = load_step_compilation_context() + + jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:] + + # serialize function and arguments + if step_compilation_context is None: + stored_function = StoredFunction( + sagemaker_session=job_settings.sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + ) + stored_function.save(func, *func_args, **func_kwargs) + else: + stored_function = StoredFunction( + sagemaker_session=job_settings.sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + context=Context( + step_name=step_compilation_context.step_name, + func_step_s3_dir=step_compilation_context.pipeline_build_time, + ), + ) + + stored_function.save_pipeline_step_function(serialized_data) + + stopping_condition = { + "MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds, + } + if job_settings.max_wait_time_in_seconds is not None: + stopping_condition["MaxWaitTimeInSeconds"] = job_settings.max_wait_time_in_seconds + + request_dict = dict( + TrainingJobName=job_name, + RoleArn=job_settings.role, + StoppingCondition=stopping_condition, + RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts}, + ) + + _update_job_request_with_checkpoint_config(func_args, func_kwargs, request_dict) + + if job_settings.tags: + request_dict["Tags"] = job_settings.tags + + # generate other build artifacts including workspace, requirements.txt + request_dict["InputDataConfig"] = _generate_input_data_config( + job_settings=job_settings, s3_base_uri=s3_base_uri + ) + + if step_compilation_context: + # Path format: base/step_name/build_timestamp/execution_id/results + # This matches the path construction in stored_function.py + s3_output_path = Join( + on="/", + values=[ + s3_base_uri, + step_compilation_context.step_name, + step_compilation_context.pipeline_build_time, + ExecutionVariables.PIPELINE_EXECUTION_ID, + "results", + ], + ) + output_config = {"S3OutputPath": s3_output_path} + else: + output_config = {"S3OutputPath": s3_base_uri} + if job_settings.s3_kms_key is not None: + output_config["KmsKeyId"] = job_settings.s3_kms_key + if job_settings.disable_output_compression: + output_config["CompressionType"] = "NONE" + request_dict["OutputDataConfig"] = output_config + + container_args = ["--s3_base_uri", s3_base_uri] + container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name]) + container_args.extend( + ["--client_python_version", RuntimeEnvironmentManager()._current_python_version()] + ) + container_args.extend( + [ + "--client_sagemaker_pysdk_version", + RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(), + ] + ) + container_args.extend( + [ + "--dependency_settings", + _DependencySettings.from_dependency_file_path( + job_settings.dependencies + ).to_string(), + ] + ) + if job_settings.use_torchrun: + container_args.extend(["--distribution", "torchrun"]) + elif job_settings.use_mpirun: + container_args.extend(["--distribution", "mpirun"]) + if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: + container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) + if job_settings.s3_kms_key: + container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) + + if job_settings.job_conda_env: + container_args.extend(["--job_conda_env", job_settings.job_conda_env]) + + if step_compilation_context: + # TODO: remove the duplicates in the list + container_args.extend(["--pipeline_step_name", step_compilation_context.step_name]) + container_args.extend( + ["--pipeline_execution_id", ExecutionVariables.PIPELINE_EXECUTION_ID] + ) + container_args.extend( + ["--func_step_s3_dir", step_compilation_context.pipeline_build_time] + ) + container_args.extend(["--property_references"]) + container_args.extend( + [ + ExecutionVariables.PIPELINE_EXECUTION_ID.expr["Get"], + ExecutionVariables.PIPELINE_EXECUTION_ID.to_string(), + ] + ) + for arg in func_args + tuple(func_kwargs.values()): + if isinstance(arg, (Parameter, ExecutionVariable, Properties)): + container_args.extend([arg.expr["Get"], arg.to_string()]) + + # Lazy import to avoid circular dependency + try: + from sagemaker.mlops.workflow.function_step import DelayedReturn + + if isinstance(arg, DelayedReturn): + # The uri is a Properties object + uri = get_step(arg)._properties.OutputDataConfig.S3OutputPath + container_args.extend([uri.expr["Get"], uri.to_string()]) + except ImportError: + # MLOps not installed, skip DelayedReturn handling + pass + + if run_info is not None: + container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))]) + elif _RunContext.get_current_run() is not None: + container_args.extend( + ["--run_in_context", _convert_run_to_json(_RunContext.get_current_run())] + ) + + algorithm_spec = dict( + TrainingImage=job_settings.image_uri, + TrainingInputMode="File", + ContainerEntrypoint=jobs_container_entrypoint, + ContainerArguments=container_args, + ) + + request_dict["AlgorithmSpecification"] = algorithm_spec + + resource_config = dict( + VolumeSizeInGB=job_settings.volume_size, + InstanceCount=job_settings.instance_count, + InstanceType=job_settings.instance_type, + ) + if job_settings.volume_kms_key is not None: + resource_config["VolumeKmsKeyId"] = job_settings.volume_kms_key + if job_settings.keep_alive_period_in_seconds is not None: + resource_config["KeepAlivePeriodInSeconds"] = job_settings.keep_alive_period_in_seconds + + request_dict["ResourceConfig"] = resource_config + + if job_settings.enable_network_isolation is not None: + request_dict["EnableNetworkIsolation"] = job_settings.enable_network_isolation + + if job_settings.encrypt_inter_container_traffic is not None: + request_dict["EnableInterContainerTrafficEncryption"] = ( + job_settings.encrypt_inter_container_traffic + ) + + if job_settings.vpc_config: + request_dict["VpcConfig"] = job_settings.vpc_config + + request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances + + request_dict["Environment"] = job_settings.environment_variables + + extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_mpirun_to_request(extended_request, job_settings) + extended_request = _extend_torchrun_to_request(extended_request, job_settings) + + return extended_request + + def describe(self): + """Describe the underlying sagemaker training job. + + Returns: + Dict: Describe training job response. + """ + if self._last_describe_response is not None and self._last_describe_response[ + "TrainingJobStatus" + ] in ["Completed", "Failed", "Stopped"]: + return self._last_describe_response + + self._last_describe_response = ( + self.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=self.job_name + ) + ) + + return self._last_describe_response + + def stop(self): + """Stop the underlying sagemaker training job.""" + self.sagemaker_session.sagemaker_client.stop_training_job(TrainingJobName=self.job_name) + + def wait(self, timeout: int = None): + """Wait for the underlying sagemaker job to finish and displays its logs . + + This method blocks on the sagemaker job completing for up to the timeout value (if + specified). If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: None + """ + + self._last_describe_response = _logs_for_job( + sagemaker_session=self.sagemaker_session, + job_name=self.job_name, + wait=True, + timeout=timeout, + ) + + @staticmethod + def _get_job_name(job_settings, func): + """Get the underlying SageMaker job name from job_name_prefix or func. + + Args: + job_settings (_JobSettings): the job settings. + func: the function to be executed. + + Returns: + str : the training job name. + """ + from sagemaker.core.workflow.utilities import load_step_compilation_context + + step_complication_context = load_step_compilation_context() + + job_name_prefix = job_settings.job_name_prefix + if not job_name_prefix: + job_name_prefix = func.__name__ + # remove all special characters in the beginning of function name + job_name_prefix = re.sub(r"^[^a-zA-Z0-9]+", "", job_name_prefix) + # convert all remaining special characters to '-' + job_name_prefix = re.sub(r"[^a-zA-Z0-9-]", "-", job_name_prefix) + + if step_complication_context: + return job_name_prefix + return name_from_base(job_name_prefix) + + +def _prepare_and_upload_runtime_scripts( + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + use_torchrun: bool = False, + use_mpirun: bool = False, +): + """Copy runtime scripts to a folder and upload to S3. + + In case of remote function, s3_base_uri is s3_root_uri + function_name. + In case of pipeline, s3_base_uri is s3_root_uri + pipeline_name. The runtime scripts are + uploaded only once per pipeline. + + Args: + spark_config (SparkConfig): remote Spark job configurations. + + s3_base_uri (str): S3 location that the runtime scripts will be uploaded to. + + s3_kms_key (str): kms key used to encrypt the files uploaded to S3. + + sagemaker_session (str): SageMaker boto client session. + + use_torchrun (bool): Whether to use torchrun or not. + + use_mpirun (bool): Whether to use mpirun or not. + + nproc_per_node (Optional[int]): Number of processes per node + """ + + from sagemaker.core.workflow.utilities import load_step_compilation_context + + step_compilation_context = load_step_compilation_context() + + if step_compilation_context and not step_compilation_context.upload_runtime_scripts: + return s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME) + + with _tmpdir() as bootstrap_scripts: + + # write entrypoint script to tmpdir + entrypoint_script_path = os.path.join(bootstrap_scripts, ENTRYPOINT_SCRIPT_NAME) + entry_point_script = ENTRYPOINT_SCRIPT + if spark_config: + entry_point_script = SPARK_ENTRYPOINT_SCRIPT + spark_script_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", SPARK_APP_SCRIPT_NAME + ) + shutil.copy2(spark_script_path, bootstrap_scripts) + + if use_torchrun: + entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT + + if use_mpirun: + entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT + + with open(entrypoint_script_path, "w", newline="\n") as file: + file.writelines(entry_point_script) + + bootstrap_script_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME + ) + mpi_utils_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME + ) + runtime_manager_script_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME + ) + + # copy runtime scripts to tmpdir + shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(mpi_utils_path, bootstrap_scripts) + shutil.copy2(runtime_manager_script_path, bootstrap_scripts) + + upload_path = S3Uploader.upload( + bootstrap_scripts, + s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME), + s3_kms_key, + sagemaker_session, + ) + + if step_compilation_context: + step_compilation_context.upload_runtime_scripts = False + return upload_path + + +def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): + """Generates input data config""" + from sagemaker.core.workflow.utilities import load_step_compilation_context + + step_compilation_context = load_step_compilation_context() + + bootstrap_scripts_s3uri = _prepare_and_upload_runtime_scripts( + spark_config=job_settings.spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + use_torchrun=job_settings.use_torchrun, + use_mpirun=job_settings.use_mpirun, + ) + + input_data_config = [ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": bootstrap_scripts_s3uri, + "S3DataType": "S3Prefix", + } + }, + ) + ] + + local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) + + if step_compilation_context: + with _tmpdir() as tmp_dir: + script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( + local_dependencies_path=local_dependencies_path, + pre_execution_commands=job_settings.pre_execution_commands, + pre_execution_script_local_path=job_settings.pre_execution_script, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + tmp_dir=tmp_dir, + ) + + if script_and_dependencies_s3uri: + input_data_config.append( + dict( + ChannelName=SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": script_and_dependencies_s3uri, + "S3DataType": "S3Prefix", + } + }, + ) + ) + + user_workspace_s3uri = _prepare_and_upload_workspace( + local_dependencies_path=local_dependencies_path, + include_local_workdir=job_settings.include_local_workdir, + pre_execution_commands=job_settings.pre_execution_commands, + pre_execution_script_local_path=job_settings.pre_execution_script, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + custom_file_filter=job_settings.custom_file_filter, + ) + + if user_workspace_s3uri: + input_data_config.append( + dict( + ChannelName=( + REMOTE_FUNCTION_WORKSPACE + if not step_compilation_context + else step_compilation_context.pipeline_build_time + ), + DataSource={ + "S3DataSource": { + "S3Uri": user_workspace_s3uri, + "S3DataType": "S3Prefix", + } + }, + ) + ) + + return input_data_config + + +def _prepare_dependencies_and_pre_execution_scripts( + local_dependencies_path: str, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + tmp_dir: str, +): + """Prepare pre-execution scripts and dependencies and upload them to s3. + + If pre execution commands are provided, a new bash file will be created + with those commands in tmp directory. + If pre execution script is provided, it copies that file from local file path + to tmp directory. + If local dependencies file is provided, it copies that file from local file path + to tmp directory. + If under pipeline context, tmp directory with copied dependencies and scripts is + uploaded to S3. + """ + from sagemaker.core.workflow.utilities import load_step_compilation_context + + if not (local_dependencies_path or pre_execution_commands or pre_execution_script_local_path): + return None + + if local_dependencies_path: + dst_path = shutil.copy2(local_dependencies_path, tmp_dir) + logger.info("Copied dependencies file at '%s' to '%s'", local_dependencies_path, dst_path) + + if pre_execution_commands or pre_execution_script_local_path: + pre_execution_script = os.path.join(tmp_dir, PRE_EXECUTION_SCRIPT_NAME) + if pre_execution_commands: + with open(pre_execution_script, "w") as target_script: + commands = [cmd + "\n" for cmd in pre_execution_commands] + target_script.writelines(commands) + logger.info( + "Generated pre-execution script from commands to '%s'", pre_execution_script + ) + else: + shutil.copy2(pre_execution_script_local_path, pre_execution_script) + logger.info( + "Copied pre-execution commands from script at '%s' to '%s'", + pre_execution_script_local_path, + pre_execution_script, + ) + + step_compilation_context = load_step_compilation_context() + if step_compilation_context: + upload_path = S3Uploader.upload( + tmp_dir, + s3_path_join( + s3_base_uri, + step_compilation_context.step_name, + step_compilation_context.pipeline_build_time, + SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, + ), + s3_kms_key, + sagemaker_session, + ) + logger.info( + "Successfully uploaded dependencies and pre execution scripts to '%s'", upload_path + ) + return upload_path + return None + + +def _prepare_and_upload_workspace( + local_dependencies_path: str, + include_local_workdir: bool, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, +) -> str: + """Prepare and upload the workspace to S3. + + Under pipeline context, only workdir is packaged in the workspace folder and uploaded to s3. + Under remote function context, workdir along with pre execution scripts and dependencies + are packaged together into the workspace folder and uploaded to S3. + """ + from sagemaker.core.workflow.utilities import load_step_compilation_context + + step_compilation_context = load_step_compilation_context() + + if not ( + local_dependencies_path + or include_local_workdir + or pre_execution_commands + or pre_execution_script_local_path + ): + return None + + func_step_s3_dir = None + if step_compilation_context: + func_step_s3_dir = step_compilation_context.pipeline_build_time + if not include_local_workdir: + return None + if not step_compilation_context.upload_workspace: + return s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE, func_step_s3_dir) + + with _tmpdir() as tmp_dir: + tmp_workspace_dir = os.path.join(tmp_dir, "temp_workspace/") + os.mkdir(tmp_workspace_dir) + # TODO Remove the following hack to avoid dir_exists error in the copy_tree call below. + tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE) + + if include_local_workdir: + copy_workdir(tmp_workspace, custom_file_filter) + logger.info("Copied user workspace to '%s'", tmp_workspace) + + if not os.path.isdir(tmp_workspace): + # create the directory if no workdir_path was provided in the input. + os.mkdir(tmp_workspace) + + if not step_compilation_context: + _prepare_dependencies_and_pre_execution_scripts( + local_dependencies_path=local_dependencies_path, + pre_execution_commands=pre_execution_commands, + pre_execution_script_local_path=pre_execution_script_local_path, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + tmp_dir=tmp_workspace, + ) + + workspace_archive_path = os.path.join(tmp_dir, "workspace") + workspace_archive_path = shutil.make_archive( + workspace_archive_path, "zip", tmp_workspace_dir + ) + logger.info("Successfully created workdir archive at '%s'", workspace_archive_path) + + upload_path = S3Uploader.upload( + workspace_archive_path, + s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE, func_step_s3_dir), + s3_kms_key, + sagemaker_session, + ) + logger.info("Successfully uploaded workdir to '%s'", upload_path) + if step_compilation_context: + step_compilation_context.upload_workspace = False + return upload_path + + +def _convert_run_to_json(run: Run) -> str: + """Convert current run into json string""" + run_info = _RunInfo(run.experiment_name, run.run_name) + return json.dumps(dataclasses.asdict(run_info)) + + +def _prepare_and_upload_spark_dependent_files( + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, +) -> Tuple: + """Upload the Spark dependencies to S3 if present. + + Args: + spark_config (SparkConfig): The remote Spark job configurations. + s3_base_uri (str): The S3 location that the Spark dependencies will be uploaded to. + s3_kms_key (str): The kms key used to encrypt the files uploaded to S3. + sagemaker_session (str): SageMaker boto client session. + """ + if not spark_config: + return None, None, None, None + + submit_jars_s3_paths = _upload_spark_submit_deps( + spark_config.submit_jars, + SPARK_SUBMIT_JARS_WORKSPACE, + s3_base_uri, + s3_kms_key, + sagemaker_session, + ) + submit_py_files_s3_paths = _upload_spark_submit_deps( + spark_config.submit_py_files, + SPARK_SUBMIT_PY_FILES_WORKSPACE, + s3_base_uri, + s3_kms_key, + sagemaker_session, + ) + submit_files_s3_path = _upload_spark_submit_deps( + spark_config.submit_files, + SPARK_SUBMIT_FILES_WORKSPACE, + s3_base_uri, + s3_kms_key, + sagemaker_session, + ) + config_file_s3_uri = _upload_serialized_spark_configuration( + s3_base_uri, s3_kms_key, spark_config.configuration, sagemaker_session + ) + + return submit_jars_s3_paths, submit_py_files_s3_paths, submit_files_s3_path, config_file_s3_uri + + +def _upload_spark_submit_deps( + submit_deps: List[str], + workspace_name: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, +) -> str: + """Upload the Spark submit dependencies to S3. + + Args: + submit_deps (List[str]): A list of path which points to the Spark dependency files. + The path can be either a local path or S3 uri. For example ``/local/deps.jar`` or + ``s3:///deps.jar``. + + workspace_name (str): workspace name for Spark dependency. + s3_base_uri (str): S3 location that the Spark dependencies will be uploaded to. + s3_kms_key (str): kms key used to encrypt the files uploaded to S3. + sagemaker_session (str): SageMaker boto client session. + + Returns: + str : The concatenated path of all dependencies which will be passed to Spark. + """ + spark_opt_s3_uris = [] + if not submit_deps: + return None + + if not workspace_name or not s3_base_uri: + raise ValueError("workspace_name or s3_base_uri may not be empty.") + + for dep_path in submit_deps: + dep_url = urlparse(dep_path) + + if dep_url.scheme in ["s3", "s3a"]: + spark_opt_s3_uris.append(dep_path) + elif not dep_url.scheme or dep_url.scheme == "file": + if not os.path.isfile(dep_path): + raise ValueError(f"submit_deps path {dep_path} is not a valid local file.") + + upload_path = S3Uploader.upload( + local_path=dep_path, + desired_s3_uri=s3_path_join(s3_base_uri, workspace_name), + kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + spark_opt_s3_uris.append(upload_path) + logger.info("Uploaded the local file %s to %s", dep_path, upload_path) + return str.join(",", spark_opt_s3_uris) + + +def _upload_serialized_spark_configuration( + s3_base_uri: str, s3_kms_key: str, configuration: Dict, sagemaker_session: Session +) -> str: + """Upload the Spark configuration json to S3""" + if not configuration: + return None + + serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8")) + config_file_s3_uri = s3_path_join(s3_base_uri, SPARK_CONF_WORKSPACE, SPARK_CONF_FILE_NAME) + + S3Uploader.upload_string_as_file_body( + body=serialized_configuration, + desired_s3_uri=config_file_s3_uri, + kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + logger.info("Uploaded spark configuration json %s to %s", configuration, config_file_s3_uri) + + return config_file_s3_uri + + +def _extend_mpirun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with mpirun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_mpirun = job_settings.use_mpirun + instance_count = job_settings.instance_count + + if not use_mpirun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + +def _extend_torchrun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with torchrun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_torchrun = job_settings.use_torchrun + instance_count = job_settings.instance_count + + if not use_torchrun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + +def _extend_spark_config_to_request( + request_dict: Dict, + job_settings: _JobSettings, + s3_base_uri: str, +) -> Dict: + """Extend the create training job request with spark configurations. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + s3_base_uri (str): S3 location that the Spark dependencies will be uploaded to. + """ + spark_config = job_settings.spark_config + + if not spark_config: + return request_dict + + extended_request = request_dict.copy() + container_entrypoint = extended_request["AlgorithmSpecification"]["ContainerEntrypoint"] + + ( + submit_jars_s3_paths, + submit_py_files_s3_paths, + submit_files_s3_path, + config_file_s3_uri, + ) = _prepare_and_upload_spark_dependent_files( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + ) + + input_data_config = extended_request["InputDataConfig"] + + if config_file_s3_uri: + input_data_config.append( + dict( + ChannelName=SPARK_CONF_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": config_file_s3_uri, + "S3DataType": "S3Prefix", + } + }, + ) + ) + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + if spark_config.spark_event_logs_uri: + container_entrypoint.extend( + ["--spark-event-logs-s3-uri", spark_config.spark_event_logs_uri] + ) + + if submit_jars_s3_paths: + container_entrypoint.extend(["--jars", submit_jars_s3_paths]) + + if submit_py_files_s3_paths: + container_entrypoint.extend(["--py-files", submit_py_files_s3_paths]) + + if submit_files_s3_path: + container_entrypoint.extend(["--files", submit_files_s3_path]) + + if spark_config: + container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + + return extended_request + + +def _update_job_request_with_checkpoint_config(args, kwargs, request_dict): + """Extend job request with checkpoint config based on CheckpointLocation in function args. + + Args: + args (tuple): The positional arguments of the remote function. + kwargs (Dict): The keyword arguments of the remote function. + request_dict (Dict): create training job request dict. + """ + checkpoint_location_index_in_args = None + checkpoint_location_key_in_kwargs = None + checkpoint_location_count = 0 + + for index, arg in enumerate(args): + if isinstance(arg, CheckpointLocation): + checkpoint_location_index_in_args = index + checkpoint_location_count += 1 + + for key, value in kwargs.items(): + if isinstance(value, CheckpointLocation): + checkpoint_location_key_in_kwargs = key + checkpoint_location_count += 1 + + if checkpoint_location_count < 1: + return + + if checkpoint_location_count > 1: + raise ValueError( + "Remote function cannot have more than one argument of type CheckpointLocation." + ) + + if checkpoint_location_index_in_args is not None: + checkpoint_location_arg = args[checkpoint_location_index_in_args] + else: + checkpoint_location_arg = kwargs[checkpoint_location_key_in_kwargs] + + checkpoint_s3_uri = checkpoint_location_arg._s3_uri + checkpoint_local_path = checkpoint_location_arg._local_path + + request_dict["CheckpointConfig"] = { + "LocalPath": checkpoint_local_path, + "S3Uri": checkpoint_s3_uri, + } + + +@dataclasses.dataclass +class _RunInfo: + """Data class to hold information of the run object from context.""" + + experiment_name: str + run_name: str + + +def _get_initial_job_state(description, status_key, wait): + """Placeholder docstring""" + status = description[status_key] + job_already_completed = status in ("Completed", "Failed", "Stopped") + return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE + + +def _logs_for_job( # noqa: C901 - suppress complexity warning for this method + sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None +): + """Display logs for a given training job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session + object, used for SageMaker interactions. + job_name (str): Name of the training job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). + log_type ([str]): A list of strings specifying which logs to print. Acceptable + strings are "All", "None", "Training", or "Rules". To maintain backwards + compatibility, boolean values are also accepted and converted to strings. + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + Returns: + Last call to sagemaker DescribeTrainingJob + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If waiting and the training job fails. + """ + sagemaker_client = sagemaker_session.sagemaker_client + request_end_time = time.time() + timeout if timeout else None + description = _wait_until( + lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name) + ) + print(secondary_training_status_message(description, None), end="") + + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( + sagemaker_session.boto_session, description, job="Training" + ) + + state = _get_initial_job_state(description, "TrainingJobStatus", wait) + + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. + last_describe_job_call = time.time() + last_description = description + last_debug_rule_statuses = None + last_profiler_rule_statuses = None + + while True: + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) + if timeout and time.time() > request_end_time: + print("Timeout Exceeded. {} seconds elapsed.".format(timeout)) + break + + if state == LogState.COMPLETE: + break + + time.sleep(poll) + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.time() - last_describe_job_call >= 30: + description = sagemaker_client.describe_training_job(TrainingJobName=job_name) + last_describe_job_call = time.time() + + if secondary_training_status_changed(description, last_description): + print() + print(secondary_training_status_message(description, last_description), end="") + last_description = description + + status = description["TrainingJobStatus"] + + if status in ("Completed", "Failed", "Stopped"): + print() + state = LogState.JOB_COMPLETE + + # Print prettified logs related to the status of SageMaker Debugger rules. + debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {}) + if ( + debug_rule_statuses + and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses) + and (log_type in {"All", "Rules"}) + ): + for status in debug_rule_statuses: + rule_log = ( + f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" + ) + print(rule_log) + + last_debug_rule_statuses = debug_rule_statuses + + # Print prettified logs related to the status of SageMaker Profiler rules. + profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {}) + if ( + profiler_rule_statuses + and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses) + and (log_type in {"All", "Rules"}) + ): + for status in profiler_rule_statuses: + rule_log = ( + f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" + ) + print(rule_log) + + last_profiler_rule_statuses = profiler_rule_statuses + + if wait: + _check_job_status(job_name, description, "TrainingJobStatus") + if dot: + print() + # Customers are not billed for hardware provisioning, so billable time is less than + # total time + training_time = description.get("TrainingTimeInSeconds") + billable_time = description.get("BillableTimeInSeconds") + if training_time is not None: + print("Training seconds:", training_time * instance_count) + if billable_time is not None: + print("Billable seconds:", billable_time * instance_count) + if description.get("EnableManagedSpotTraining"): + saving = (1 - float(billable_time) / training_time) * 100 + print("Managed Spot Training savings: {:.1f}%".format(saving)) + return last_description + + +def _check_job_status(job, desc, status_key_name): + """Check to see if the job completed successfully. + + If not, construct and raise a exceptions. (UnexpectedStatusException). + + Args: + job (str): The name of the job to check. + desc (dict[str, str]): The result of ``describe_training_job()``. + status_key_name (str): Status key name to check for. + + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If the training job fails. + """ + status = desc[status_key_name] + # If the status is capital case, then convert it to Camel case + status = _STATUS_CODE_TABLE.get(status, status) + + if status == "Stopped": + logger.warning( + "Job ended with status 'Stopped' rather than 'Completed'. " + "This could mean the job timed out or stopped early for some other reason: " + "Consider checking whether it completed as you expect." + ) + elif status != "Completed": + reason = desc.get("FailureReason", "(No reason provided)") + job_type = status_key_name.replace("JobStatus", " job") + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + ) + message = ( + "Error for {job_type} {job_name}: {status}. Reason: {reason}. " + "Check troubleshooting guide for common errors: {troubleshooting}" + ).format( + job_type=job_type, + job_name=job, + status=status, + reason=reason, + troubleshooting=troubleshooting, + ) + if "CapacityError" in str(reason): + raise exceptions.CapacityError( + message=message, + allowed_statuses=["Completed", "Stopped"], + actual_status=status, + ) + raise exceptions.UnexpectedStatusException( + message=message, + allowed_statuses=["Completed", "Stopped"], + actual_status=status, + ) + + +def _flush_log_streams( + stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap +): + """Placeholder docstring""" + if len(stream_names) < instance_count: + # Log streams are created whenever a container starts writing to stdout/err, so this list + # may be dynamic until we have a stream for every instance. + try: + streams = client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=min(instance_count, 50), + ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] + + while "nextToken" in streams: + streams = client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=50, + ) + + stream_names.extend([s["logStreamName"] for s in streams["logStreams"]]) + + positions.update( + [ + (s, sagemaker_logs.Position(timestamp=0, skip=0)) + for s in stream_names + if s not in positions + ] + ) + except ClientError as e: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + err = e.response.get("Error", {}) + if err.get("Code", None) != "ResourceNotFoundException": + raise + + if len(stream_names) > 0: + if dot: + print("") + dot = False + for idx, event in sagemaker_logs.multi_stream_iter( + client, log_group, stream_names, positions + ): + color_wrap(idx, event["message"]) + ts, count = positions[stream_names[idx]] + if event["timestamp"] == ts: + positions[stream_names[idx]] = sagemaker_logs.Position(timestamp=ts, skip=count + 1) + else: + positions[stream_names[idx]] = sagemaker_logs.Position( + timestamp=event["timestamp"], skip=1 + ) + else: + dot = True + print(".", end="") + sys.stdout.flush() + + +def _rule_statuses_changed(current_statuses, last_statuses): + """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules.""" + if not last_statuses: + return True + + for current, last in zip(current_statuses, last_statuses): + if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and ( + current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"] + ): + return True + + return False + + +def _get_initial_job_state(description, status_key, wait): + """Placeholder docstring""" + status = description[status_key] + job_already_completed = status in ("Completed", "Failed", "Stopped") + return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE + + +def _logs_init(boto_session, description, job): + """Placeholder docstring""" + if job == "Training": + if "InstanceGroups" in description["ResourceConfig"]: + instance_count = 0 + for instanceGroup in description["ResourceConfig"]["InstanceGroups"]: + instance_count += instanceGroup["InstanceCount"] + else: + instance_count = description["ResourceConfig"]["InstanceCount"] + elif job == "Transform": + instance_count = description["TransformResources"]["InstanceCount"] + elif job == "Processing": + instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] + elif job == "AutoML": + instance_count = 0 + + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + + # Increase retries allowed (from default of 4), as we don't want waiting for a training job + # to be interrupted by a transient exception. + config = botocore.config.Config(retries={"max_attempts": 15}) + client = boto_session.client("logs", config=config) + log_group = "/aws/sagemaker/" + job + "Jobs" + + dot = False + + from sagemaker.core.logs import ColorWrap + + color_wrap = ColorWrap() + + return instance_count, stream_names, positions, client, log_group, dot, color_wrap diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py index afe0f80012..f07e860cf9 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -167,7 +167,10 @@ def _handle_pre_exec_scripts(script_file_dir: str): """ path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME) - RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path=path_to_pre_exec_script) + if os.path.isfile(path_to_pre_exec_script): + RuntimeEnvironmentManager().run_pre_exec_script( + pre_exec_script_path=path_to_pre_exec_script + ) def _install_dependencies( @@ -599,4 +602,4 @@ def main(sys_args=None): if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py index 79ddd4020b..c5d9f15ee2 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py @@ -249,4 +249,4 @@ def main(sys_args=None): if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py index 9cb0c7aee4..5f00317c23 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py @@ -338,11 +338,35 @@ def _update_conda_env(self, env_name, local_path): def _export_conda_env_from_prefix(self, prefix, local_path): """Export the conda env to a conda yml file""" + # Validate inputs to prevent command injection + validated_prefix = self._validate_path(prefix) + validated_path = self._validate_path(local_path) - cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path] - logger.info("Exporting conda environment: %s", cmd) - _run_shell_cmd(cmd) - logger.info("Conda environment %s exported successfully", prefix) + cmd = [self._get_conda_exe(), "env", "export", "-p", validated_prefix, "--no-builds"] + logger.info("Exporting conda environment: %s", " ".join(cmd)) + + # Capture output and write to file instead of using shell redirection + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False + ) + output, error_output = process.communicate() + return_code = process.wait() + + if return_code: + error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_output.decode('utf-8')}" + raise RuntimeEnvironmentError(error_message) + + # Write the captured output to the file + with open(validated_path, 'w') as f: + f.write(output.decode('utf-8')) + + logger.info("Conda environment %s exported successfully", validated_prefix) + except Exception as e: + raise RuntimeEnvironmentError(f"Failed to export conda environment: {str(e)}") def _write_conda_env_to_file(self, env_name): """Writes conda env to the text file""" @@ -385,6 +409,7 @@ def _current_sagemaker_pysdk_version(self): """Returns the current sagemaker python sdk version where program is running""" try: from importlib import metadata + return metadata.version("sagemaker") except Exception: return "3.0.0.dev0" # Development version fallback @@ -526,4 +551,4 @@ class RuntimeEnvironmentError(Exception): def __init__(self, message): self.message = message - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) diff --git a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py b/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py index b5083b0566..6b25d5da8b 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/spark_config.py @@ -10,21 +10,140 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" -DEPRECATED: This module has been moved to sagemaker.core.remote_function.spark_config - -This is a backward compatibility shim. -""" +"""This module is used to define the Spark job config to remote function.""" from __future__ import absolute_import -import warnings +from typing import Optional, List, Dict, Union +import attr +from urllib.parse import urlparse +from sagemaker.core.workflow import is_pipeline_variable + + +def _validate_configuration(instance, attribute, configuration): + # pylint: disable=unused-argument + """This is the helper method to validate the spark configuration""" + if configuration: + SparkConfigUtils.validate_configuration(configuration=configuration) + + +def _validate_s3_uri(instance, attribute, s3_uri): + # pylint: disable=unused-argument + """This is the helper method to validate the s3 uri""" + if s3_uri: + SparkConfigUtils.validate_s3_uri(s3_uri) + + +@attr.s(frozen=True) +class SparkConfig: + """This is the class to initialize the spark configurations for remote function + + Attributes: + submit_jars (Optional[List[str]]): A list which contains paths to the jars which + are going to be submitted to Spark job. The location can be a valid s3 uri or + local path to the jar. Defaults to ``None``. + submit_py_files (Optional[List[str]]): A list which contains paths to the python + files which are going to be submitted to Spark job. The location can be a + valid s3 uri or local path to the python file. Defaults to ``None``. + submit_files (Optional[List[str]]): A list which contains paths to the files which + are going to be submitted to Spark job. The location can be a valid s3 uri or + local path to the python file. Defaults to ``None``. + configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive. + List or dictionary of EMR-style classifications. + https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html + spark_event_logs_s3_uri (str): S3 path where Spark application events will + be published to. + """ + + submit_jars: Optional[List[str]] = attr.ib(default=None) + submit_py_files: Optional[List[str]] = attr.ib(default=None) + submit_files: Optional[List[str]] = attr.ib(default=None) + configuration: Optional[Union[List[Dict], Dict]] = attr.ib( + default=None, validator=_validate_configuration + ) + spark_event_logs_uri: Optional[str] = attr.ib(default=None, validator=_validate_s3_uri) + + +class SparkConfigUtils: + """Util class for spark configurations""" + + _valid_configuration_keys = ["Classification", "Properties", "Configurations"] + _valid_configuration_classifications = [ + "core-site", + "hadoop-env", + "hadoop-log4j", + "hive-env", + "hive-log4j", + "hive-exec-log4j", + "hive-site", + "spark-defaults", + "spark-env", + "spark-log4j", + "spark-hive-site", + "spark-metrics", + "yarn-env", + "yarn-site", + "export", + ] + + @staticmethod + def validate_configuration(configuration: Dict): + """Validates the user-provided Hadoop/Spark/Hive configuration. + + This ensures that the list or dictionary the user provides will serialize to + JSON matching the schema of EMR's application configuration + + Args: + configuration (Dict): A dict that contains the configuration overrides to + the default values. For more information, please visit: + https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html + """ + emr_configure_apps_url = ( + "https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html" + ) + if isinstance(configuration, dict): + keys = configuration.keys() + if "Classification" not in keys or "Properties" not in keys: + raise ValueError( + f"Missing one or more required keys in configuration dictionary " + f"{configuration} Please see {emr_configure_apps_url} for more information" + ) + + for key in keys: + if key not in SparkConfigUtils._valid_configuration_keys: + raise ValueError( + f"Invalid key: {key}. " + f"Must be one of {SparkConfigUtils._valid_configuration_keys}. " + f"Please see {emr_configure_apps_url} for more information." + ) + if key == "Classification": + if ( + configuration[key] + not in SparkConfigUtils._valid_configuration_classifications + ): + raise ValueError( + f"Invalid classification: {key}. Must be one of " + f"{SparkConfigUtils._valid_configuration_classifications}" + ) + + if isinstance(configuration, list): + for item in configuration: + SparkConfigUtils.validate_configuration(item) + + # TODO (guoqioa@): method only checks urlparse scheme, need to perform deep s3 validation + @staticmethod + def validate_s3_uri(spark_output_s3_path): + """Validate whether the URI uses an S3 scheme. + + In the future, this validation will perform deeper S3 validation. -# Backward compatibility: re-export from core -from sagemaker.core.remote_function.spark_config import * # noqa: F401, F403 + Args: + spark_output_s3_path (str): The URI of the Spark output S3 Path. + """ + if is_pipeline_variable(spark_output_s3_path): + return -warnings.warn( - "sagemaker.train.remote_function.spark_config has been moved to sagemaker.core.remote_function.spark_config. " - "Please update your imports. This shim will be removed in a future version.", - DeprecationWarning, - stacklevel=2 -) + if urlparse(spark_output_s3_path).scheme != "s3": + raise ValueError( + f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " + "s3://bucket-name/folder-name" + ) diff --git a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py index f7d4b978d4..09af11549f 100644 --- a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py +++ b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py @@ -24,7 +24,7 @@ import omegaconf from omegaconf import OmegaConf, dictconfig -# from sagemaker.utils.image_uris import retrieve +from sagemaker.image_uris import retrieve from sagemaker.train import logger from sagemaker.train.utils import _run_clone_command_silent @@ -129,7 +129,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): """Get the model base name and script for the training recipe.""" model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), + "llama": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), @@ -176,14 +176,13 @@ def _configure_gpu_args( if isinstance(gpu_image_cfg, str): training_image = gpu_image_cfg else: - # training_image = retrieve( - # gpu_image_cfg.get("framework"), - # region=region_name, - # version=gpu_image_cfg.get("version"), - # image_scope="training", - # **gpu_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + gpu_image_cfg.get("framework"), + region=region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) # Setting dummy parameters for now torch_distributed = Torchrun(smp=SMP(random_seed="123456")) @@ -214,14 +213,13 @@ def _configure_trainium_args( if isinstance(neuron_image_cfg, str): training_image = neuron_image_cfg else: - # training_image = retrieve( - # neuron_image_cfg.get("framework"), - # region=region_name, - # version=neuron_image_cfg.get("version"), - # image_scope="training", - # **neuron_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + neuron_image_cfg.get("framework"), + region=region_name, + version=neuron_image_cfg.get("version"), + image_scope="training", + **neuron_image_cfg.get("additional_args"), + ) args.update( { diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index d1af08e2f1..aa54e36f3d 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -52,8 +52,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from sagemaker.train.model_trainer import ModelTrainer -from sagemaker.core.training.configs import InputData -from sagemaker.core.training.utils import _is_valid_s3_uri +from sagemaker.train.configs import InputData +from sagemaker.train.utils import _is_valid_s3_uri HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName" PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs" diff --git a/sagemaker-core/tests/unit/test_training_constants.py b/sagemaker-train/tests/unit/test_training_constants.py similarity index 96% rename from sagemaker-core/tests/unit/test_training_constants.py rename to sagemaker-train/tests/unit/test_training_constants.py index ad096fac37..4ff66e2c24 100644 --- a/sagemaker-core/tests/unit/test_training_constants.py +++ b/sagemaker-train/tests/unit/test_training_constants.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Unit tests for sagemaker.core.training.constants module.""" +"""Unit tests for sagemaker.train.constants module.""" from __future__ import absolute_import import os -from sagemaker.core.training import constants +from sagemaker.train import constants class TestTrainingConstants: