Skip to content

Commit a74f9ab

Browse files
authored
Add input validation and resource management improvements V3 (#5418)
* Add input validation and resource management improvements V3 * Allowing for sym-links, better refactoring * Removing home path and adding additional validaiton * Including check for root directory * Adding root directory validation to other helpers
1 parent 1c7faf0 commit a74f9ab

File tree

6 files changed

+428
-6
lines changed

6 files changed

+428
-6
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@
7575
WAITING_DOT_NUMBER = 10
7676
MAX_ITEMS = 100
7777
PAGE_SIZE = 10
78+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
79+
80+
_SENSITIVE_SYSTEM_PATHS = [
81+
abspath(os.path.expanduser("~/.aws")),
82+
abspath(os.path.expanduser("~/.ssh")),
83+
abspath(os.path.expanduser("~/.kube")),
84+
abspath(os.path.expanduser("~/.docker")),
85+
abspath(os.path.expanduser("~/.config")),
86+
abspath(os.path.expanduser("~/.credentials")),
87+
"/etc",
88+
"/root",
89+
"/var/lib",
90+
"/opt/ml/metadata",
91+
]
7892

7993
logger = logging.getLogger(__name__)
8094

@@ -608,11 +622,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
608622
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
609623

610624

625+
def _validate_source_directory(source_directory):
626+
"""Validate that source_directory is safe to use.
627+
628+
Ensures the source directory path does not access restricted system locations.
629+
630+
Args:
631+
source_directory (str): The source directory path to validate.
632+
633+
Raises:
634+
ValueError: If the path is not allowed.
635+
"""
636+
if not source_directory or source_directory.lower().startswith("s3://"):
637+
# S3 paths and None are safe
638+
return
639+
640+
# Resolve symlinks to get the actual path
641+
abs_source = abspath(realpath(source_directory))
642+
643+
# Check if the source path is under any sensitive directory
644+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
645+
if abs_source != "/" and abs_source.startswith(sensitive_path):
646+
raise ValueError(
647+
f"source_directory cannot access sensitive system paths. "
648+
f"Got: {source_directory} (resolved to {abs_source})"
649+
)
650+
651+
652+
def _validate_dependency_path(dependency):
653+
"""Validate that a dependency path is safe to use.
654+
655+
Ensures the dependency path does not access restricted system locations.
656+
657+
Args:
658+
dependency (str): The dependency path to validate.
659+
660+
Raises:
661+
ValueError: If the path is not allowed.
662+
"""
663+
if not dependency:
664+
return
665+
666+
# Resolve symlinks to get the actual path
667+
abs_dependency = abspath(realpath(dependency))
668+
669+
# Check if the dependency path is under any sensitive directory
670+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
671+
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
672+
raise ValueError(
673+
f"dependency path cannot access sensitive system paths. "
674+
f"Got: {dependency} (resolved to {abs_dependency})"
675+
)
676+
677+
611678
def _create_or_update_code_dir(
612679
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
613680
):
614681
"""Placeholder docstring"""
615682
code_dir = os.path.join(model_dir, "code")
683+
resolved_code_dir = _get_resolved_path(code_dir)
684+
685+
# Validate that code_dir does not resolve to a sensitive system path
686+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
687+
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
688+
raise ValueError(
689+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
690+
)
691+
616692
if source_directory and source_directory.lower().startswith("s3://"):
617693
local_code_path = os.path.join(tmp, "local_code.tar.gz")
618694
download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -621,6 +697,8 @@ def _create_or_update_code_dir(
621697
custom_extractall_tarfile(t, code_dir)
622698

623699
elif source_directory:
700+
# Validate source_directory for security
701+
_validate_source_directory(source_directory)
624702
if os.path.exists(code_dir):
625703
shutil.rmtree(code_dir)
626704
shutil.copytree(source_directory, code_dir)
@@ -636,6 +714,8 @@ def _create_or_update_code_dir(
636714
raise
637715

638716
for dependency in dependencies:
717+
# Validate dependency path for security
718+
_validate_dependency_path(dependency)
639719
lib_dir = os.path.join(code_dir, "lib")
640720
if os.path.isdir(dependency):
641721
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1647,6 +1727,38 @@ def _get_safe_members(members):
16471727
yield file_info
16481728

16491729

1730+
def _validate_extracted_paths(extract_path):
1731+
"""Validate that extracted paths remain within the expected directory.
1732+
1733+
Performs post-extraction validation to ensure all extracted files and directories
1734+
are within the intended extraction path.
1735+
1736+
Args:
1737+
extract_path (str): The path where files were extracted.
1738+
1739+
Raises:
1740+
ValueError: If any extracted file is outside the expected extraction path.
1741+
"""
1742+
base = _get_resolved_path(extract_path)
1743+
1744+
for root, dirs, files in os.walk(extract_path):
1745+
# Check directories
1746+
for dir_name in dirs:
1747+
dir_path = os.path.join(root, dir_name)
1748+
resolved = _get_resolved_path(dir_path)
1749+
if not resolved.startswith(base):
1750+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
1751+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1752+
1753+
# Check files
1754+
for file_name in files:
1755+
file_path = os.path.join(root, file_name)
1756+
resolved = _get_resolved_path(file_path)
1757+
if not resolved.startswith(base):
1758+
logger.error("Extracted file escaped extraction path: %s", file_path)
1759+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
1760+
1761+
16501762
def custom_extractall_tarfile(tar, extract_path):
16511763
"""Extract a tarfile, optionally using data_filter if available.
16521764
@@ -1667,6 +1779,8 @@ def custom_extractall_tarfile(tar, extract_path):
16671779
tar.extractall(path=extract_path, filter="data")
16681780
else:
16691781
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1782+
# Re-validate extracted paths to catch symlink race conditions
1783+
_validate_extracted_paths(extract_path)
16701784

16711785

16721786
def can_model_package_source_uri_autopopulate(source_uri: str):

sagemaker-core/src/sagemaker/core/iterators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io
1818

1919
from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure
20+
from sagemaker.core.common_utils import _MAX_BUFFER_SIZE
2021

2122

2223
def handle_stream_errors(chunk):
@@ -182,5 +183,15 @@ def __next__(self):
182183
# print and move on to next response byte
183184
print("Unknown event type:" + chunk)
184185
continue
186+
187+
# Check buffer size before writing to prevent unbounded memory consumption
188+
chunk_size = len(chunk["PayloadPart"]["Bytes"])
189+
current_size = self.buffer.getbuffer().nbytes
190+
if current_size + chunk_size > _MAX_BUFFER_SIZE:
191+
raise RuntimeError(
192+
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
193+
f"No newline found in stream."
194+
)
195+
185196
self.buffer.seek(0, io.SEEK_END)
186197
self.buffer.write(chunk["PayloadPart"]["Bytes"])

sagemaker-core/src/sagemaker/core/local/data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from six.moves.urllib.parse import urlparse
2525

2626
import sagemaker.core
27+
from sagemaker.core.common_utils import _SENSITIVE_SYSTEM_PATHS
2728

2829

2930
def get_data_source_instance(data_source, sagemaker_session):
@@ -120,6 +121,15 @@ def __init__(self, root_path):
120121
super(LocalFileDataSource, self).__init__()
121122

122123
self.root_path = os.path.abspath(root_path)
124+
125+
# Validate that the path is not in restricted locations
126+
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
127+
if self.root_path != "/" and self.root_path.startswith(restricted_path):
128+
raise ValueError(
129+
f"Local Mode does not support mounting from restricted system paths. "
130+
f"Got: {root_path}"
131+
)
132+
123133
if not os.path.exists(self.root_path):
124134
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)
125135

sagemaker-core/src/sagemaker/core/local/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
4848
destination_directory
4949
"""
5050
full_path = os.path.join(destination_directory, relative_path)
51-
if os.path.exists(full_path):
52-
return
53-
54-
os.makedirs(destination_directory, relative_path)
51+
os.makedirs(full_path, exist_ok=True)
5552

5653

5754
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):

sagemaker-core/tests/unit/local/test_local_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
@patch("sagemaker.core.local.utils.os.path")
3636
@patch("sagemaker.core.local.utils.os")
3737
def test_copy_directory_structure(m_os, m_os_path):
38-
m_os_path.exists.return_value = False
38+
m_os_path.join.return_value = "/tmp/code/"
3939
copy_directory_structure("/tmp/", "code/")
40-
m_os.makedirs.assert_called_with("/tmp/", "code/")
40+
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)
4141

4242

4343
@patch("shutil.rmtree", Mock())

0 commit comments

Comments
 (0)