@@ -254,6 +254,7 @@ def get_cached_module_file(
254254 token : Optional [Union [bool , str ]] = None ,
255255 revision : Optional [str ] = None ,
256256 local_files_only : bool = False ,
257+ local_dir : Optional [str ] = None ,
257258):
258259 """
259260 Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -333,6 +334,7 @@ def get_cached_module_file(
333334 force_download = force_download ,
334335 proxies = proxies ,
335336 local_files_only = local_files_only ,
337+ local_dir = local_dir ,
336338 )
337339 submodule = "git"
338340 module_file = pretrained_model_name_or_path + ".py"
@@ -356,6 +358,7 @@ def get_cached_module_file(
356358 force_download = force_download ,
357359 proxies = proxies ,
358360 local_files_only = local_files_only ,
361+ local_dir = local_dir ,
359362 token = token ,
360363 )
361364 submodule = os .path .join ("local" , "--" .join (pretrained_model_name_or_path .split ("/" )))
@@ -416,6 +419,7 @@ def get_cached_module_file(
416419 token = token ,
417420 revision = revision ,
418421 local_files_only = local_files_only ,
422+ local_dir = local_dir ,
419423 )
420424 return os .path .join (full_submodule , module_file )
421425
@@ -432,7 +436,7 @@ def get_class_from_dynamic_module(
432436 token : Optional [Union [bool , str ]] = None ,
433437 revision : Optional [str ] = None ,
434438 local_files_only : bool = False ,
435- ** kwargs ,
439+ local_dir : Optional [ str ] = None ,
436440):
437441 """
438442 Extracts a class from a module file, present in the local folder or repository of a model.
@@ -497,5 +501,6 @@ def get_class_from_dynamic_module(
497501 token = token ,
498502 revision = revision ,
499503 local_files_only = local_files_only ,
504+ local_dir = local_dir ,
500505 )
501506 return get_class_in_module (class_name , final_module )
0 commit comments