Skip to content

Commit 6028613

Browse files
authored
Z-Image-Turbo from_single_file (#12756)
* Z-Image-Turbo `from_single_file` * compute_dtype * -device cast
1 parent a1f36ee commit 6028613

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
convert_stable_cascade_unet_single_file_to_diffusers,
5050
convert_wan_transformer_to_diffusers,
5151
convert_wan_vae_to_diffusers,
52+
convert_z_image_transformer_checkpoint_to_diffusers,
5253
create_controlnet_diffusers_config_from_ldm,
5354
create_unet_diffusers_config_from_ldm,
5455
create_vae_diffusers_config_from_ldm,
@@ -167,6 +168,10 @@
167168
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
168169
"default_subfolder": "transformer",
169170
},
171+
"ZImageTransformer2DModel": {
172+
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
173+
"default_subfolder": "transformer",
174+
},
170175
}
171176

172177

src/diffusers/loaders/single_file_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123+
"z-image-turbo": "cap_embedder.0.weight",
123124
"sana": [
124125
"blocks.0.cross_attn.q_linear.weight",
125126
"blocks.0.cross_attn.q_linear.bias",
@@ -218,6 +219,7 @@
218219
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
219220
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
220221
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
222+
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
221223
}
222224

223225
# Use to configure model sample size when original config is provided
@@ -721,6 +723,12 @@ def infer_diffusers_model_type(checkpoint):
721723
):
722724
model_type = "instruct-pix2pix"
723725

726+
elif (
727+
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
728+
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
729+
):
730+
model_type = "z-image-turbo"
731+
724732
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
725733
model_type = "lumina2"
726734

@@ -3824,3 +3832,56 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38243832
handler_fn_inplace(key, converted_state_dict)
38253833

38263834
return converted_state_dict
3835+
3836+
3837+
def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3838+
Z_IMAGE_KEYS_RENAME_DICT = {
3839+
"final_layer.": "all_final_layer.2-1.",
3840+
"x_embedder.": "all_x_embedder.2-1.",
3841+
".attention.out.bias": ".attention.to_out.0.bias",
3842+
".attention.k_norm.weight": ".attention.norm_k.weight",
3843+
".attention.q_norm.weight": ".attention.norm_q.weight",
3844+
".attention.out.weight": ".attention.to_out.0.weight",
3845+
}
3846+
3847+
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
3848+
if ".attention.qkv.weight" not in key:
3849+
return
3850+
3851+
fused_qkv_weight = state_dict.pop(key)
3852+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
3853+
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
3854+
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
3855+
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
3856+
3857+
state_dict[new_q_name] = to_q_weight
3858+
state_dict[new_k_name] = to_k_weight
3859+
state_dict[new_v_name] = to_v_weight
3860+
return
3861+
3862+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
3863+
".attention.qkv.weight": convert_z_image_fused_attention,
3864+
}
3865+
3866+
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
3867+
state_dict[new_key] = state_dict.pop(old_key)
3868+
3869+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
3870+
3871+
# Handle single file --> diffusers key remapping via the remap dict
3872+
for key in list(converted_state_dict.keys()):
3873+
new_key = key[:]
3874+
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
3875+
new_key = new_key.replace(replace_key, rename_key)
3876+
3877+
update_state_dict(converted_state_dict, key, new_key)
3878+
3879+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
3880+
# special_keys_remap
3881+
for key in list(converted_state_dict.keys()):
3882+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
3883+
if special_key not in key:
3884+
continue
3885+
handler_fn_inplace(key, converted_state_dict)
3886+
3887+
return converted_state_dict

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ def timestep_embedding(t, dim, max_period=10000):
6363
def forward(self, t):
6464
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
6565
weight_dtype = self.mlp[0].weight.dtype
66+
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
6667
if weight_dtype.is_floating_point:
6768
t_freq = t_freq.to(weight_dtype)
69+
elif compute_dtype is not None:
70+
t_freq = t_freq.to(compute_dtype)
6871
t_emb = self.mlp(t_freq)
6972
return t_emb
7073

0 commit comments

Comments
 (0)