Skip to content

Commit c888aac

Browse files
committed
revert to use skip_keys
1 parent 09a7b0a commit c888aac

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def apply_group_offloading(
535535
exclude_kwargs (`List[str]`, *optional*):
536536
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
537537
caching lists that need to maintain their object identity across forward passes. If not provided, will be
538-
inferred from the module's `_group_offload_exclude_kwargs` attribute if it exists.
538+
inferred from the module's `_skip_keys` attribute if it exists.
539539
540540
Example:
541541
```python
@@ -581,7 +581,7 @@ def apply_group_offloading(
581581
block_modules = getattr(module, "_group_offload_block_modules", None)
582582

583583
if exclude_kwargs is None:
584-
exclude_kwargs = getattr(module, "_group_offload_exclude_kwargs", None)
584+
exclude_kwargs = getattr(module, "_skip_keys", None)
585585

586586
config = GroupOffloadingConfig(
587587
onload_device=onload_device,

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -963,11 +963,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
963963

964964
_supports_gradient_checkpointing = False
965965
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
966-
967-
# kwargs to ignore when send_to_device moves inputs/outputs between devices
968-
# these are shared mutable states that are modified in-place and
969-
# should not be subjected to copy operations
970-
_group_offload_exclude_kwargs = ["feat_cache", "feat_idx"]
966+
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
967+
# these are shared mutable state modified in-place
968+
_skip_keys = ["feat_cache", "feat_idx"]
971969

972970
@register_to_config
973971
def __init__(

0 commit comments

Comments
 (0)