From 78d88af027de9769756de6cacb0076b49e40fe7c Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:19:47 -0800 Subject: [PATCH 1/9] Add input validation and resource management improvements --- src/sagemaker/iterators.py | 13 ++ src/sagemaker/local/data.py | 24 ++++ src/sagemaker/local/utils.py | 5 +- src/sagemaker/utils.py | 127 ++++++++++++++++++ .../unit/sagemaker/local/test_local_utils.py | 6 +- 5 files changed, 167 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index 38a43121a1..d28ece537b 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -114,6 +114,9 @@ def __next__(self): class LineIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Line iteration.""" + # Maximum buffer size to prevent unbounded memory consumption (10 MB) + MAX_BUFFER_SIZE = 10 * 1024 * 1024 + def __init__(self, event_stream): """Initialises a LineIterator Iterator object @@ -182,5 +185,15 @@ def __next__(self): # print and move on to next response byte print("Unknown event type:" + chunk) continue + + # Check buffer size before writing to prevent unbounded memory consumption + chunk_size = len(chunk["PayloadPart"]["Bytes"]) + current_size = self.buffer.getbuffer().nbytes + if current_size + chunk_size > self.MAX_BUFFER_SIZE: + raise RuntimeError( + f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " + f"No newline found in stream." + ) + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 226ce35c45..5b634738ac 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -118,10 +118,34 @@ def get_root_dir(self): class LocalFileDataSource(DataSource): """Represents a data source within the local filesystem.""" + # Blocklist of sensitive directories that should not be accessible + RESTRICTED_PATHS = [ + os.path.abspath(os.path.expanduser("~/.aws")), + os.path.abspath(os.path.expanduser("~/.ssh")), + os.path.abspath(os.path.expanduser("~/.kube")), + os.path.abspath(os.path.expanduser("~/.docker")), + os.path.abspath(os.path.expanduser("~/.config")), + os.path.abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) + + # Validate that the path is not in restricted locations + for restricted_path in self.RESTRICTED_PATHS: + if self.root_path.startswith(restricted_path): + raise ValueError( + f"Local Mode does not support mounting from restricted system paths. " + f"Got: {root_path}" + ) + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 3c7c3cda61..53988e2bba 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path): destination_directory """ full_path = os.path.join(destination_directory, relative_path) - if os.path.exists(full_path): - return - - os.makedirs(destination_directory, relative_path) + os.makedirs(full_path, exist_ok=True) def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 33744bd455..76d3086db4 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -601,6 +601,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) +def _validate_source_directory(source_directory): + """Validate that source_directory is safe to use. + + Ensures the source directory path does not access restricted system locations. + + Args: + source_directory (str): The source directory path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not source_directory or source_directory.lower().startswith("s3://"): + # S3 paths and None are safe + return + + abs_source = abspath(source_directory) + + # Blocklist of sensitive directories that should not be accessible + sensitive_paths = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + + # Check if the source path is under any sensitive directory + for sensitive_path in sensitive_paths: + if abs_source.startswith(sensitive_path): + raise ValueError( + f"source_directory cannot access sensitive system paths. " + f"Got: {source_directory} (resolved to {abs_source})" + ) + + # Check for symlinks to prevent symlink-based escapes + if os.path.islink(abs_source): + raise ValueError(f"source_directory cannot be a symlink: {source_directory}") + + +def _validate_dependency_path(dependency): + """Validate that a dependency path is safe to use. + + Ensures the dependency path does not access restricted system locations. + + Args: + dependency (str): The dependency path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not dependency: + return + + abs_dependency = abspath(dependency) + + # Blocklist of sensitive directories that should not be accessible + sensitive_paths = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + + # Check if the dependency path is under any sensitive directory + for sensitive_path in sensitive_paths: + if abs_dependency.startswith(sensitive_path): + raise ValueError( + f"dependency path cannot access sensitive system paths. " + f"Got: {dependency} (resolved to {abs_dependency})" + ) + + # Check for symlinks to prevent symlink-based escapes + if os.path.islink(abs_dependency): + raise ValueError(f"dependency path cannot be a symlink: {dependency}") + + def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): @@ -614,6 +703,8 @@ def _create_or_update_code_dir( custom_extractall_tarfile(t, code_dir) elif source_directory: + # Validate source_directory for security + _validate_source_directory(source_directory) if os.path.exists(code_dir): shutil.rmtree(code_dir) shutil.copytree(source_directory, code_dir) @@ -646,6 +737,8 @@ def _create_or_update_code_dir( ) for dependency in dependencies: + # Validate dependency path for security + _validate_dependency_path(dependency) lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) @@ -1620,6 +1713,38 @@ def _get_safe_members(members): yield file_info +def _validate_extracted_paths(extract_path): + """Validate that extracted paths remain within the expected directory. + + Performs post-extraction validation to ensure all extracted files and directories + are within the intended extraction path. + + Args: + extract_path (str): The path where files were extracted. + + Raises: + ValueError: If any extracted file is outside the expected extraction path. + """ + base = _get_resolved_path(extract_path) + + for root, dirs, files in os.walk(extract_path): + # Check directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + resolved = _get_resolved_path(dir_path) + if not resolved.startswith(base): + logger.error("Extracted directory escaped extraction path: %s", dir_path) + raise ValueError(f"Extracted path outside expected directory: {dir_path}") + + # Check files + for file_name in files: + file_path = os.path.join(root, file_name) + resolved = _get_resolved_path(file_path) + if not resolved.startswith(base): + logger.error("Extracted file escaped extraction path: %s", file_path) + raise ValueError(f"Extracted path outside expected directory: {file_path}") + + def custom_extractall_tarfile(tar, extract_path): """Extract a tarfile, optionally using data_filter if available. @@ -1640,6 +1765,8 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + # Re-validate extracted paths to catch symlink race conditions + _validate_extracted_paths(extract_path) def can_model_package_source_uri_autopopulate(source_uri: str): diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 82e3207266..83fe8de555 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -22,12 +22,10 @@ from sagemaker.session_settings import SessionSettings -@patch("sagemaker.local.utils.os.path") @patch("sagemaker.local.utils.os") -def test_copy_directory_structure(m_os, m_os_path): - m_os_path.exists.return_value = False +def test_copy_directory_structure(m_os): sagemaker.local.utils.copy_directory_structure("/tmp/", "code/") - m_os.makedirs.assert_called_with("/tmp/", "code/") + m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True) @patch("shutil.rmtree", Mock()) From 7efb032b289a6d1003a8d34fed3c3369c437abb3 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:25:36 -0800 Subject: [PATCH 2/9] Fix failing unit test --- tests/unit/sagemaker/local/test_local_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 83fe8de555..0deed277c5 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -22,8 +22,10 @@ from sagemaker.session_settings import SessionSettings +@patch("sagemaker.local.utils.os.path") @patch("sagemaker.local.utils.os") -def test_copy_directory_structure(m_os): +def test_copy_directory_structure(m_os, m_os_path): + m_os_path.join.return_value = "/tmp/code/" sagemaker.local.utils.copy_directory_structure("/tmp/", "code/") m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True) From 4cd09ae54644b05908391275c565fe5b3f1d14a0 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:24:53 -0800 Subject: [PATCH 3/9] Fix codestyle issues --- src/sagemaker/iterators.py | 2 ++ src/sagemaker/local/data.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index d28ece537b..ffe0044cf2 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -186,6 +186,7 @@ def __next__(self): print("Unknown event type:" + chunk) continue + # Check buffer size before writing to prevent unbounded memory consumption chunk_size = len(chunk["PayloadPart"]["Bytes"]) current_size = self.buffer.getbuffer().nbytes @@ -195,5 +196,6 @@ def __next__(self): f"No newline found in stream." ) + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 5b634738ac..51679ca9bd 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -138,6 +138,7 @@ def __init__(self, root_path): self.root_path = os.path.abspath(root_path) + # Validate that the path is not in restricted locations for restricted_path in self.RESTRICTED_PATHS: if self.root_path.startswith(restricted_path): @@ -146,6 +147,7 @@ def __init__(self, root_path): f"Got: {root_path}" ) + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) From db68076dce76b695af8d2f14ffcb0b11b5457c74 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:43:16 -0800 Subject: [PATCH 4/9] More codestyle fixes --- src/sagemaker/iterators.py | 4 +--- src/sagemaker/local/data.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index ffe0044cf2..867448d25e 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -185,7 +185,6 @@ def __next__(self): # print and move on to next response byte print("Unknown event type:" + chunk) continue - # Check buffer size before writing to prevent unbounded memory consumption chunk_size = len(chunk["PayloadPart"]["Bytes"]) @@ -195,7 +194,6 @@ def __next__(self): f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " f"No newline found in stream." ) - - + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 51679ca9bd..88273672d9 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -137,7 +137,6 @@ def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) - # Validate that the path is not in restricted locations for restricted_path in self.RESTRICTED_PATHS: @@ -146,8 +145,7 @@ def __init__(self, root_path): f"Local Mode does not support mounting from restricted system paths. " f"Got: {root_path}" ) - - + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) From 206c07e2da077ddb8264b5cca1d798387cefdc21 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:08:57 -0800 Subject: [PATCH 5/9] Allowing for sym-links, better refactoring --- src/sagemaker/iterators.py | 8 ++--- src/sagemaker/local/data.py | 18 ++--------- src/sagemaker/utils.py | 63 +++++++++++++------------------------ 3 files changed, 26 insertions(+), 63 deletions(-) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index 867448d25e..8baf6dc886 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -17,6 +17,7 @@ import io from sagemaker.exceptions import ModelStreamError, InternalStreamFailure +from sagemaker.utils import _MAX_BUFFER_SIZE def handle_stream_errors(chunk): @@ -114,9 +115,6 @@ def __next__(self): class LineIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Line iteration.""" - # Maximum buffer size to prevent unbounded memory consumption (10 MB) - MAX_BUFFER_SIZE = 10 * 1024 * 1024 - def __init__(self, event_stream): """Initialises a LineIterator Iterator object @@ -189,9 +187,9 @@ def __next__(self): # Check buffer size before writing to prevent unbounded memory consumption chunk_size = len(chunk["PayloadPart"]["Bytes"]) current_size = self.buffer.getbuffer().nbytes - if current_size + chunk_size > self.MAX_BUFFER_SIZE: + if current_size + chunk_size > _MAX_BUFFER_SIZE: raise RuntimeError( - f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " + f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. " f"No newline found in stream." ) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 88273672d9..b0bd7dc71d 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -26,6 +26,7 @@ import sagemaker.amazon.common import sagemaker.local.utils import sagemaker.utils +from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS def get_data_source_instance(data_source, sagemaker_session): @@ -118,28 +119,13 @@ def get_root_dir(self): class LocalFileDataSource(DataSource): """Represents a data source within the local filesystem.""" - # Blocklist of sensitive directories that should not be accessible - RESTRICTED_PATHS = [ - os.path.abspath(os.path.expanduser("~/.aws")), - os.path.abspath(os.path.expanduser("~/.ssh")), - os.path.abspath(os.path.expanduser("~/.kube")), - os.path.abspath(os.path.expanduser("~/.docker")), - os.path.abspath(os.path.expanduser("~/.config")), - os.path.abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] - def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) # Validate that the path is not in restricted locations - for restricted_path in self.RESTRICTED_PATHS: + for restricted_path in _SENSITIVE_SYSTEM_PATHS: if self.root_path.startswith(restricted_path): raise ValueError( f"Local Mode does not support mounting from restricted system paths. " diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 76d3086db4..d5e5b21a01 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -76,6 +76,21 @@ WAITING_DOT_NUMBER = 10 MAX_ITEMS = 100 PAGE_SIZE = 10 +_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators + +_SENSITIVE_SYSTEM_PATHS = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", +] logger = logging.getLogger(__name__) @@ -616,35 +631,17 @@ def _validate_source_directory(source_directory): # S3 paths and None are safe return - abs_source = abspath(source_directory) - - # Blocklist of sensitive directories that should not be accessible - sensitive_paths = [ - abspath(os.path.expanduser("~/.aws")), - abspath(os.path.expanduser("~/.ssh")), - abspath(os.path.expanduser("~/.kube")), - abspath(os.path.expanduser("~/.docker")), - abspath(os.path.expanduser("~/.config")), - abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] + # Resolve symlinks to get the actual path + abs_source = abspath(realpath(source_directory)) # Check if the source path is under any sensitive directory - for sensitive_path in sensitive_paths: + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: if abs_source.startswith(sensitive_path): raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" ) - # Check for symlinks to prevent symlink-based escapes - if os.path.islink(abs_source): - raise ValueError(f"source_directory cannot be a symlink: {source_directory}") - def _validate_dependency_path(dependency): """Validate that a dependency path is safe to use. @@ -660,35 +657,17 @@ def _validate_dependency_path(dependency): if not dependency: return - abs_dependency = abspath(dependency) - - # Blocklist of sensitive directories that should not be accessible - sensitive_paths = [ - abspath(os.path.expanduser("~/.aws")), - abspath(os.path.expanduser("~/.ssh")), - abspath(os.path.expanduser("~/.kube")), - abspath(os.path.expanduser("~/.docker")), - abspath(os.path.expanduser("~/.config")), - abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] + # Resolve symlinks to get the actual path + abs_dependency = abspath(realpath(dependency)) # Check if the dependency path is under any sensitive directory - for sensitive_path in sensitive_paths: + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: if abs_dependency.startswith(sensitive_path): raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" ) - # Check for symlinks to prevent symlink-based escapes - if os.path.islink(abs_dependency): - raise ValueError(f"dependency path cannot be a symlink: {dependency}") - def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp From de8bc1eafeec18e51040628114c1c964a6640aac Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:40:47 -0800 Subject: [PATCH 6/9] Adding additional validation and removing home as sensitive path --- src/sagemaker/utils.py | 22 ++++-- tests/unit/test_utils.py | 157 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d5e5b21a01..34bf961a7b 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -85,11 +85,10 @@ abspath(os.path.expanduser("~/.docker")), abspath(os.path.expanduser("~/.config")), abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", + abspath(realpath("/etc")), + abspath(realpath("/root")), + abspath(realpath("/var/lib")), + abspath(realpath("/opt/ml/metadata")), ] logger = logging.getLogger(__name__) @@ -636,7 +635,7 @@ def _validate_source_directory(source_directory): # Check if the source path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_source.startswith(sensitive_path): + if abs_source != "/" and abs_source.startswith(sensitive_path): raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" @@ -662,7 +661,7 @@ def _validate_dependency_path(dependency): # Check if the dependency path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_dependency.startswith(sensitive_path): + if abs_dependency != "/" and abs_dependency.startswith(sensitive_path): raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" @@ -674,6 +673,15 @@ def _create_or_update_code_dir( ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") + resolved_code_dir = _get_resolved_path(code_dir) + + # Validate that code_dir does not resolve to a sensitive system path + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if resolved_code_dir.startswith(sensitive_path): + raise ValueError( + f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" + ) + if source_directory and source_directory.lower().startswith("s3://"): local_code_path = os.path.join(tmp, "local_code.tar.gz") download_file_from_url(source_directory, local_code_path, sagemaker_session) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 5deff5163b..d5cb284ea1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -2245,3 +2245,160 @@ def test_get_domain_for_region(self): self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov") self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov") self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com") + + + +class TestValidateSourceDirectory(TestCase): + """Tests for _validate_source_directory function""" + + def test_validate_source_directory_with_s3_path(self): + """S3 paths should be allowed""" + from sagemaker.utils import _validate_source_directory + # Should not raise any exception + _validate_source_directory("s3://my-bucket/my-prefix") + + def test_validate_source_directory_with_none(self): + """None should be allowed""" + from sagemaker.utils import _validate_source_directory + # Should not raise any exception + _validate_source_directory(None) + + def test_validate_source_directory_with_safe_local_path(self): + """Safe local paths should be allowed""" + from sagemaker.utils import _validate_source_directory + # Should not raise any exception + _validate_source_directory("/tmp/my_code") + _validate_source_directory("./my_code") + _validate_source_directory("../my_code") + + def test_validate_source_directory_with_sensitive_path_aws(self): + """Paths under ~/.aws should be rejected""" + from sagemaker.utils import _validate_source_directory + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(os.path.expanduser("~/.aws/credentials")) + + def test_validate_source_directory_with_sensitive_path_ssh(self): + """Paths under ~/.ssh should be rejected""" + from sagemaker.utils import _validate_source_directory + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(os.path.expanduser("~/.ssh/id_rsa")) + + def test_validate_source_directory_with_root_directory(self): + """Root directory itself should be allowed (not rejected)""" + from sagemaker.utils import _validate_source_directory + # Should not raise any exception - root directory is explicitly allowed + _validate_source_directory("/") + + +class TestValidateDependencyPath(TestCase): + """Tests for _validate_dependency_path function""" + + def test_validate_dependency_path_with_none(self): + """None should be allowed""" + from sagemaker.utils import _validate_dependency_path + # Should not raise any exception + _validate_dependency_path(None) + + def test_validate_dependency_path_with_safe_local_path(self): + """Safe local paths should be allowed""" + from sagemaker.utils import _validate_dependency_path + # Should not raise any exception + _validate_dependency_path("/tmp/my_lib") + _validate_dependency_path("./my_lib") + _validate_dependency_path("../my_lib") + + def test_validate_dependency_path_with_sensitive_path_aws(self): + """Paths under ~/.aws should be rejected""" + from sagemaker.utils import _validate_dependency_path + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(os.path.expanduser("~/.aws")) + + def test_validate_dependency_path_with_sensitive_path_docker(self): + """Paths under ~/.docker should be rejected""" + from sagemaker.utils import _validate_dependency_path + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(os.path.expanduser("~/.docker/config.json")) + + def test_validate_dependency_path_with_root_directory(self): + """Root directory itself should be allowed (not rejected)""" + from sagemaker.utils import _validate_dependency_path + # Should not raise any exception - root directory is explicitly allowed + _validate_dependency_path("/") + + +class TestCreateOrUpdateCodeDir(TestCase): + """Tests for _create_or_update_code_dir function""" + + @patch("sagemaker.utils._validate_source_directory") + @patch("sagemaker.utils._validate_dependency_path") + @patch("sagemaker.utils.os.path.exists") + @patch("sagemaker.utils.os.mkdir") + @patch("sagemaker.utils.shutil.copy2") + def test_create_or_update_code_dir_with_inference_script( + self, mock_copy, mock_mkdir, mock_exists, mock_validate_dep, mock_validate_src + ): + """Test creating code dir with inference script""" + from sagemaker.utils import _create_or_update_code_dir + + mock_exists.return_value = False + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + mock_get_resolved.return_value = "/tmp/model/code" + + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script="inference.py", + source_directory=None, + dependencies=[], + sagemaker_session=None, + tmp="/tmp" + ) + + mock_mkdir.assert_called() + mock_copy.assert_called_once() + + @patch("sagemaker.utils._validate_source_directory") + @patch("sagemaker.utils.os.path.exists") + @patch("sagemaker.utils.shutil.rmtree") + @patch("sagemaker.utils.shutil.copytree") + def test_create_or_update_code_dir_with_source_directory( + self, mock_copytree, mock_rmtree, mock_exists, mock_validate_src + ): + """Test creating code dir with source directory""" + from sagemaker.utils import _create_or_update_code_dir + + mock_exists.return_value = True + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + mock_get_resolved.return_value = "/tmp/model/code" + + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script=None, + source_directory="/tmp/my_code", + dependencies=[], + sagemaker_session=None, + tmp="/tmp" + ) + + mock_validate_src.assert_called_once_with("/tmp/my_code") + mock_rmtree.assert_called_once() + mock_copytree.assert_called_once() + + def test_create_or_update_code_dir_with_sensitive_code_dir(self): + """Test that code_dir resolving to sensitive path is rejected""" + from sagemaker.utils import _create_or_update_code_dir + + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: + # Simulate code_dir resolving to a sensitive path + mock_get_resolved.return_value = os.path.abspath(os.path.expanduser("~/.aws")) + + with pytest.raises(ValueError, match="Invalid code_dir path"): + _create_or_update_code_dir( + model_dir="/tmp/model", + inference_script="inference.py", + source_directory=None, + dependencies=[], + sagemaker_session=None, + tmp="/tmp" + ) From 9606bea6e79f46fce137f4a3def1f823db018011 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:06:27 -0800 Subject: [PATCH 7/9] Adding root directory validation to other helpers --- src/sagemaker/local/data.py | 2 +- src/sagemaker/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index b0bd7dc71d..10c8420559 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -126,7 +126,7 @@ def __init__(self, root_path): # Validate that the path is not in restricted locations for restricted_path in _SENSITIVE_SYSTEM_PATHS: - if self.root_path.startswith(restricted_path): + if self.root_path != "/" and self.root_path.startswith(restricted_path): raise ValueError( f"Local Mode does not support mounting from restricted system paths. " f"Got: {root_path}" diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 34bf961a7b..3de4b87c34 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -677,7 +677,7 @@ def _create_or_update_code_dir( # Validate that code_dir does not resolve to a sensitive system path for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if resolved_code_dir.startswith(sensitive_path): + if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): raise ValueError( f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" ) From 74f2b4f866be022d23e401ab097263e1b40514f3 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:13:47 -0800 Subject: [PATCH 8/9] Fixing codestyle changes --- src/sagemaker/utils.py | 2 +- tests/unit/test_utils.py | 34 ++++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 3de4b87c34..dce5f2517d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -674,7 +674,7 @@ def _create_or_update_code_dir( """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") resolved_code_dir = _get_resolved_path(code_dir) - + # Validate that code_dir does not resolve to a sensitive system path for sensitive_path in _SENSITIVE_SYSTEM_PATHS: if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d5cb284ea1..1005219c62 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -2247,25 +2247,27 @@ def test_get_domain_for_region(self): self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com") - class TestValidateSourceDirectory(TestCase): """Tests for _validate_source_directory function""" def test_validate_source_directory_with_s3_path(self): """S3 paths should be allowed""" from sagemaker.utils import _validate_source_directory + # Should not raise any exception _validate_source_directory("s3://my-bucket/my-prefix") def test_validate_source_directory_with_none(self): """None should be allowed""" from sagemaker.utils import _validate_source_directory + # Should not raise any exception _validate_source_directory(None) def test_validate_source_directory_with_safe_local_path(self): """Safe local paths should be allowed""" from sagemaker.utils import _validate_source_directory + # Should not raise any exception _validate_source_directory("/tmp/my_code") _validate_source_directory("./my_code") @@ -2274,18 +2276,21 @@ def test_validate_source_directory_with_safe_local_path(self): def test_validate_source_directory_with_sensitive_path_aws(self): """Paths under ~/.aws should be rejected""" from sagemaker.utils import _validate_source_directory + with pytest.raises(ValueError, match="cannot access sensitive system paths"): _validate_source_directory(os.path.expanduser("~/.aws/credentials")) def test_validate_source_directory_with_sensitive_path_ssh(self): """Paths under ~/.ssh should be rejected""" from sagemaker.utils import _validate_source_directory + with pytest.raises(ValueError, match="cannot access sensitive system paths"): _validate_source_directory(os.path.expanduser("~/.ssh/id_rsa")) def test_validate_source_directory_with_root_directory(self): """Root directory itself should be allowed (not rejected)""" from sagemaker.utils import _validate_source_directory + # Should not raise any exception - root directory is explicitly allowed _validate_source_directory("/") @@ -2296,12 +2301,14 @@ class TestValidateDependencyPath(TestCase): def test_validate_dependency_path_with_none(self): """None should be allowed""" from sagemaker.utils import _validate_dependency_path + # Should not raise any exception _validate_dependency_path(None) def test_validate_dependency_path_with_safe_local_path(self): """Safe local paths should be allowed""" from sagemaker.utils import _validate_dependency_path + # Should not raise any exception _validate_dependency_path("/tmp/my_lib") _validate_dependency_path("./my_lib") @@ -2310,18 +2317,21 @@ def test_validate_dependency_path_with_safe_local_path(self): def test_validate_dependency_path_with_sensitive_path_aws(self): """Paths under ~/.aws should be rejected""" from sagemaker.utils import _validate_dependency_path + with pytest.raises(ValueError, match="cannot access sensitive system paths"): _validate_dependency_path(os.path.expanduser("~/.aws")) def test_validate_dependency_path_with_sensitive_path_docker(self): """Paths under ~/.docker should be rejected""" from sagemaker.utils import _validate_dependency_path + with pytest.raises(ValueError, match="cannot access sensitive system paths"): _validate_dependency_path(os.path.expanduser("~/.docker/config.json")) def test_validate_dependency_path_with_root_directory(self): """Root directory itself should be allowed (not rejected)""" from sagemaker.utils import _validate_dependency_path + # Should not raise any exception - root directory is explicitly allowed _validate_dependency_path("/") @@ -2339,9 +2349,9 @@ def test_create_or_update_code_dir_with_inference_script( ): """Test creating code dir with inference script""" from sagemaker.utils import _create_or_update_code_dir - + mock_exists.return_value = False - + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: mock_get_resolved.return_value = "/tmp/model/code" @@ -2351,9 +2361,9 @@ def test_create_or_update_code_dir_with_inference_script( source_directory=None, dependencies=[], sagemaker_session=None, - tmp="/tmp" + tmp="/tmp", ) - + mock_mkdir.assert_called() mock_copy.assert_called_once() @@ -2366,12 +2376,12 @@ def test_create_or_update_code_dir_with_source_directory( ): """Test creating code dir with source directory""" from sagemaker.utils import _create_or_update_code_dir - + mock_exists.return_value = True - + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: mock_get_resolved.return_value = "/tmp/model/code" - + _create_or_update_code_dir( model_dir="/tmp/model", inference_script=None, @@ -2380,7 +2390,7 @@ def test_create_or_update_code_dir_with_source_directory( sagemaker_session=None, tmp="/tmp" ) - + mock_validate_src.assert_called_once_with("/tmp/my_code") mock_rmtree.assert_called_once() mock_copytree.assert_called_once() @@ -2388,11 +2398,11 @@ def test_create_or_update_code_dir_with_source_directory( def test_create_or_update_code_dir_with_sensitive_code_dir(self): """Test that code_dir resolving to sensitive path is rejected""" from sagemaker.utils import _create_or_update_code_dir - + with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: # Simulate code_dir resolving to a sensitive path mock_get_resolved.return_value = os.path.abspath(os.path.expanduser("~/.aws")) - + with pytest.raises(ValueError, match="Invalid code_dir path"): _create_or_update_code_dir( model_dir="/tmp/model", @@ -2400,5 +2410,5 @@ def test_create_or_update_code_dir_with_sensitive_code_dir(self): source_directory=None, dependencies=[], sagemaker_session=None, - tmp="/tmp" + tmp="/tmp", ) From 3353ff46830a23bc9ffd0c150ae5c6a36e7bc94a Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:31:01 -0800 Subject: [PATCH 9/9] Fixes for missed codestyle changes --- tests/unit/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 1005219c62..91e96e157b 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -2354,7 +2354,7 @@ def test_create_or_update_code_dir_with_inference_script( with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved: mock_get_resolved.return_value = "/tmp/model/code" - + _create_or_update_code_dir( model_dir="/tmp/model", inference_script="inference.py", @@ -2388,7 +2388,7 @@ def test_create_or_update_code_dir_with_source_directory( source_directory="/tmp/my_code", dependencies=[], sagemaker_session=None, - tmp="/tmp" + tmp="/tmp", ) mock_validate_src.assert_called_once_with("/tmp/my_code")