|
120 | 120 | "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", |
121 | 121 | "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", |
122 | 122 | "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], |
| 123 | + "z-image-turbo": "cap_embedder.0.weight", |
123 | 124 | "sana": [ |
124 | 125 | "blocks.0.cross_attn.q_linear.weight", |
125 | 126 | "blocks.0.cross_attn.q_linear.bias", |
|
218 | 219 | "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"}, |
219 | 220 | "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"}, |
220 | 221 | "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"}, |
| 222 | + "z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"}, |
221 | 223 | } |
222 | 224 |
|
223 | 225 | # Use to configure model sample size when original config is provided |
@@ -721,6 +723,12 @@ def infer_diffusers_model_type(checkpoint): |
721 | 723 | ): |
722 | 724 | model_type = "instruct-pix2pix" |
723 | 725 |
|
| 726 | + elif ( |
| 727 | + CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint |
| 728 | + and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560 |
| 729 | + ): |
| 730 | + model_type = "z-image-turbo" |
| 731 | + |
724 | 732 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): |
725 | 733 | model_type = "lumina2" |
726 | 734 |
|
@@ -3824,3 +3832,56 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) |
3824 | 3832 | handler_fn_inplace(key, converted_state_dict) |
3825 | 3833 |
|
3826 | 3834 | return converted_state_dict |
| 3835 | + |
| 3836 | + |
| 3837 | +def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 3838 | + Z_IMAGE_KEYS_RENAME_DICT = { |
| 3839 | + "final_layer.": "all_final_layer.2-1.", |
| 3840 | + "x_embedder.": "all_x_embedder.2-1.", |
| 3841 | + ".attention.out.bias": ".attention.to_out.0.bias", |
| 3842 | + ".attention.k_norm.weight": ".attention.norm_k.weight", |
| 3843 | + ".attention.q_norm.weight": ".attention.norm_q.weight", |
| 3844 | + ".attention.out.weight": ".attention.to_out.0.weight", |
| 3845 | + } |
| 3846 | + |
| 3847 | + def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None: |
| 3848 | + if ".attention.qkv.weight" not in key: |
| 3849 | + return |
| 3850 | + |
| 3851 | + fused_qkv_weight = state_dict.pop(key) |
| 3852 | + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) |
| 3853 | + new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight") |
| 3854 | + new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight") |
| 3855 | + new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight") |
| 3856 | + |
| 3857 | + state_dict[new_q_name] = to_q_weight |
| 3858 | + state_dict[new_k_name] = to_k_weight |
| 3859 | + state_dict[new_v_name] = to_v_weight |
| 3860 | + return |
| 3861 | + |
| 3862 | + TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| 3863 | + ".attention.qkv.weight": convert_z_image_fused_attention, |
| 3864 | + } |
| 3865 | + |
| 3866 | + def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None: |
| 3867 | + state_dict[new_key] = state_dict.pop(old_key) |
| 3868 | + |
| 3869 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 3870 | + |
| 3871 | + # Handle single file --> diffusers key remapping via the remap dict |
| 3872 | + for key in list(converted_state_dict.keys()): |
| 3873 | + new_key = key[:] |
| 3874 | + for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items(): |
| 3875 | + new_key = new_key.replace(replace_key, rename_key) |
| 3876 | + |
| 3877 | + update_state_dict(converted_state_dict, key, new_key) |
| 3878 | + |
| 3879 | + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
| 3880 | + # special_keys_remap |
| 3881 | + for key in list(converted_state_dict.keys()): |
| 3882 | + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 3883 | + if special_key not in key: |
| 3884 | + continue |
| 3885 | + handler_fn_inplace(key, converted_state_dict) |
| 3886 | + |
| 3887 | + return converted_state_dict |
0 commit comments