@@ -62,7 +62,7 @@ def get_hyperparams(log_dir, timestamp):
6262 "patch_size" : values .get ('patch_size' ), "embed_dim" : values .get ('embed_dim' )}
6363 return extract_hyperparams
6464
65- def main ():
65+ def main_run ():
6666 log_dir = 'logs'
6767 timestamp = get_latest_timestamp (os .path .join (log_dir , 'train/multiruns' ))
6868
@@ -177,5 +177,39 @@ def main():
177177 with open ('best_model_checkpoint.txt' , 'w' ) as f :
178178 f .write (f"./model_storage/epoch-checkpoint_patch_size-{ hparams_data ['best_params' ]['model.patch_size' ]} _embed_dim-{ hparams_data ['best_params' ]['model.embed_dim' ]} .ckpt" )
179179
180+ import shutil
181+
182+ # Define the source file and destination folder
183+ source_file = 'best_model_checkpoint.txt'
184+ destination_folder = 'model_storage/'
185+
186+ # Copy the file to the destination folder
187+ shutil .copy (source_file , destination_folder )
188+
189+ print (f"{ source_file } has been copied to { destination_folder } " )
190+
191+
192+ # Define the path to the checkpoint file and folder containing .ckpt files
193+ checkpoint_file = 'best_model_checkpoint.txt'
194+ checkpoint_folder = 'model_storage'
195+
196+ # Read the first line of the checkpoint file to get the file to keep
197+ with open (checkpoint_file , 'r' ) as f :
198+ keep_file = f .readline ().strip ()
199+
200+ # Get the full path of the file to keep
201+ keep_file_path = os .path .join (checkpoint_folder , os .path .basename (keep_file ))
202+
203+ # Iterate over files in the checkpoint folder and delete unwanted .ckpt files
204+ for file in os .listdir (checkpoint_folder ):
205+ file_path = os .path .join (checkpoint_folder , file )
206+ if file_path .endswith ('.ckpt' ) and file_path != keep_file_path :
207+ os .remove (file_path )
208+ print (f"Removed: { file_path } " )
209+
210+ print (f"Kept: { keep_file_path } " )
211+
212+
213+
180214if __name__ == "__main__" :
181- main ()
215+ main_run ()
0 commit comments