Skip to content

Commit ea4f29f

Browse files
authored
Merge branch 'main' into custom-modular-tests
2 parents b8809f7 + d54622c commit ea4f29f

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def from_pretrained(
305305
"cache_dir",
306306
"force_download",
307307
"local_files_only",
308+
"local_dir",
308309
"proxies",
309310
"resume_download",
310311
"revision",
@@ -331,11 +332,10 @@ def from_pretrained(
331332
module_file=module_file,
332333
class_name=class_name,
333334
**hub_kwargs,
334-
**kwargs,
335335
)
336336
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
337337
block_kwargs = {
338-
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
338+
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
339339
}
340340

341341
return block_cls(**block_kwargs)

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_cached_module_file(
254254
token: Optional[Union[bool, str]] = None,
255255
revision: Optional[str] = None,
256256
local_files_only: bool = False,
257+
local_dir: Optional[str] = None,
257258
):
258259
"""
259260
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -333,6 +334,7 @@ def get_cached_module_file(
333334
force_download=force_download,
334335
proxies=proxies,
335336
local_files_only=local_files_only,
337+
local_dir=local_dir,
336338
)
337339
submodule = "git"
338340
module_file = pretrained_model_name_or_path + ".py"
@@ -356,6 +358,7 @@ def get_cached_module_file(
356358
force_download=force_download,
357359
proxies=proxies,
358360
local_files_only=local_files_only,
361+
local_dir=local_dir,
359362
token=token,
360363
)
361364
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -416,6 +419,7 @@ def get_cached_module_file(
416419
token=token,
417420
revision=revision,
418421
local_files_only=local_files_only,
422+
local_dir=local_dir,
419423
)
420424
return os.path.join(full_submodule, module_file)
421425

@@ -432,7 +436,7 @@ def get_class_from_dynamic_module(
432436
token: Optional[Union[bool, str]] = None,
433437
revision: Optional[str] = None,
434438
local_files_only: bool = False,
435-
**kwargs,
439+
local_dir: Optional[str] = None,
436440
):
437441
"""
438442
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -497,5 +501,6 @@ def get_class_from_dynamic_module(
497501
token=token,
498502
revision=revision,
499503
local_files_only=local_files_only,
504+
local_dir=local_dir,
500505
)
501506
return get_class_in_module(class_name, final_module)

0 commit comments

Comments
 (0)