diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index d18c82df4f62..0ee32f820b56 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1064,6 +1064,41 @@ def save_function(weights, filename): save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") + @classmethod + def _save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]], + lora_metadata: Dict[str, Optional[dict]], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + """ + Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all + pipeline types. + """ + state_dict = {} + final_lora_adapter_metadata = {} + + for prefix, layers in lora_layers.items(): + state_dict.update(cls.pack_weights(layers, prefix)) + + for prefix, metadata in lora_metadata.items(): + if metadata: + final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix)) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None, + ) + @classmethod def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7e89066f1fc4..8060b519f147 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -510,35 +510,28 @@ def save_lora_weights( text_encoder_lora_adapter_metadata: LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") + lora_layers = {} + lora_metadata = {} if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + lora_layers[cls.unet_name] = unet_lora_layers + lora_metadata[cls.unet_name] = unet_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - if unet_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1004,44 +997,34 @@ def save_lora_weights( text_encoder_2_lora_adapter_metadata: LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) + lora_layers = {} + lora_metadata = {} if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + lora_layers[cls.unet_name] = unet_lora_layers + lora_metadata[cls.unet_name] = unet_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + lora_layers["text_encoder"] = text_encoder_lora_layers + lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + lora_layers["text_encoder_2"] = text_encoder_2_lora_layers + lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata - if unet_lora_adapter_metadata is not None: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + if not lora_layers: + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." ) - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1467,46 +1450,34 @@ def save_lora_weights( text_encoder_2_lora_adapter_metadata: LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) + lora_layers = {} + lora_metadata = {} if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + lora_layers["text_encoder"] = text_encoder_lora_layers + lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) + lora_layers["text_encoder_2"] = text_encoder_2_lora_layers + lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + if not lora_layers: + raise ValueError( + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." ) - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer @@ -1830,28 +1801,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -2435,37 +2402,28 @@ def save_lora_weights( text_encoder_lora_adapter_metadata: LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + lora_layers = {} + lora_metadata = {} if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - if transformer_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3254,28 +3212,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3594,28 +3548,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3938,28 +3888,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4280,28 +4226,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4624,28 +4566,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4969,28 +4907,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -5384,28 +5318,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5802,28 +5732,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -6144,28 +6070,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -6488,28 +6410,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -6835,28 +6753,24 @@ def save_lora_weights( transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + lora_layers = {} + lora_metadata = {} - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora