diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index 38a43121a1..8baf6dc886 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -17,6 +17,7 @@ import io from sagemaker.exceptions import ModelStreamError, InternalStreamFailure +from sagemaker.utils import _MAX_BUFFER_SIZE def handle_stream_errors(chunk): @@ -182,5 +183,15 @@ def __next__(self): # print and move on to next response byte print("Unknown event type:" + chunk) continue + + # Check buffer size before writing to prevent unbounded memory consumption + chunk_size = len(chunk["PayloadPart"]["Bytes"]) + current_size = self.buffer.getbuffer().nbytes + if current_size + chunk_size > _MAX_BUFFER_SIZE: + raise RuntimeError( + f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. " + f"No newline found in stream." + ) + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 226ce35c45..10c8420559 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -26,6 +26,7 @@ import sagemaker.amazon.common import sagemaker.local.utils import sagemaker.utils +from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS def get_data_source_instance(data_source, sagemaker_session): @@ -122,6 +123,15 @@ def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) + + # Validate that the path is not in restricted locations + for restricted_path in _SENSITIVE_SYSTEM_PATHS: + if self.root_path != "/" and self.root_path.startswith(restricted_path): + raise ValueError( + f"Local Mode does not support mounting from restricted system paths. " + f"Got: {root_path}" + ) + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 3c7c3cda61..53988e2bba 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path): destination_directory """ full_path = os.path.join(destination_directory, relative_path) - if os.path.exists(full_path): - return - - os.makedirs(destination_directory, relative_path) + os.makedirs(full_path, exist_ok=True) def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 33744bd455..dce5f2517d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -76,6 +76,20 @@ WAITING_DOT_NUMBER = 10 MAX_ITEMS = 100 PAGE_SIZE = 10 +_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators + +_SENSITIVE_SYSTEM_PATHS = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + abspath(realpath("/etc")), + abspath(realpath("/root")), + abspath(realpath("/var/lib")), + abspath(realpath("/opt/ml/metadata")), +] logger = logging.getLogger(__name__) @@ -601,11 +615,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) +def _validate_source_directory(source_directory): + """Validate that source_directory is safe to use. + + Ensures the source directory path does not access restricted system locations. + + Args: + source_directory (str): The source directory path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not source_directory or source_directory.lower().startswith("s3://"): + # S3 paths and None are safe + return + + # Resolve symlinks to get the actual path + abs_source = abspath(realpath(source_directory)) + + # Check if the source path is under any sensitive directory + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if abs_source != "/" and abs_source.startswith(sensitive_path): + raise ValueError( + f"source_directory cannot access sensitive system paths. " + f"Got: {source_directory} (resolved to {abs_source})" + ) + + +def _validate_dependency_path(dependency): + """Validate that a dependency path is safe to use. + + Ensures the dependency path does not access restricted system locations. + + Args: + dependency (str): The dependency path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not dependency: + return + + # Resolve symlinks to get the actual path + abs_dependency = abspath(realpath(dependency)) + + # Check if the dependency path is under any sensitive directory + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if abs_dependency != "/" and abs_dependency.startswith(sensitive_path): + raise ValueError( + f"dependency path cannot access sensitive system paths. " + f"Got: {dependency} (resolved to {abs_dependency})" + ) + + def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") + resolved_code_dir = _get_resolved_path(code_dir) + + # Validate that code_dir does not resolve to a sensitive system path + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): + raise ValueError( + f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" + ) + if source_directory and source_directory.lower().startswith("s3://"): local_code_path = os.path.join(tmp, "local_code.tar.gz") download_file_from_url(source_directory, local_code_path, sagemaker_session) @@ -614,6 +690,8 @@ def _create_or_update_code_dir( custom_extractall_tarfile(t, code_dir) elif source_directory: + # Validate source_directory for security + _validate_source_directory(source_directory) if os.path.exists(code_dir): shutil.rmtree(code_dir) shutil.copytree(source_directory, code_dir) @@ -646,6 +724,8 @@ def _create_or_update_code_dir( ) for dependency in dependencies: + # Validate dependency path for security + _validate_dependency_path(dependency) lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) @@ -1620,6 +1700,38 @@ def _get_safe_members(members): yield file_info +def _validate_extracted_paths(extract_path): + """Validate that extracted paths remain within the expected directory. + + Performs post-extraction validation to ensure all extracted files and directories + are within the intended extraction path. + + Args: + extract_path (str): The path where files were extracted. + + Raises: + ValueError: If any extracted file is outside the expected extraction path. + """ + base = _get_resolved_path(extract_path) + + for root, dirs, files in os.walk(extract_path): + # Check directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + resolved = _get_resolved_path(dir_path) + if not resolved.startswith(base): + logger.error("Extracted directory escaped extraction path: %s", dir_path) + raise ValueError(f"Extracted path outside expected directory: {dir_path}") + + # Check files + for file_name in files: + file_path = os.path.join(root, file_name) + resolved = _get_resolved_path(file_path) + if not resolved.startswith(base): + logger.error("Extracted file escaped extraction path: %s", file_path) + raise ValueError(f"Extracted path outside expected directory: {file_path}") + + def custom_extractall_tarfile(tar, extract_path): """Extract a tarfile, optionally using data_filter if available. @@ -1640,6 +1752,8 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + # Re-validate extracted paths to catch symlink race conditions + _validate_extracted_paths(extract_path) def can_model_package_source_uri_autopopulate(source_uri: str): diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 82e3207266..0deed277c5 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -25,9 +25,9 @@ @patch("sagemaker.local.utils.os.path") @patch("sagemaker.local.utils.os") def test_copy_directory_structure(m_os, m_os_path): - m_os_path.exists.return_value = False + m_os_path.join.return_value = "/tmp/code/" sagemaker.local.utils.copy_directory_structure("/tmp/", "code/") - m_os.makedirs.assert_called_with("/tmp/", "code/") + m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True) @patch("shutil.rmtree", Mock()) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 5deff5163b..91e96e157b 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -2245,3 +2245,170 @@ def test_get_domain_for_region(self): self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov") self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov") self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com") + + +class TestValidateSourceDirectory(TestCase): + """Tests for _validate_source_directory function""" + + def test_validate_source_directory_with_s3_path(self): + """S3 paths should be allowed""" + from sagemaker.utils import _validate_source_directory + + # Should not raise any exception + _validate_source_directory("s3://my-bucket/my-prefix") + + def test_validate_source_directory_with_none(self): + """None should be allowed""" + from sagemaker.utils import _validate_source_directory + + # Should not raise any exception + _validate_source_directory(None) + + def test_validate_source_directory_with_safe_local_path(self): + """Safe local paths should be allowed""" + from sagemaker.utils import _validate_source_directory + + # Should not raise any exception + _validate_source_directory("/tmp/my_code") + _validate_source_directory("./my_code") + _validate_source_directory("../my_code") + + def test_validate_source_directory_with_sensitive_path_aws(self): + """Paths under ~/.aws should be rejected""" + from sagemaker.utils import _validate_source_directory + + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(os.path.expanduser("~/.aws/credentials")) + + def test_validate_source_directory_with_sensitive_path_ssh(self): + """Paths under ~/.ssh should be rejected""" + from sagemaker.utils import _validate_source_directory + + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(os.path.expanduser("~/.ssh/id_rsa")) + + def test_validate_source_directory_with_root_directory(self): + """Root directory itself should be allowed (not rejected)""" + from sagemaker.utils import _validate_source_directory + + # Should not raise any exception - root directory is explicitly allowed + _validate_source_directory("/") + + +class TestValidateDependencyPath(TestCase): + """Tests for _validate_dependency_path function""" + + def test_validate_dependency_path_with_none(self): + """None should be allowed""" + from sagemaker.utils import _validate_dependency_path + + # Should not raise any exception + _validate_dependency_path(None) + + def test_validate_dependency_path_with_safe_local_path(self): + """Safe local paths should be allowed""" + from sagemaker.utils import _validate_dependency_path + + # Should not raise any exception + _validate_dependency_path("/tmp/my_lib") + _validate_dependency_path("./my_lib") + _validate_dependency_path("../my_lib") + + def test_validate_dependency_path_with_sensitive_path_aws(self): + """Paths under ~/.aws should be rejected""" + from sagemaker.utils import _validate_dependency_path + + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(os.path.expanduser("~/.aws")) + + def test_validate_dependency_path_with_sensitive_path_docker(self): + """Paths under ~/.docker should be rejected""" + from sagemaker.utils import _validate_dependency_path + + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(os.path.expanduser("~/.docker/config.json")) + + def test_validate_dependency_path_with_root_directory(self): + """Root directory itself should be allowed (not rejected)""" + from sagemaker.utils import _validate_dependency_path + + # Should not raise any exception - root directory is explicitly allowed + _validate_dependency_path("/") + + +class TestCreateOrUpdateCodeDir(TestCase): + """Tests for _create_or_update_code_dir function""" + + @patch("sagemaker.utils._validate_source_directory") + @patch("sagemaker.utils._validate_dependency_path") + @patch("sagemaker.utils.os.path.exists") + @patch("sagemaker.utils.os.mkdir") + @patch("sagemaker.utils.shutil.copy2") + def test_create_or_update_code_dir_with_inference_script( + self, mock_copy, mock_mkdir, mock_exists, mock_validate_dep, mock_validate_src + ): + """Test creating code dir with inference script""" + from sagemaker.utils import _create_or_update_code_dir + + mock_exists.return_value = False + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + mock_get_resolved.return_value = "/tmp/model/code" + + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script="inference.py", + source_directory=None, + dependencies=[], + sagemaker_session=None, + tmp="/tmp", + ) + + mock_mkdir.assert_called() + mock_copy.assert_called_once() + + @patch("sagemaker.utils._validate_source_directory") + @patch("sagemaker.utils.os.path.exists") + @patch("sagemaker.utils.shutil.rmtree") + @patch("sagemaker.utils.shutil.copytree") + def test_create_or_update_code_dir_with_source_directory( + self, mock_copytree, mock_rmtree, mock_exists, mock_validate_src + ): + """Test creating code dir with source directory""" + from sagemaker.utils import _create_or_update_code_dir + + mock_exists.return_value = True + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + mock_get_resolved.return_value = "/tmp/model/code" + + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script=None, + source_directory="/tmp/my_code", + dependencies=[], + sagemaker_session=None, + tmp="/tmp", + ) + + mock_validate_src.assert_called_once_with("/tmp/my_code") + mock_rmtree.assert_called_once() + mock_copytree.assert_called_once() + + def test_create_or_update_code_dir_with_sensitive_code_dir(self): + """Test that code_dir resolving to sensitive path is rejected""" + from sagemaker.utils import _create_or_update_code_dir + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + # Simulate code_dir resolving to a sensitive path + mock_get_resolved.return_value = os.path.abspath(os.path.expanduser("~/.aws")) + + with pytest.raises(ValueError, match="Invalid code_dir path"): + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script="inference.py", + source_directory=None, + dependencies=[], + sagemaker_session=None, + tmp="/tmp", + )