@@ -125,7 +125,8 @@ class ModelTrainer(BaseModel):
125125 from sagemaker.train import ModelTrainer
126126 from sagemaker.train.configs import SourceCode, Compute, InputData
127127
128- source_code = SourceCode(source_dir="source", entry_script="train.py")
128+ ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
129+ source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
129130 training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
130131 model_trainer = ModelTrainer(
131132 training_image=training_image,
@@ -612,6 +613,7 @@ def train(
612613 channel_name = SM_CODE ,
613614 data_source = self .source_code .source_dir ,
614615 key_prefix = input_data_key_prefix ,
616+ ignore_patterns = self .source_code .ignore_patterns ,
615617 )
616618 final_input_data_config .append (source_code_channel )
617619
@@ -633,6 +635,7 @@ def train(
633635 channel_name = SM_DRIVERS ,
634636 data_source = tmp_dir .name ,
635637 key_prefix = input_data_key_prefix ,
638+ ignore_patterns = self .source_code .ignore_patterns ,
636639 )
637640 final_input_data_config .append (sm_drivers_channel )
638641
@@ -742,7 +745,11 @@ def train(
742745 local_container .train (wait )
743746
744747 def create_input_data_channel (
745- self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
748+ self ,
749+ channel_name : str ,
750+ data_source : DataSourceType ,
751+ key_prefix : Optional [str ] = None ,
752+ ignore_patterns : Optional [List [str ]] = None ,
746753 ) -> Channel :
747754 """Create an input data channel for the training job.
748755
@@ -758,6 +765,9 @@ def create_input_data_channel(
758765
759766 If specified, local data will be uploaded to:
760767 ``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
768+ ignore_patterns: (Optional[List[str]]) :
769+ The ignore patterns to ignore specific files/folders when uploading to S3.
770+ If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
761771 """
762772 from sagemaker .core .helper .pipeline_variable import PipelineVariable
763773
@@ -807,11 +817,28 @@ def create_input_data_channel(
807817 )
808818 if self .sagemaker_session .default_bucket_prefix :
809819 key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
810- s3_uri = self .sagemaker_session .upload_data (
811- path = data_source ,
812- bucket = self .sagemaker_session .default_bucket (),
813- key_prefix = key_prefix ,
814- )
820+ if ignore_patterns and _is_valid_path (data_source , path_type = "Directory" ):
821+ tmp_dir = TemporaryDirectory ()
822+ copied_path = os .path .join (
823+ tmp_dir .name , os .path .basename (os .path .normpath (data_source ))
824+ )
825+ shutil .copytree (
826+ data_source ,
827+ copied_path ,
828+ dirs_exist_ok = True ,
829+ ignore = shutil .ignore_patterns (* ignore_patterns ),
830+ )
831+ s3_uri = self .sagemaker_session .upload_data (
832+ path = copied_path ,
833+ bucket = self .sagemaker_session .default_bucket (),
834+ key_prefix = key_prefix ,
835+ )
836+ else :
837+ s3_uri = self .sagemaker_session .upload_data (
838+ path = data_source ,
839+ bucket = self .sagemaker_session .default_bucket (),
840+ key_prefix = key_prefix ,
841+ )
815842 channel = Channel (
816843 channel_name = channel_name ,
817844 data_source = DataSource (
0 commit comments