From ad1fc3747324da45d499838a341ceb89e61f31af Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:22:10 +0530 Subject: [PATCH 01/13] fix: group offloading to support standalone computational layers in block-level offloading --- src/diffusers/hooks/group_offloading.py | 204 +++++++++++++++++++++--- 1 file changed, 181 insertions(+), 23 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 38f291f5203c..4978e48d2d0f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -578,6 +578,10 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. + + Standalone computational layers (Conv2d, Linear, etc.) that are not part of ModuleList/Sequential are treated + individually with leaf-level logic to ensure proper device management. This includes computational layers nested + within container modules. """ if config.stream is not None and config.num_blocks_per_group != 1: @@ -589,11 +593,20 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() unmatched_modules = [] + unmatched_computational_layers = [] matched_module_groups = [] for name, submodule in module.named_children(): if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append((name, submodule)) - modules_with_group_offloading.add(name) + # Check if this is a computational layer that should be handled individually + if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + unmatched_computational_layers.append((name, submodule)) + modules_with_group_offloading.add(name) + else: + # This is a container module - recursively find computational layers within it + _find_and_apply_computational_layer_hooks(submodule, name, config, modules_with_group_offloading) + unmatched_modules.append((name, submodule)) + # Do NOT add the container name to modules_with_group_offloading here, because we need + # parameters from non-computational sublayers (like GroupNorm) to be gathered continue for i in range(0, len(submodule), config.num_blocks_per_group): @@ -622,6 +635,25 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf for group_module in group.modules: _apply_group_offloading_hook(group_module, group, config=config) + # Apply leaf-level treatment to standalone computational layers at the top level + # Each computational layer gets its own ModuleGroup with hooks registered directly on it + for name, comp_layer in unmatched_computational_layers: + group = ModuleGroup( + modules=[comp_layer], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=comp_layer, + onload_leader=comp_layer, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=name, + ) + _apply_group_offloading_hook(comp_layer, group, config=config) + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not # part of any group (as doing so would lead to no VRAM savings). @@ -630,28 +662,154 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf parameters = [param for _, param in parameters] buffers = [buffer for _, buffer in buffers] - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. + # Create a group for the remaining unmatched submodules (non-computational containers) of the top-level + # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) - else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + + +def _find_and_apply_computational_layer_hooks( + container_module: torch.nn.Module, + container_name: str, + config: GroupOffloadingConfig, + modules_with_group_offloading: Set[str], +) -> None: + r""" + Recursively finds all computational layers within a container module and applies individual hooks to them. + This ensures that standalone Conv2d, Linear, etc. layers nested inside container modules (like Encoder/Decoder) + get proper device management. + """ + for name, submodule in container_module.named_modules(): + if name == "": # Skip the container itself + continue + + # Only apply hooks to supported computational layers + if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + full_name = f"{container_name}.{name}" + group = ModuleGroup( + modules=[submodule], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=full_name, + ) + _apply_group_offloading_hook(submodule, group, config=config) + modules_with_group_offloading.add(full_name) + + # Also handle parameters and buffers at non-leaf levels within the container + # This is similar to what leaf-level offloading does + module_dict = dict(container_module.named_modules()) + parameters = [] + buffers = [] + + for name, param in container_module.named_parameters(): + # Check if this parameter has a parent that already got a hook + has_parent_with_hook = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + full_parent_name = f"{container_name}.{parent_name}" + if full_parent_name in modules_with_group_offloading: + has_parent_with_hook = True + break + atoms.pop() + + if not has_parent_with_hook: + parameters.append((name, param)) + + for name, buffer in container_module.named_buffers(): + # Check if this buffer has a parent that already got a hook + has_parent_with_hook = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + full_parent_name = f"{container_name}.{parent_name}" + if full_parent_name in modules_with_group_offloading: + has_parent_with_hook = True + break + atoms.pop() + + if not has_parent_with_hook: + buffers.append((name, buffer)) + + # Group parameters and buffers by their immediate parent module and apply hooks + parent_to_parameters = {} + for name, param in parameters: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + if parent_name in parent_to_parameters: + parent_to_parameters[parent_name].append(param) + else: + parent_to_parameters[parent_name] = [param] + break + atoms.pop() + + parent_to_buffers = {} + for name, buffer in buffers: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + if parent_name in parent_to_buffers: + parent_to_buffers[parent_name].append(buffer) + else: + parent_to_buffers[parent_name] = [buffer] + break + atoms.pop() + + parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) + for name in parent_names: + params = parent_to_parameters.get(name, []) + bufs = parent_to_buffers.get(name, []) + parent_module = module_dict[name] + full_parent_name = f"{container_name}.{name}" + + group = ModuleGroup( + modules=[], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_leader=parent_module, + onload_leader=parent_module, + offload_to_disk_path=config.offload_to_disk_path, + parameters=params, + buffers=bufs, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=full_parent_name, + ) + _apply_group_offloading_hook(parent_module, group, config=config) + modules_with_group_offloading.add(full_parent_name) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: From 59b6b678295214b70f6ecaa3f95129b76baf50d8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:37:10 +0530 Subject: [PATCH 02/13] test: for models with standalone and deeply nested layers in block-level offloading --- tests/hooks/test_group_offloading.py | 298 +++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..9099fb49afcb 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -149,6 +149,146 @@ def post_forward(self, module, output): return output +# Model simulating VAE structure with standalone computational layers +class DummyVAELikeModel(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + # Encoder container (not ModuleList/Sequential at top level) + self.encoder = torch.nn.Sequential( + torch.nn.Linear(in_features, hidden_features), + torch.nn.ReLU(), + ) + + # Standalone Conv2d layer (simulates quant_conv) + self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) + + # Decoder container with nested ModuleList + self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features) + + # Standalone Conv2d layer (simulates post_quant_conv) + self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) + + # Output projection + self.linear_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Encode + x = self.encoder(x) + + # Reshape for conv operations + batch_size = x.shape[0] + x_reshaped = x.view(batch_size, 1, -1, 1) + + # Apply standalone conv layers + x_reshaped = self.quant_conv(x_reshaped) + x_reshaped = self.post_quant_conv(x_reshaped) + + # Reshape back + x = x_reshaped.view(batch_size, -1) + + # Decode + x = self.decoder(x) + + # Output + x = self.linear_out(x) + return x + + +class DecoderWithNestedBlocks(torch.nn.Module): + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + + # Container modules (not ModuleList/Sequential) + self.conv_in = torch.nn.Linear(in_features, in_features) + + # Nested ModuleList (like VAE's decoder.up_blocks) + self.up_blocks = torch.nn.ModuleList( + [torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)] + ) + + # Non-computational layer + self.norm = torch.nn.LayerNorm(in_features) + + self.conv_out = torch.nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_in(x) + for block in self.up_blocks: + x = block(x) + x = self.norm(x) + x = self.conv_out(x) + return x + + +# Model with only standalone computational layers at top level +class DummyModelWithStandaloneLayers(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.layer1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(hidden_features, hidden_features) + self.layer3 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + +# Model with deeply nested structure +class DummyModelWithDeeplyNestedBlocks(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.input_layer = torch.nn.Linear(in_features, hidden_features) + self.container = ContainerWithNestedModuleList(hidden_features) + self.output_layer = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_layer(x) + x = self.container(x) + x = self.output_layer(x) + return x + + +class ContainerWithNestedModuleList(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + # Top-level computational layer + self.proj_in = torch.nn.Linear(features, features) + + # Nested container with ModuleList + self.nested_container = NestedContainer(features) + + # Another top-level computational layer + self.proj_out = torch.nn.Linear(features, features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.nested_container(x) + x = self.proj_out(x) + return x + + +class NestedContainer(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)]) + self.norm = torch.nn.LayerNorm(features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + + @require_torch_accelerator class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -362,3 +502,161 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_vae_like_model_with_standalone_conv_layers(self): + """Test that models with standalone Conv2d layers (like VAE) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.") + + def test_vae_like_model_without_streams(self): + """Test VAE-like model with block-level offloading but without streams.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." + ) + + def test_model_with_only_standalone_layers(self): + """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for model with standalone layers." + ) + + def test_model_with_deeply_nested_blocks(self): + """Test models with deeply nested structure where ModuleList is not at top level.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for deeply nested model.") + + @parameterized.expand([("block_level",), ("leaf_level",)]) + def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str): + """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match for standalone Conv layers with {offload_type}.", + ) + + def test_multiple_invocations_with_vae_like_model(self): + """Test that multiple forward passes work correctly with VAE-like model.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(5): + out_ref = model_ref(x) + out = model(x) + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") + + def test_nested_container_parameters_offloading(self): + """Test that parameters from non-computational layers in nested containers are handled correctly.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(3): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for nested parameters.", + ) From fa94f37f441de7494b0fc726644fa67ef358b90d Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:39:53 +0530 Subject: [PATCH 03/13] feat: support for block-level offloading in group offloading config --- src/diffusers/hooks/group_offloading.py | 241 ++++++++---------------- src/diffusers/models/modeling_utils.py | 5 + 2 files changed, 88 insertions(+), 158 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 4978e48d2d0f..f9189443ee0f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -59,6 +59,7 @@ class GroupOffloadingConfig: num_blocks_per_group: Optional[int] = None offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + block_modules: Optional[List[str]] = None class ModuleGroup: @@ -77,7 +78,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, + group_id: Optional[Union[int, str]] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -453,6 +454,7 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[List[str]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -510,6 +512,9 @@ def apply_group_offloading( If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + block_modules (`List[str]`, *optional*): + List of module names that should be treated as blocks for offloading. If provided, only these modules + will be considered for block-level offloading. If not provided, the default block detection logic will be used. Example: ```python @@ -561,6 +566,7 @@ def apply_group_offloading( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, ) _apply_group_offloading(module, config) @@ -576,84 +582,67 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to - the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading + is done at the top-level blocks and modules specified in block_modules. - Standalone computational layers (Conv2d, Linear, etc.) that are not part of ModuleList/Sequential are treated - individually with leaf-level logic to ensure proper device management. This includes computational layers nested - within container modules. + When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified + module, we either offload the entire submodule or recursively apply block offloading to it. """ - if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) config.num_blocks_per_group = 1 - # Create module groups for ModuleList and Sequential blocks + block_modules = set(config.block_modules) if config.block_modules is not None else set() + + # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] - unmatched_computational_layers = [] matched_module_groups = [] - for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - # Check if this is a computational layer that should be handled individually - if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): - unmatched_computational_layers.append((name, submodule)) - modules_with_group_offloading.add(name) - else: - # This is a container module - recursively find computational layers within it - _find_and_apply_computational_layer_hooks(submodule, name, config, modules_with_group_offloading) - unmatched_modules.append((name, submodule)) - # Do NOT add the container name to modules_with_group_offloading here, because we need - # parameters from non-computational sublayers (like GroupNorm) to be gathered - continue - for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - group = ModuleGroup( - modules=current_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=current_modules[-1], - onload_leader=current_modules[0], - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=group_id, + for name, submodule in module.named_children(): + # Check if this is an explicitly defined block module + if name in block_modules: + # Apply block offloading to the specified submodule + _apply_block_offloading_to_submodule( + submodule, name, config, modules_with_group_offloading, matched_module_groups ) - matched_module_groups.append(group) - for j in range(i, i + len(current_modules)): - modules_with_group_offloading.add(f"{name}.{j}") + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + group = ModuleGroup( + modules=current_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # This is an unmatched module + unmatched_modules.append((name, submodule)) # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): for group_module in group.modules: _apply_group_offloading_hook(group_module, group, config=config) - # Apply leaf-level treatment to standalone computational layers at the top level - # Each computational layer gets its own ModuleGroup with hooks registered directly on it - for name, comp_layer in unmatched_computational_layers: - group = ModuleGroup( - modules=[comp_layer], - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=comp_layer, - onload_leader=comp_layer, - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=name, - ) - _apply_group_offloading_hook(comp_layer, group, config=config) - # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not # part of any group (as doing so would lead to no VRAM savings). @@ -662,7 +651,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf parameters = [param for _, param in parameters] buffers = [buffer for _, buffer in buffers] - # Create a group for the remaining unmatched submodules (non-computational containers) of the top-level + # Create a group for the remaining unmatched submodules of the top-level # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: @@ -687,129 +676,65 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) -def _find_and_apply_computational_layer_hooks( - container_module: torch.nn.Module, - container_name: str, +def _apply_block_offloading_to_submodule( + submodule: torch.nn.Module, + name: str, config: GroupOffloadingConfig, modules_with_group_offloading: Set[str], + matched_module_groups: List[ModuleGroup], ) -> None: r""" - Recursively finds all computational layers within a container module and applies individual hooks to them. - This ensures that standalone Conv2d, Linear, etc. layers nested inside container modules (like Encoder/Decoder) - get proper device management. + Apply block offloading to a explicitly defined submodule. This function either: + 1. Offloads the entire submodule as a single group ( SIMPLE APPROACH) + 2. Recursively applies block offloading to the submodule + + For now, we use the simple approach - offload the entire submodule as a single group. """ - for name, submodule in container_module.named_modules(): - if name == "": # Skip the container itself - continue + # Simple approach: offload the entire submodule as a single group + # Since AEs are typically small, this is usually okay + if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # If it's a ModuleList or Sequential, apply the normal block-level logic + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue - # Only apply hooks to supported computational layers - if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): - full_name = f"{container_name}.{name}" + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( - modules=[submodule], + modules=current_modules, offload_device=config.offload_device, onload_device=config.onload_device, offload_to_disk_path=config.offload_to_disk_path, - offload_leader=submodule, - onload_leader=submodule, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], non_blocking=config.non_blocking, stream=config.stream, record_stream=config.record_stream, low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, - group_id=full_name, + group_id=group_id, ) - _apply_group_offloading_hook(submodule, group, config=config) - modules_with_group_offloading.add(full_name) - - # Also handle parameters and buffers at non-leaf levels within the container - # This is similar to what leaf-level offloading does - module_dict = dict(container_module.named_modules()) - parameters = [] - buffers = [] - - for name, param in container_module.named_parameters(): - # Check if this parameter has a parent that already got a hook - has_parent_with_hook = False - atoms = name.split(".") - while len(atoms) > 0: - parent_name = ".".join(atoms) - full_parent_name = f"{container_name}.{parent_name}" - if full_parent_name in modules_with_group_offloading: - has_parent_with_hook = True - break - atoms.pop() - - if not has_parent_with_hook: - parameters.append((name, param)) - - for name, buffer in container_module.named_buffers(): - # Check if this buffer has a parent that already got a hook - has_parent_with_hook = False - atoms = name.split(".") - while len(atoms) > 0: - parent_name = ".".join(atoms) - full_parent_name = f"{container_name}.{parent_name}" - if full_parent_name in modules_with_group_offloading: - has_parent_with_hook = True - break - atoms.pop() - - if not has_parent_with_hook: - buffers.append((name, buffer)) - - # Group parameters and buffers by their immediate parent module and apply hooks - parent_to_parameters = {} - for name, param in parameters: - atoms = name.split(".") - while len(atoms) > 0: - parent_name = ".".join(atoms) - if parent_name in module_dict: - if parent_name in parent_to_parameters: - parent_to_parameters[parent_name].append(param) - else: - parent_to_parameters[parent_name] = [param] - break - atoms.pop() - - parent_to_buffers = {} - for name, buffer in buffers: - atoms = name.split(".") - while len(atoms) > 0: - parent_name = ".".join(atoms) - if parent_name in module_dict: - if parent_name in parent_to_buffers: - parent_to_buffers[parent_name].append(buffer) - else: - parent_to_buffers[parent_name] = [buffer] - break - atoms.pop() - - parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) - for name in parent_names: - params = parent_to_parameters.get(name, []) - bufs = parent_to_buffers.get(name, []) - parent_module = module_dict[name] - full_parent_name = f"{container_name}.{name}" - + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # For other modules, treat the entire submodule as a single group group = ModuleGroup( - modules=[], + modules=[submodule], offload_device=config.offload_device, onload_device=config.onload_device, - offload_leader=parent_module, - onload_leader=parent_module, offload_to_disk_path=config.offload_to_disk_path, - parameters=params, - buffers=bufs, + offload_leader=submodule, + onload_leader=submodule, non_blocking=config.non_blocking, stream=config.stream, record_stream=config.record_stream, low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, - group_id=full_parent_name, + group_id=name, ) - _apply_group_offloading_hook(parent_module, group, config=config) - modules_with_group_offloading.add(full_parent_name) + matched_module_groups.append(group) + modules_with_group_offloading.add(name) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..881225989269 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -570,6 +570,10 @@ def enable_group_offload( f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"open an issue at https://github.com/huggingface/diffusers/issues." ) + + # Get block modules from the model if available + block_modules = getattr(self, "_group_offload_block_modules", None) + apply_group_offloading( module=self, onload_device=onload_device, @@ -581,6 +585,7 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, ) def set_attention_backend(self, backend: str) -> None: From fb8a74195dbb9fb2b0677f906450cb67f601e8bf Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:55:27 +0530 Subject: [PATCH 04/13] fix: group offload block modules to AutoencoderKL and AutoencoderKLWan --- src/diffusers/models/autoencoders/autoencoder_kl.py | 1 + src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ffc8778e7aca..4096b7c07609 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b0b2960aaf18..6b29a6273cd9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -964,6 +964,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo # keys toignore when AlignDeviceHook moves inputs/outputs between devices # these are shared mutable state modified in-place _skip_keys = ["feat_cache", "feat_idx"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( From e71d91edd8b4784f977a92d92eb001f85da38bce Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 25 Nov 2025 01:08:19 +0530 Subject: [PATCH 05/13] fix: update group offloading tests to use AutoencoderKL and adjust input dimensions --- tests/hooks/test_group_offloading.py | 144 +++++++-------------------- 1 file changed, 35 insertions(+), 109 deletions(-) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 9099fb49afcb..565b846025ef 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -19,6 +19,7 @@ import torch from parameterized import parameterized +from diffusers import AutoencoderKL from diffusers.hooks import HookRegistry, ModelHook from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -149,78 +150,6 @@ def post_forward(self, module, output): return output -# Model simulating VAE structure with standalone computational layers -class DummyVAELikeModel(ModelMixin): - def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: - super().__init__() - - # Encoder container (not ModuleList/Sequential at top level) - self.encoder = torch.nn.Sequential( - torch.nn.Linear(in_features, hidden_features), - torch.nn.ReLU(), - ) - - # Standalone Conv2d layer (simulates quant_conv) - self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) - - # Decoder container with nested ModuleList - self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features) - - # Standalone Conv2d layer (simulates post_quant_conv) - self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) - - # Output projection - self.linear_out = torch.nn.Linear(hidden_features, out_features) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Encode - x = self.encoder(x) - - # Reshape for conv operations - batch_size = x.shape[0] - x_reshaped = x.view(batch_size, 1, -1, 1) - - # Apply standalone conv layers - x_reshaped = self.quant_conv(x_reshaped) - x_reshaped = self.post_quant_conv(x_reshaped) - - # Reshape back - x = x_reshaped.view(batch_size, -1) - - # Decode - x = self.decoder(x) - - # Output - x = self.linear_out(x) - return x - - -class DecoderWithNestedBlocks(torch.nn.Module): - def __init__(self, in_features: int, out_features: int) -> None: - super().__init__() - - # Container modules (not ModuleList/Sequential) - self.conv_in = torch.nn.Linear(in_features, in_features) - - # Nested ModuleList (like VAE's decoder.up_blocks) - self.up_blocks = torch.nn.ModuleList( - [torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)] - ) - - # Non-computational layer - self.norm = torch.nn.LayerNorm(in_features) - - self.conv_out = torch.nn.Linear(in_features, out_features) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.conv_in(x) - for block in self.up_blocks: - x = block(x) - x = self.norm(x) - x = self.conv_out(x) - return x - - # Model with only standalone computational layers at top level class DummyModelWithStandaloneLayers(ModelMixin): def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: @@ -503,45 +432,25 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) - def test_vae_like_model_with_standalone_conv_layers(self): - """Test that models with standalone Conv2d layers (like VAE) work with block-level offloading.""" - if torch.device(torch_device).type not in ["cuda", "xpu"]: - return - - model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) - - model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) - model_ref.load_state_dict(model.state_dict(), strict=True) - model_ref.to(torch_device) - - model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) - - x = torch.randn(2, 64).to(torch_device) - - with torch.no_grad(): - out_ref = model_ref(x) - out = model(x) - - self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.") - def test_vae_like_model_without_streams(self): """Test VAE-like model with block-level offloading but without streams.""" if torch.device(torch_device).type not in ["cuda", "xpu"]: return - model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) - model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref = AutoencoderKL(**config) model_ref.load_state_dict(model.state_dict(), strict=True) model_ref.to(torch_device) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) - x = torch.randn(2, 64).to(torch_device) + x = torch.randn(2, 3, 32, 32).to(torch_device) with torch.no_grad(): - out_ref = model_ref(x) - out = model(x) + out_ref = model_ref(x).sample + out = model(x).sample self.assertTrue( torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." @@ -597,19 +506,20 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str) if torch.device(torch_device).type not in ["cuda", "xpu"]: return - model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) - model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref = AutoencoderKL(**config) model_ref.load_state_dict(model.state_dict(), strict=True) model_ref.to(torch_device) model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) - x = torch.randn(2, 64).to(torch_device) + x = torch.randn(2, 3, 32, 32).to(torch_device) with torch.no_grad(): - out_ref = model_ref(x) - out = model(x) + out_ref = model_ref(x).sample + out = model(x).sample self.assertTrue( torch.allclose(out_ref, out, atol=1e-5), @@ -621,20 +531,21 @@ def test_multiple_invocations_with_vae_like_model(self): if torch.device(torch_device).type not in ["cuda", "xpu"]: return - model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) - model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref = AutoencoderKL(**config) model_ref.load_state_dict(model.state_dict(), strict=True) model_ref.to(torch_device) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) - x = torch.randn(2, 64).to(torch_device) + x = torch.randn(2, 3, 32, 32).to(torch_device) with torch.no_grad(): - for i in range(5): - out_ref = model_ref(x) - out = model(x) + for i in range(2): + out_ref = model_ref(x).sample + out = model(x).sample self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") def test_nested_container_parameters_offloading(self): @@ -660,3 +571,18 @@ def test_nested_container_parameters_offloading(self): torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i} for nested parameters.", ) + + def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [2, 4] + norm_num_groups = norm_num_groups or 2 + init_dict = { + "block_out_channels": block_out_channels, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + "layers_per_block": 1, + } + return init_dict From e7711435bd5197cdb61e836f32ddca2bd8b5c12d Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 3 Dec 2025 23:08:27 +0530 Subject: [PATCH 06/13] refactor: streamline block offloading logic --- src/diffusers/hooks/group_offloading.py | 66 +------------------------ 1 file changed, 2 insertions(+), 64 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f9189443ee0f..a91a93dd7818 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -606,9 +606,8 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Check if this is an explicitly defined block module if name in block_modules: # Apply block offloading to the specified submodule - _apply_block_offloading_to_submodule( - submodule, name, config, modules_with_group_offloading, matched_module_groups - ) + _apply_group_offloading_block_level(submodule, config) + modules_with_group_offloading.add(name) elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): # Handle ModuleList and Sequential blocks as before for i in range(0, len(submodule), config.num_blocks_per_group): @@ -676,67 +675,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) -def _apply_block_offloading_to_submodule( - submodule: torch.nn.Module, - name: str, - config: GroupOffloadingConfig, - modules_with_group_offloading: Set[str], - matched_module_groups: List[ModuleGroup], -) -> None: - r""" - Apply block offloading to a explicitly defined submodule. This function either: - 1. Offloads the entire submodule as a single group ( SIMPLE APPROACH) - 2. Recursively applies block offloading to the submodule - - For now, we use the simple approach - offload the entire submodule as a single group. - """ - # Simple approach: offload the entire submodule as a single group - # Since AEs are typically small, this is usually okay - if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - # If it's a ModuleList or Sequential, apply the normal block-level logic - for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = list(submodule[i : i + config.num_blocks_per_group]) - if len(current_modules) == 0: - continue - - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - group = ModuleGroup( - modules=current_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=current_modules[-1], - onload_leader=current_modules[0], - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=group_id, - ) - matched_module_groups.append(group) - for j in range(i, i + len(current_modules)): - modules_with_group_offloading.add(f"{name}.{j}") - else: - # For other modules, treat the entire submodule as a single group - group = ModuleGroup( - modules=[submodule], - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=submodule, - onload_leader=submodule, - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=name, - ) - matched_module_groups.append(group) - modules_with_group_offloading.add(name) - - def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory From ab9b249d371fae90dccb8074a205918686b88bf0 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 3 Dec 2025 23:58:18 +0000 Subject: [PATCH 07/13] Apply style fixes --- src/diffusers/hooks/group_offloading.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a91a93dd7818..661951d8df9e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -513,8 +513,8 @@ def apply_group_offloading( option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. block_modules (`List[str]`, *optional*): - List of module names that should be treated as blocks for offloading. If provided, only these modules - will be considered for block-level offloading. If not provided, the default block detection logic will be used. + List of module names that should be treated as blocks for offloading. If provided, only these modules will + be considered for block-level offloading. If not provided, the default block detection logic will be used. Example: ```python @@ -583,8 +583,8 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly - defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading - is done at the top-level blocks and modules specified in block_modules. + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is + done at the top-level blocks and modules specified in block_modules. When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified module, we either offload the entire submodule or recursively apply block offloading to it. From 26bccde7617b257a31d9b04467869d02f1b26e21 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 4 Dec 2025 09:52:24 +0100 Subject: [PATCH 08/13] update tests --- tests/models/test_modeling_common.py | 1 + tests/testing_utils.py | 36 ++++++++++++++++++---------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6f4c3d544b45..f2338c4a8f70 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1851,6 +1851,7 @@ def _run_forward(model, inputs_dict): offload_to_disk_path=tmpdir, offload_type=offload_type, num_blocks_per_group=num_blocks_per_group, + block_modules=model._group_offload_block_modules if hasattr(model, "_group_offload_block_modules") else None ) if not is_correct: if extra_files: diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 6ed7e3467d7f..3ece3ae69538 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1424,6 +1424,7 @@ def _get_expected_safetensors_files( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None ) -> Set[str]: expected_files = set() @@ -1435,22 +1436,32 @@ def get_hashed_filename(group_id: str) -> str: if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") - # Handle groups of ModuleList and Sequential blocks + block_modules_set = set(block_modules) if block_modules is not None else set() + + # Handle groups of ModuleList and Sequential blocks, and explicitly defined block modules unmatched_modules = [] for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Check if this is an explicitly defined block module + if name in block_modules_set: + # Recursively get expected files for the specified submodule + submodule_files = _get_expected_safetensors_files( + submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules + ) + expected_files.update(submodule_files) + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + if not current_modules: + continue + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + expected_files.add(get_hashed_filename(group_id)) + else: + # This is an unmatched module unmatched_modules.append(module) - continue - - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] - if not current_modules: - continue - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - expected_files.add(get_hashed_filename(group_id)) # Handle the group for unmatched top-level modules and parameters - for module in unmatched_modules: + if len(unmatched_modules) > 0: expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group")) elif offload_type == "leaf_level": @@ -1492,12 +1503,13 @@ def _check_safetensors_serialization( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None, ) -> bool: if not os.path.isdir(offload_to_disk_path): return False, None, None expected_files = _get_expected_safetensors_files( - module, offload_to_disk_path, offload_type, num_blocks_per_group + module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules ) actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) missing_files = expected_files - actual_files From f305934de636df8d4220a87fe2234876df2122f2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 4 Dec 2025 10:56:36 +0100 Subject: [PATCH 09/13] update --- src/diffusers/hooks/group_offloading.py | 6 +++++- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 4 ++++ tests/models/test_modeling_common.py | 4 +++- tests/testing_utils.py | 2 +- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 661951d8df9e..288d70e28f49 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -322,6 +322,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs def post_forward(self, module: torch.nn.Module, output): @@ -608,6 +609,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Apply block offloading to the specified submodule _apply_group_offloading_block_level(submodule, config) modules_with_group_offloading.add(name) + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): # Handle ModuleList and Sequential blocks as before for i in range(0, len(submodule), config.num_blocks_per_group): @@ -653,7 +655,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Create a group for the remaining unmatched submodules of the top-level # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + has_unmatched = len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0 + + if has_unmatched or len(block_modules) > 0: unmatched_group = ModuleGroup( modules=unmatched_modules, offload_device=config.offload_device, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 6b29a6273cd9..b332a964c310 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1146,6 +1146,9 @@ def _encode(self, x: torch.Tensor): feat_idx=self._enc_conv_idx, ) out = torch.cat([out, out_], 2) + __import__("ipdb").set_trace() + # cache_devices = [i.device.type for i in self._enc_feat_map] + # any((d != "cuda" for d in cache_devices)) enc = self.quant_conv(out) self.clear_cache() @@ -1409,6 +1412,7 @@ def forward( """ x = sample posterior = self.encode(x).latent_dist + if sample_posterior: z = posterior.sample(generator=generator) else: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f2338c4a8f70..988e85c6bbb3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1851,7 +1851,9 @@ def _run_forward(model, inputs_dict): offload_to_disk_path=tmpdir, offload_type=offload_type, num_blocks_per_group=num_blocks_per_group, - block_modules=model._group_offload_block_modules if hasattr(model, "_group_offload_block_modules") else None + block_modules=model._group_offload_block_modules + if hasattr(model, "_group_offload_block_modules") + else None, ) if not is_correct: if extra_files: diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 3ece3ae69538..69135e16cee0 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1424,7 +1424,7 @@ def _get_expected_safetensors_files( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, - block_modules: Optional[List[str]] = None + block_modules: Optional[List[str]] = None, ) -> Set[str]: expected_files = set() From cf65ae34493fbe685c0ace548204705eee3a37d9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 5 Dec 2025 04:31:30 +0100 Subject: [PATCH 10/13] fix for failing tests --- src/diffusers/hooks/group_offloading.py | 46 +++++++++++++++---- .../models/autoencoders/autoencoder_kl_wan.py | 12 ++--- src/diffusers/models/modeling_utils.py | 6 +-- tests/models/test_modeling_common.py | 1 - tests/testing_utils.py | 25 +++++++--- 5 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 288d70e28f49..5844bb647a09 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -15,7 +15,7 @@ import hashlib import os from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union @@ -60,6 +60,8 @@ class GroupOffloadingConfig: offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None block_modules: Optional[List[str]] = None + exclude_kwargs: Optional[List[str]] = None + module_prefix: Optional[str] = "" class ModuleGroup: @@ -321,7 +323,20 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): self.group.stream.synchronize() args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) - kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + + # Some Autoencoder models use a feature cache that is passed through submodules + # and modified in place. The `send_to_device` call returns a copy of this feature cache object + # which causes issues with inplace updates. Use `exclude_kwargs` to mark these cache features + exclude_kwargs = self.config.exclude_kwargs or [] + if exclude_kwargs: + moved_kwargs = send_to_device( + {k: v for k, v in kwargs.items() if k not in exclude_kwargs}, + self.group.onload_device, + non_blocking=self.group.non_blocking, + ) + kwargs.update(moved_kwargs) + else: + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) return args, kwargs @@ -456,6 +471,7 @@ def apply_group_offloading( low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, block_modules: Optional[List[str]] = None, + exclude_kwargs: Optional[List[str]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -516,6 +532,10 @@ def apply_group_offloading( block_modules (`List[str]`, *optional*): List of module names that should be treated as blocks for offloading. If provided, only these modules will be considered for block-level offloading. If not provided, the default block detection logic will be used. + exclude_kwargs (`List[str]`, *optional*): + List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like + caching lists that need to maintain their object identity across forward passes. If not provided, will be + inferred from the module's `_group_offload_exclude_kwargs` attribute if it exists. Example: ```python @@ -557,6 +577,12 @@ def apply_group_offloading( _raise_error_if_accelerate_model_or_sequential_hook_present(module) + if block_modules is None: + block_modules = getattr(module, "_group_offload_block_modules", None) + + if exclude_kwargs is None: + exclude_kwargs = getattr(module, "_group_offload_exclude_kwargs", None) + config = GroupOffloadingConfig( onload_device=onload_device, offload_device=offload_device, @@ -568,6 +594,7 @@ def apply_group_offloading( low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) _apply_group_offloading(module, config) @@ -606,8 +633,11 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf for name, submodule in module.named_children(): # Check if this is an explicitly defined block module if name in block_modules: - # Apply block offloading to the specified submodule - _apply_group_offloading_block_level(submodule, config) + # track submodule using a prefix + prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}." + submodule_config = replace(config, module_prefix=prefix) + + _apply_group_offloading_block_level(submodule, submodule_config) modules_with_group_offloading.add(name) elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): @@ -617,7 +647,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf if len(current_modules) == 0: continue - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( modules=current_modules, offload_device=config.offload_device, @@ -655,9 +685,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Create a group for the remaining unmatched submodules of the top-level # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - has_unmatched = len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0 - - if has_unmatched or len(block_modules) > 0: + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: unmatched_group = ModuleGroup( modules=unmatched_modules, offload_device=config.offload_device, @@ -671,7 +699,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf stream=None, record_stream=False, onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", + group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group", ) if config.stream is None: _apply_group_offloading_hook(module, unmatched_group, config=config) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b332a964c310..8ca03df0e4c1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_out(x) + return x @@ -961,11 +962,13 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo """ _supports_gradient_checkpointing = False - # keys toignore when AlignDeviceHook moves inputs/outputs between devices - # these are shared mutable state modified in-place - _skip_keys = ["feat_cache", "feat_idx"] _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] + # kwargs to ignore when send_to_device moves inputs/outputs between devices + # these are shared mutable states that are modified in-place and + # should not be subjected to copy operations + _group_offload_exclude_kwargs = ["feat_cache", "feat_idx"] + @register_to_config def __init__( self, @@ -1146,9 +1149,6 @@ def _encode(self, x: torch.Tensor): feat_idx=self._enc_conv_idx, ) out = torch.cat([out, out_], 2) - __import__("ipdb").set_trace() - # cache_devices = [i.device.type for i in self._enc_feat_map] - # any((d != "cuda" for d in cache_devices)) enc = self.quant_conv(out) self.clear_cache() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 881225989269..41da95d3a2a2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -531,6 +531,8 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[str] = None, + exclude_kwargs: Optional[str] = None, ) -> None: r""" Activates group offloading for the current model. @@ -571,9 +573,6 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) - # Get block modules from the model if available - block_modules = getattr(self, "_group_offload_block_modules", None) - apply_group_offloading( module=self, onload_device=onload_device, @@ -586,6 +585,7 @@ def enable_group_offload( low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) def set_attention_backend(self, backend: str) -> None: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 988e85c6bbb3..508ea786f42d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1735,7 +1735,6 @@ def run_forward(model): return model(**inputs_dict)[0] model = self.model_class(**init_dict) - model.to(torch_device) output_without_group_offloading = run_forward(model) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 69135e16cee0..680c1feee7eb 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1425,6 +1425,7 @@ def _get_expected_safetensors_files( offload_type: str, num_blocks_per_group: Optional[int] = None, block_modules: Optional[List[str]] = None, + module_prefix: str = "", ) -> Set[str]: expected_files = set() @@ -1439,30 +1440,40 @@ def get_hashed_filename(group_id: str) -> str: block_modules_set = set(block_modules) if block_modules is not None else set() # Handle groups of ModuleList and Sequential blocks, and explicitly defined block modules + modules_with_group_offloading = set() unmatched_modules = [] for name, submodule in module.named_children(): # Check if this is an explicitly defined block module if name in block_modules_set: - # Recursively get expected files for the specified submodule + # Recursively get expected files for the specified submodule with updated prefix + new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}." submodule_files = _get_expected_safetensors_files( - submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules + submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix ) expected_files.update(submodule_files) + modules_with_group_offloading.add(name) + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): # Handle ModuleList and Sequential blocks as before for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] if not current_modules: continue - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}" expected_files.add(get_hashed_filename(group_id)) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") else: # This is an unmatched module - unmatched_modules.append(module) + unmatched_modules.append(submodule) + + # Handle the group for unmatched top-level modules and parameters/buffers + # We need to check if there are any parameters/buffers that don't belong to modules with group offloading + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - # Handle the group for unmatched top-level modules and parameters - if len(unmatched_modules) > 0: - expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group")) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group")) elif offload_type == "leaf_level": # Handle leaf-level module groups From 09a7b0ad25bb49017040dd1061e328a76f133b5b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 5 Dec 2025 04:37:34 +0100 Subject: [PATCH 11/13] clean up --- src/diffusers/hooks/group_offloading.py | 6 ++++-- tests/testing_utils.py | 7 ------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 5844bb647a09..f4852e8e1581 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -326,7 +326,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): # Some Autoencoder models use a feature cache that is passed through submodules # and modified in place. The `send_to_device` call returns a copy of this feature cache object - # which causes issues with inplace updates. Use `exclude_kwargs` to mark these cache features + # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features exclude_kwargs = self.config.exclude_kwargs or [] if exclude_kwargs: moved_kwargs = send_to_device( @@ -633,7 +633,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf for name, submodule in module.named_children(): # Check if this is an explicitly defined block module if name in block_modules: - # track submodule using a prefix + # Track submodule using a prefix to avoid filename collisions during disk offload. + # Without this, submodules sharing the same model class would be assigned identical + # filenames (derived from the class name). prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}." submodule_config = replace(config, module_prefix=prefix) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 680c1feee7eb..4550813259af 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1439,13 +1439,10 @@ def get_hashed_filename(group_id: str) -> str: block_modules_set = set(block_modules) if block_modules is not None else set() - # Handle groups of ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] for name, submodule in module.named_children(): - # Check if this is an explicitly defined block module if name in block_modules_set: - # Recursively get expected files for the specified submodule with updated prefix new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}." submodule_files = _get_expected_safetensors_files( submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix @@ -1454,7 +1451,6 @@ def get_hashed_filename(group_id: str) -> str: modules_with_group_offloading.add(name) elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - # Handle ModuleList and Sequential blocks as before for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] if not current_modules: @@ -1464,11 +1460,8 @@ def get_hashed_filename(group_id: str) -> str: for j in range(i, i + len(current_modules)): modules_with_group_offloading.add(f"{name}.{j}") else: - # This is an unmatched module unmatched_modules.append(submodule) - # Handle the group for unmatched top-level modules and parameters/buffers - # We need to check if there are any parameters/buffers that don't belong to modules with group offloading parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) From c888aac570dfb7e2430b835a7481ad15e84df8b4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 5 Dec 2025 04:51:50 +0100 Subject: [PATCH 12/13] revert to use skip_keys --- src/diffusers/hooks/group_offloading.py | 4 ++-- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f4852e8e1581..d2989238bf78 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -535,7 +535,7 @@ def apply_group_offloading( exclude_kwargs (`List[str]`, *optional*): List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like caching lists that need to maintain their object identity across forward passes. If not provided, will be - inferred from the module's `_group_offload_exclude_kwargs` attribute if it exists. + inferred from the module's `_skip_keys` attribute if it exists. Example: ```python @@ -581,7 +581,7 @@ def apply_group_offloading( block_modules = getattr(module, "_group_offload_block_modules", None) if exclude_kwargs is None: - exclude_kwargs = getattr(module, "_group_offload_exclude_kwargs", None) + exclude_kwargs = getattr(module, "_skip_keys", None) config = GroupOffloadingConfig( onload_device=onload_device, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 8ca03df0e4c1..57284c487e2b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -963,11 +963,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo _supports_gradient_checkpointing = False _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] - - # kwargs to ignore when send_to_device moves inputs/outputs between devices - # these are shared mutable states that are modified in-place and - # should not be subjected to copy operations - _group_offload_exclude_kwargs = ["feat_cache", "feat_idx"] + # keys toignore when AlignDeviceHook moves inputs/outputs between devices + # these are shared mutable state modified in-place + _skip_keys = ["feat_cache", "feat_idx"] @register_to_config def __init__( From 4bd3384b05da88b9b94866279276be095e425e70 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 5 Dec 2025 05:03:22 +0100 Subject: [PATCH 13/13] clean up --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d2989238bf78..dafcf261a04e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -615,7 +615,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf done at the top-level blocks and modules specified in block_modules. When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified - module, we either offload the entire submodule or recursively apply block offloading to it. + module, recursively apply block offloading to it. """ if config.stream is not None and config.num_blocks_per_group != 1: logger.warning(