7575WAITING_DOT_NUMBER = 10
7676MAX_ITEMS = 100
7777PAGE_SIZE = 10
78+ _MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
79+
80+ _SENSITIVE_SYSTEM_PATHS = [
81+ abspath (os .path .expanduser ("~/.aws" )),
82+ abspath (os .path .expanduser ("~/.ssh" )),
83+ abspath (os .path .expanduser ("~/.kube" )),
84+ abspath (os .path .expanduser ("~/.docker" )),
85+ abspath (os .path .expanduser ("~/.config" )),
86+ abspath (os .path .expanduser ("~/.credentials" )),
87+ "/etc" ,
88+ "/root" ,
89+ "/var/lib" ,
90+ "/opt/ml/metadata" ,
91+ ]
7892
7993logger = logging .getLogger (__name__ )
8094
@@ -608,11 +622,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
608622 shutil .move (tmp_model_path , repacked_model_uri .replace ("file://" , "" ))
609623
610624
625+ def _validate_source_directory (source_directory ):
626+ """Validate that source_directory is safe to use.
627+
628+ Ensures the source directory path does not access restricted system locations.
629+
630+ Args:
631+ source_directory (str): The source directory path to validate.
632+
633+ Raises:
634+ ValueError: If the path is not allowed.
635+ """
636+ if not source_directory or source_directory .lower ().startswith ("s3://" ):
637+ # S3 paths and None are safe
638+ return
639+
640+ # Resolve symlinks to get the actual path
641+ abs_source = abspath (realpath (source_directory ))
642+
643+ # Check if the source path is under any sensitive directory
644+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
645+ if abs_source != "/" and abs_source .startswith (sensitive_path ):
646+ raise ValueError (
647+ f"source_directory cannot access sensitive system paths. "
648+ f"Got: { source_directory } (resolved to { abs_source } )"
649+ )
650+
651+
652+ def _validate_dependency_path (dependency ):
653+ """Validate that a dependency path is safe to use.
654+
655+ Ensures the dependency path does not access restricted system locations.
656+
657+ Args:
658+ dependency (str): The dependency path to validate.
659+
660+ Raises:
661+ ValueError: If the path is not allowed.
662+ """
663+ if not dependency :
664+ return
665+
666+ # Resolve symlinks to get the actual path
667+ abs_dependency = abspath (realpath (dependency ))
668+
669+ # Check if the dependency path is under any sensitive directory
670+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
671+ if abs_dependency != "/" and abs_dependency .startswith (sensitive_path ):
672+ raise ValueError (
673+ f"dependency path cannot access sensitive system paths. "
674+ f"Got: { dependency } (resolved to { abs_dependency } )"
675+ )
676+
677+
611678def _create_or_update_code_dir (
612679 model_dir , inference_script , source_directory , dependencies , sagemaker_session , tmp
613680):
614681 """Placeholder docstring"""
615682 code_dir = os .path .join (model_dir , "code" )
683+ resolved_code_dir = _get_resolved_path (code_dir )
684+
685+ # Validate that code_dir does not resolve to a sensitive system path
686+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
687+ if resolved_code_dir != "/" and resolved_code_dir .startswith (sensitive_path ):
688+ raise ValueError (
689+ f"Invalid code_dir path: { code_dir } resolves to sensitive system path { resolved_code_dir } "
690+ )
691+
616692 if source_directory and source_directory .lower ().startswith ("s3://" ):
617693 local_code_path = os .path .join (tmp , "local_code.tar.gz" )
618694 download_file_from_url (source_directory , local_code_path , sagemaker_session )
@@ -621,6 +697,8 @@ def _create_or_update_code_dir(
621697 custom_extractall_tarfile (t , code_dir )
622698
623699 elif source_directory :
700+ # Validate source_directory for security
701+ _validate_source_directory (source_directory )
624702 if os .path .exists (code_dir ):
625703 shutil .rmtree (code_dir )
626704 shutil .copytree (source_directory , code_dir )
@@ -636,6 +714,8 @@ def _create_or_update_code_dir(
636714 raise
637715
638716 for dependency in dependencies :
717+ # Validate dependency path for security
718+ _validate_dependency_path (dependency )
639719 lib_dir = os .path .join (code_dir , "lib" )
640720 if os .path .isdir (dependency ):
641721 shutil .copytree (dependency , os .path .join (lib_dir , os .path .basename (dependency )))
@@ -1647,6 +1727,38 @@ def _get_safe_members(members):
16471727 yield file_info
16481728
16491729
1730+ def _validate_extracted_paths (extract_path ):
1731+ """Validate that extracted paths remain within the expected directory.
1732+
1733+ Performs post-extraction validation to ensure all extracted files and directories
1734+ are within the intended extraction path.
1735+
1736+ Args:
1737+ extract_path (str): The path where files were extracted.
1738+
1739+ Raises:
1740+ ValueError: If any extracted file is outside the expected extraction path.
1741+ """
1742+ base = _get_resolved_path (extract_path )
1743+
1744+ for root , dirs , files in os .walk (extract_path ):
1745+ # Check directories
1746+ for dir_name in dirs :
1747+ dir_path = os .path .join (root , dir_name )
1748+ resolved = _get_resolved_path (dir_path )
1749+ if not resolved .startswith (base ):
1750+ logger .error ("Extracted directory escaped extraction path: %s" , dir_path )
1751+ raise ValueError (f"Extracted path outside expected directory: { dir_path } " )
1752+
1753+ # Check files
1754+ for file_name in files :
1755+ file_path = os .path .join (root , file_name )
1756+ resolved = _get_resolved_path (file_path )
1757+ if not resolved .startswith (base ):
1758+ logger .error ("Extracted file escaped extraction path: %s" , file_path )
1759+ raise ValueError (f"Extracted path outside expected directory: { file_path } " )
1760+
1761+
16501762def custom_extractall_tarfile (tar , extract_path ):
16511763 """Extract a tarfile, optionally using data_filter if available.
16521764
@@ -1667,6 +1779,8 @@ def custom_extractall_tarfile(tar, extract_path):
16671779 tar .extractall (path = extract_path , filter = "data" )
16681780 else :
16691781 tar .extractall (path = extract_path , members = _get_safe_members (tar ))
1782+ # Re-validate extracted paths to catch symlink race conditions
1783+ _validate_extracted_paths (extract_path )
16701784
16711785
16721786def can_model_package_source_uri_autopopulate (source_uri : str ):
0 commit comments