Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/sagemaker/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io

from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
from sagemaker.utils import _MAX_BUFFER_SIZE


def handle_stream_errors(chunk):
Expand Down Expand Up @@ -182,5 +183,15 @@ def __next__(self):
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

# Check buffer size before writing to prevent unbounded memory consumption
chunk_size = len(chunk["PayloadPart"]["Bytes"])
current_size = self.buffer.getbuffer().nbytes
if current_size + chunk_size > _MAX_BUFFER_SIZE:
raise RuntimeError(
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
f"No newline found in stream."
)

self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
10 changes: 10 additions & 0 deletions src/sagemaker/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import sagemaker.amazon.common
import sagemaker.local.utils
import sagemaker.utils
from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS


def get_data_source_instance(data_source, sagemaker_session):
Expand Down Expand Up @@ -122,6 +123,15 @@ def __init__(self, root_path):
super(LocalFileDataSource, self).__init__()

self.root_path = os.path.abspath(root_path)

# Validate that the path is not in restricted locations
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
if self.root_path != "/" and self.root_path.startswith(restricted_path):
raise ValueError(
f"Local Mode does not support mounting from restricted system paths. "
f"Got: {root_path}"
)

if not os.path.exists(self.root_path):
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)

Expand Down
5 changes: 1 addition & 4 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
destination_directory
"""
full_path = os.path.join(destination_directory, relative_path)
if os.path.exists(full_path):
return

os.makedirs(destination_directory, relative_path)
os.makedirs(full_path, exist_ok=True)


def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
Expand Down
114 changes: 114 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@
WAITING_DOT_NUMBER = 10
MAX_ITEMS = 100
PAGE_SIZE = 10
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators

_SENSITIVE_SYSTEM_PATHS = [
abspath(os.path.expanduser("~/.aws")),
abspath(os.path.expanduser("~/.ssh")),
abspath(os.path.expanduser("~/.kube")),
abspath(os.path.expanduser("~/.docker")),
abspath(os.path.expanduser("~/.config")),
abspath(os.path.expanduser("~/.credentials")),
abspath(realpath("/etc")),
abspath(realpath("/root")),
abspath(realpath("/var/lib")),
abspath(realpath("/opt/ml/metadata")),
]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -601,11 +615,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))


def _validate_source_directory(source_directory):
"""Validate that source_directory is safe to use.

Ensures the source directory path does not access restricted system locations.

Args:
source_directory (str): The source directory path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not source_directory or source_directory.lower().startswith("s3://"):
# S3 paths and None are safe
return

# Resolve symlinks to get the actual path
abs_source = abspath(realpath(source_directory))

# Check if the source path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_source != "/" and abs_source.startswith(sensitive_path):
raise ValueError(
f"source_directory cannot access sensitive system paths. "
f"Got: {source_directory} (resolved to {abs_source})"
)


def _validate_dependency_path(dependency):
"""Validate that a dependency path is safe to use.

Ensures the dependency path does not access restricted system locations.

Args:
dependency (str): The dependency path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not dependency:
return

# Resolve symlinks to get the actual path
abs_dependency = abspath(realpath(dependency))

# Check if the dependency path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
raise ValueError(
f"dependency path cannot access sensitive system paths. "
f"Got: {dependency} (resolved to {abs_dependency})"
)


def _create_or_update_code_dir(
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
):
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)

# Validate that code_dir does not resolve to a sensitive system path
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
raise ValueError(
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
)

if source_directory and source_directory.lower().startswith("s3://"):
local_code_path = os.path.join(tmp, "local_code.tar.gz")
download_file_from_url(source_directory, local_code_path, sagemaker_session)
Expand All @@ -614,6 +690,8 @@ def _create_or_update_code_dir(
custom_extractall_tarfile(t, code_dir)

elif source_directory:
# Validate source_directory for security
_validate_source_directory(source_directory)
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_directory, code_dir)
Expand Down Expand Up @@ -646,6 +724,8 @@ def _create_or_update_code_dir(
)

for dependency in dependencies:
# Validate dependency path for security
_validate_dependency_path(dependency)
lib_dir = os.path.join(code_dir, "lib")
if os.path.isdir(dependency):
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
Expand Down Expand Up @@ -1620,6 +1700,38 @@ def _get_safe_members(members):
yield file_info


def _validate_extracted_paths(extract_path):
"""Validate that extracted paths remain within the expected directory.

Performs post-extraction validation to ensure all extracted files and directories
are within the intended extraction path.

Args:
extract_path (str): The path where files were extracted.

Raises:
ValueError: If any extracted file is outside the expected extraction path.
"""
base = _get_resolved_path(extract_path)

for root, dirs, files in os.walk(extract_path):
# Check directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
resolved = _get_resolved_path(dir_path)
if not resolved.startswith(base):
logger.error("Extracted directory escaped extraction path: %s", dir_path)
raise ValueError(f"Extracted path outside expected directory: {dir_path}")

# Check files
for file_name in files:
file_path = os.path.join(root, file_name)
resolved = _get_resolved_path(file_path)
if not resolved.startswith(base):
logger.error("Extracted file escaped extraction path: %s", file_path)
raise ValueError(f"Extracted path outside expected directory: {file_path}")


def custom_extractall_tarfile(tar, extract_path):
"""Extract a tarfile, optionally using data_filter if available.

Expand All @@ -1640,6 +1752,8 @@ def custom_extractall_tarfile(tar, extract_path):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
# Re-validate extracted paths to catch symlink race conditions
_validate_extracted_paths(extract_path)


def can_model_package_source_uri_autopopulate(source_uri: str):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
@patch("sagemaker.local.utils.os.path")
@patch("sagemaker.local.utils.os")
def test_copy_directory_structure(m_os, m_os_path):
m_os_path.exists.return_value = False
m_os_path.join.return_value = "/tmp/code/"
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)


@patch("shutil.rmtree", Mock())
Expand Down
Loading
Loading