Skip to content

Commit 3e472ac

Browse files
committed
transformer weights
1 parent f9540cb commit 3e472ac

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

src/diffusers/models/transformers/transformer_z_image_control.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,29 @@ def from_controlnet(
468468
if not load_weights:
469469
return model
470470

471+
for key, all_x_embedder in transformer.all_x_embedder.items():
472+
model.all_x_embedder[key].load_state_dict(all_x_embedder.state_dict())
473+
474+
for key, all_final_layer in transformer.all_final_layer.items():
475+
model.all_final_layer[key].load_state_dict(all_final_layer.state_dict())
476+
477+
for i, noise_refiner in enumerate(transformer.noise_refiner):
478+
model.noise_refiner[i].load_state_dict(noise_refiner.state_dict())
479+
480+
for i, context_refiner in enumerate(transformer.context_refiner):
481+
model.context_refiner[i].load_state_dict(context_refiner.state_dict())
482+
483+
model.t_embedder.load_state_dict(transformer.t_embedder.state_dict())
484+
485+
model.cap_embedder.load_state_dict(transformer.cap_embedder.state_dict())
486+
487+
model.x_pad_token = transformer.x_pad_token
488+
489+
model.cap_pad_token = transformer.cap_pad_token
490+
491+
for i, layer in enumerate(transformer.layers):
492+
model.layers[i].load_state_dict(layer.state_dict())
493+
471494
for i, control_layer in enumerate(controlnet.control_layers):
472495
model.control_layers[i].load_state_dict(control_layer.state_dict())
473496

0 commit comments

Comments
 (0)