diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 38f291f5203c..dafcf261a04e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -15,7 +15,7 @@ import hashlib import os from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union @@ -59,6 +59,9 @@ class GroupOffloadingConfig: num_blocks_per_group: Optional[int] = None offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + block_modules: Optional[List[str]] = None + exclude_kwargs: Optional[List[str]] = None + module_prefix: Optional[str] = "" class ModuleGroup: @@ -77,7 +80,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, + group_id: Optional[Union[int, str]] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -320,7 +323,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): self.group.stream.synchronize() args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) - kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + + # Some Autoencoder models use a feature cache that is passed through submodules + # and modified in place. The `send_to_device` call returns a copy of this feature cache object + # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features + exclude_kwargs = self.config.exclude_kwargs or [] + if exclude_kwargs: + moved_kwargs = send_to_device( + {k: v for k, v in kwargs.items() if k not in exclude_kwargs}, + self.group.onload_device, + non_blocking=self.group.non_blocking, + ) + kwargs.update(moved_kwargs) + else: + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs def post_forward(self, module: torch.nn.Module, output): @@ -453,6 +470,8 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[List[str]] = None, + exclude_kwargs: Optional[List[str]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -510,6 +529,13 @@ def apply_group_offloading( If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + block_modules (`List[str]`, *optional*): + List of module names that should be treated as blocks for offloading. If provided, only these modules will + be considered for block-level offloading. If not provided, the default block detection logic will be used. + exclude_kwargs (`List[str]`, *optional*): + List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like + caching lists that need to maintain their object identity across forward passes. If not provided, will be + inferred from the module's `_skip_keys` attribute if it exists. Example: ```python @@ -551,6 +577,12 @@ def apply_group_offloading( _raise_error_if_accelerate_model_or_sequential_hook_present(module) + if block_modules is None: + block_modules = getattr(module, "_group_offload_block_modules", None) + + if exclude_kwargs is None: + exclude_kwargs = getattr(module, "_skip_keys", None) + config = GroupOffloadingConfig( onload_device=onload_device, offload_device=offload_device, @@ -561,6 +593,8 @@ def apply_group_offloading( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) _apply_group_offloading(module, config) @@ -576,46 +610,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to - the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - """ + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is + done at the top-level blocks and modules specified in block_modules. + When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified + module, recursively apply block offloading to it. + """ if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) config.num_blocks_per_group = 1 - # Create module groups for ModuleList and Sequential blocks + block_modules = set(config.block_modules) if config.block_modules is not None else set() + + # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] matched_module_groups = [] + for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append((name, submodule)) + # Check if this is an explicitly defined block module + if name in block_modules: + # Track submodule using a prefix to avoid filename collisions during disk offload. + # Without this, submodules sharing the same model class would be assigned identical + # filenames (derived from the class name). + prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}." + submodule_config = replace(config, module_prefix=prefix) + + _apply_group_offloading_block_level(submodule, submodule_config) modules_with_group_offloading.add(name) - continue - for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - group = ModuleGroup( - modules=current_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=current_modules[-1], - onload_leader=current_modules[0], - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=group_id, - ) - matched_module_groups.append(group) - for j in range(i, i + len(current_modules)): - modules_with_group_offloading.add(f"{name}.{j}") + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + + group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}" + group = ModuleGroup( + modules=current_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # This is an unmatched module + unmatched_modules.append((name, submodule)) # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): @@ -630,28 +684,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf parameters = [param for _, param in parameters] buffers = [buffer for _, buffer in buffers] - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. + # Create a group for the remaining unmatched submodules of the top-level + # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) - else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ffc8778e7aca..4096b7c07609 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b0b2960aaf18..57284c487e2b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_out(x) + return x @@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo """ _supports_gradient_checkpointing = False + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] # keys toignore when AlignDeviceHook moves inputs/outputs between devices # these are shared mutable state modified in-place _skip_keys = ["feat_cache", "feat_idx"] @@ -1408,6 +1410,7 @@ def forward( """ x = sample posterior = self.encode(x).latent_dist + if sample_posterior: z = posterior.sample(generator=generator) else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..41da95d3a2a2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -531,6 +531,8 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[str] = None, + exclude_kwargs: Optional[str] = None, ) -> None: r""" Activates group offloading for the current model. @@ -570,6 +572,7 @@ def enable_group_offload( f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"open an issue at https://github.com/huggingface/diffusers/issues." ) + apply_group_offloading( module=self, onload_device=onload_device, @@ -581,6 +584,8 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) def set_attention_backend(self, backend: str) -> None: diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..236094109d07 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -19,6 +19,7 @@ import torch from parameterized import parameterized +from diffusers import AutoencoderKL from diffusers.hooks import HookRegistry, ModelHook from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -149,6 +150,74 @@ def post_forward(self, module, output): return output +# Model with only standalone computational layers at top level +class DummyModelWithStandaloneLayers(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.layer1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(hidden_features, hidden_features) + self.layer3 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + +# Model with deeply nested structure +class DummyModelWithDeeplyNestedBlocks(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.input_layer = torch.nn.Linear(in_features, hidden_features) + self.container = ContainerWithNestedModuleList(hidden_features) + self.output_layer = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_layer(x) + x = self.container(x) + x = self.output_layer(x) + return x + + +class ContainerWithNestedModuleList(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + # Top-level computational layer + self.proj_in = torch.nn.Linear(features, features) + + # Nested container with ModuleList + self.nested_container = NestedContainer(features) + + # Another top-level computational layer + self.proj_out = torch.nn.Linear(features, features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.nested_container(x) + x = self.proj_out(x) + return x + + +class NestedContainer(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)]) + self.norm = torch.nn.LayerNorm(features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + + @require_torch_accelerator class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -340,7 +409,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): out = model(x) self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.") - num_repeats = 4 + num_repeats = 2 for i in range(num_repeats): out_ref = model_ref(x) out = model(x) @@ -362,3 +431,138 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_vae_like_model_without_streams(self): + """Test VAE-like model with block-level offloading but without streams.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x).sample + out = model(x).sample + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." + ) + + def test_model_with_only_standalone_layers(self): + """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for model with standalone layers.", + ) + + @parameterized.expand([("block_level",), ("leaf_level",)]) + def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str): + """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x).sample + out = model(x).sample + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match for standalone Conv layers with {offload_type}.", + ) + + def test_multiple_invocations_with_vae_like_model(self): + """Test that multiple forward passes work correctly with VAE-like model.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x).sample + out = model(x).sample + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") + + def test_nested_container_parameters_offloading(self): + """Test that parameters from non-computational layers in nested containers are handled correctly.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for nested parameters.", + ) + + def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [2, 4] + norm_num_groups = norm_num_groups or 2 + init_dict = { + "block_out_channels": block_out_channels, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + "layers_per_block": 1, + } + return init_dict diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6f4c3d544b45..508ea786f42d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1735,7 +1735,6 @@ def run_forward(model): return model(**inputs_dict)[0] model = self.model_class(**init_dict) - model.to(torch_device) output_without_group_offloading = run_forward(model) @@ -1851,6 +1850,9 @@ def _run_forward(model, inputs_dict): offload_to_disk_path=tmpdir, offload_type=offload_type, num_blocks_per_group=num_blocks_per_group, + block_modules=model._group_offload_block_modules + if hasattr(model, "_group_offload_block_modules") + else None, ) if not is_correct: if extra_files: diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 6ed7e3467d7f..4550813259af 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1424,6 +1424,8 @@ def _get_expected_safetensors_files( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None, + module_prefix: str = "", ) -> Set[str]: expected_files = set() @@ -1435,23 +1437,36 @@ def get_hashed_filename(group_id: str) -> str: if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") - # Handle groups of ModuleList and Sequential blocks + block_modules_set = set(block_modules) if block_modules is not None else set() + + modules_with_group_offloading = set() unmatched_modules = [] for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append(module) - continue + if name in block_modules_set: + new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}." + submodule_files = _get_expected_safetensors_files( + submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix + ) + expected_files.update(submodule_files) + modules_with_group_offloading.add(name) + + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + if not current_modules: + continue + group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}" + expected_files.add(get_hashed_filename(group_id)) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + unmatched_modules.append(submodule) - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] - if not current_modules: - continue - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - expected_files.add(get_hashed_filename(group_id)) + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - # Handle the group for unmatched top-level modules and parameters - for module in unmatched_modules: - expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group")) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group")) elif offload_type == "leaf_level": # Handle leaf-level module groups @@ -1492,12 +1507,13 @@ def _check_safetensors_serialization( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None, ) -> bool: if not os.path.isdir(offload_to_disk_path): return False, None, None expected_files = _get_expected_safetensors_files( - module, offload_to_disk_path, offload_type, num_blocks_per_group + module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules ) actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) missing_files = expected_files - actual_files