Skip to content
Merged
52 changes: 52 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,49 @@ def unload_lora_weights(self):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None

if getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()

for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already pin torch version this is safe enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
device=module_weight.device,
dtype=module_weight.dtype,
)

original_module.weight.data.copy_(current_param_weight)
if module_bias is not None:
original_module.bias.data.copy_(overwritten_params[f"{name}.bias"])

setattr(parent_module, current_module_name, original_module)

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
Expand All @@ -2309,6 +2352,7 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
Expand Down Expand Up @@ -2378,6 +2422,14 @@ def _maybe_expand_transformer_param_shape_or_error_(
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")

# For `unload_lora_weights()`.
overwritten_params[f"{current_module_name}.weight"] = module_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would have a small but significant memory overhead. For inference purposes only with loras, maybe this could be made opt-out if we know we never want call unload_lora_weights. Not a blocker though and can be tackled in a different PR but lmk your thoughts

Copy link
Member Author

@sayakpaul sayakpaul Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this could be tackled with discard_original_layers. For now, I have added a note as a comment about it.

if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias

if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params

return has_param_with_shape_update


Expand Down
66 changes: 66 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,72 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
"adapter-2",
)

def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)

# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4

dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

control_pipe.unload_lora_weights()
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down
Loading