From 2354fda9c0d66ae1e1606463a767b7b8173ff73a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:44:12 +0000 Subject: [PATCH 01/15] init --- ...convert_z_image_controlnet_to_diffusers.py | 103 +++ src/diffusers/models/controlnets/__init__.py | 1 + .../models/controlnets/controlnet_z_image.py | 528 ++++++++++++++ .../transformers/transformer_z_image.py | 11 +- .../z_image/pipeline_z_image_controlnet.py | 674 ++++++++++++++++++ 5 files changed, 1315 insertions(+), 2 deletions(-) create mode 100644 scripts/convert_z_image_controlnet_to_diffusers.py create mode 100644 src/diffusers/models/controlnets/controlnet_z_image.py create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py new file mode 100644 index 000000000000..c4b96cda02af --- /dev/null +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -0,0 +1,103 @@ +import argparse +from contextlib import nullcontext + +import torch +import safetensors.torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.models import ZImageTransformer2DModel +from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel + +""" +python scripts/convert_z_image_controlnet_to_diffusers.py \ +--original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ +--original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ +--filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" +--output_path "z-image-controlnet-hf/" +""" + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) +parser.add_argument("--original_controlnet_repo_id", default=None, type=str) +parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_controlnet_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_controlnet_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_controlnet_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + +def load_z_image(args): + model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + return model.state_dict(), model.config + +def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): + converted_state_dict = {} + + converted_state_dict.update(original_state_dict) + + to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + + for key in z_image.keys(): + for copy_key in to_copy: + if key.startswith(copy_key): + converted_state_dict[key] = z_image[key] + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + z_image, config = load_z_image(args) + + control_in_dim = 16 + control_layers_places = [0, 5, 10, 15, 20, 25] + + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) + + for key, tensor in converted_controlnet_state_dict.items(): + print(f"{key} - {tensor.dtype}") + + controlnet = ZImageControlNetModel( + all_patch_size=config["all_patch_size"], + all_f_patch_size=config["all_f_patch_size"], + in_channels=config["in_channels"], + dim=config["dim"], + n_layers=config["n_layers"], + n_refiner_layers=config["n_refiner_layers"], + n_heads=config["n_heads"], + n_kv_heads=config["n_kv_heads"], + norm_eps=config["norm_eps"], + qk_norm=config["qk_norm"], + cap_feat_dim=config["cap_feat_dim"], + rope_theta=config["rope_theta"], + t_scale=config["t_scale"], + axes_dims=config["axes_dims"], + axes_lens=config["axes_lens"], + control_layers_places=control_layers_places, + control_in_dim=control_in_dim, + ) + missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) + print(f"{missing=}") + print(f"{unexpected=}") + print("Saving Z-Image ControlNet in Diffusers format") + controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 7ce352879daa..fee7f231e899 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -19,6 +19,7 @@ ) from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .controlnet_z_image import ZImageControlNetModel from .multicontrolnet import MultiControlNetModel from .multicontrolnet_union import MultiControlNetUnionModel diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py new file mode 100644 index 000000000000..d6cede86812d --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -0,0 +1,528 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.normalization import RMSNorm +from ..controlnets.controlnet import zero_module +from ..modeling_utils import ModelMixin +from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) + self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) + + def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + +class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + control_layers_places: List[int]=None, + control_in_dim=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + self.n_layers = n_layers + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + ## Original Control layers + + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + block_id=i + ) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + cap_padding_len: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + cap_feats: List[torch.Tensor], + control_context: List[torch.Tensor], + t=None, + patch_size=2, + f_patch_size=1, + conditioning_scale: float = 1.0, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## Original forward_control + + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + x_size, + x_pos_ids, + x_inner_pad_mask, + ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + x_item_seqlens = [len(_) for _ in control_context] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + c = control_context_unified + + new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + c = self._gradient_checkpointing_func(layer, c, **new_kwargs) + else: + c = layer(c, **new_kwargs) + + hints = torch.unbind(c)[:-1] * conditioning_scale + controlnet_block_samples = {} + for layer_idx in range(self.n_layers): + if layer_idx in self.control_layers_places: + hints_idx = self.control_layers_places.index(layer_idx) + controlnet_block_samples[layer_idx] = hints[hints_idx] + + return controlnet_block_samples diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..2d332217d897 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,6 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, + controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, return_dict: bool = True, ): assert patch_size in self.all_patch_size @@ -635,13 +636,19 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] else: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py new file mode 100644 index 000000000000..609b141be796 --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -0,0 +1,674 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + prompt_embeds_model_input, + control_image, + timestep_model_input, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) From 1e2009de435516caf7b6e67ab215f8f6299c375f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:48:38 +0000 Subject: [PATCH 02/15] passed transformer --- .../models/controlnets/controlnet_z_image.py | 98 +++---------------- .../z_image/pipeline_z_image_controlnet.py | 1 + 2 files changed, 15 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d6cede86812d..6fe9d38ce3d1 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -23,7 +23,7 @@ from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM class ZImageControlTransformerBlock(ZImageTransformerBlock): @@ -66,87 +66,16 @@ def __init__( self, all_patch_size=(2,), all_f_patch_size=(1,), - in_channels=16, dim=3840, - n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - cap_feat_dim=2560, - rope_theta=256.0, - t_scale=1000.0, - axes_dims=[32, 48, 48], - axes_lens=[1024, 512, 512], control_layers_places: List[int]=None, control_in_dim=None, ): super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels - self.all_patch_size = all_patch_size - self.all_f_patch_size = all_f_patch_size - self.dim = dim - self.n_heads = n_heads - - self.rope_theta = rope_theta - self.t_scale = t_scale - self.gradient_checkpointing = False - self.n_layers = n_layers - - assert len(all_patch_size) == len(all_f_patch_size) - - all_x_embedder = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - self.all_x_embedder = nn.ModuleDict(all_x_embedder) - self.noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.context_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), - ) - - self.x_pad_token = nn.Parameter(torch.empty((1, dim))) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) - - self.axes_dims = axes_dims - self.axes_lens = axes_lens - - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - - ## Original Control layers - self.control_layers_places = control_layers_places self.control_in_dim = control_in_dim @@ -366,6 +295,7 @@ def patchify_and_embed( def forward( self, + transformer: ZImageTransformer2DModel, x: List[torch.Tensor], cap_feats: List[torch.Tensor], control_context: List[torch.Tensor], @@ -380,7 +310,7 @@ def forward( bsz = len(x) device = x[0].device t = t * self.t_scale - t = self.t_embedder(t) + t = transformer.t_embedder(t) ( x, @@ -398,13 +328,13 @@ def forward( x_max_item_seqlen = max(x_item_seqlens) x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) # Match t_embedder output dtype to x for layerwise casting compatibility adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) @@ -413,10 +343,10 @@ def forward( x_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) else: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) # cap embed & refine @@ -425,10 +355,10 @@ def forward( cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = transformer.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) @@ -437,10 +367,10 @@ def forward( cap_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) else: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) # unified @@ -485,7 +415,7 @@ def forward( adaln_input = t.type_as(control_context) control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 609b141be796..d374b8032ea8 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -594,6 +594,7 @@ def __call__( latent_model_input_list = list(latent_model_input.unbind(dim=0)) controlnet_block_samples = self.controlnet( + self.transformer, latent_model_input_list, prompt_embeds_model_input, control_image, From 0c308394049f2c7a65c697cf88ea9c40d9ca4333 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:50:01 +0000 Subject: [PATCH 03/15] ruff --- ...convert_z_image_controlnet_to_diffusers.py | 21 ++++++++--- .../models/controlnets/controlnet_z_image.py | 36 +++++++++---------- .../transformers/transformer_z_image.py | 2 +- .../z_image/pipeline_z_image_controlnet.py | 3 +- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index c4b96cda02af..a9f97d81676d 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,14 +1,15 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.utils.import_utils import is_accelerate_available from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel +from diffusers.utils.import_utils import is_accelerate_available + """ python scripts/convert_z_image_controlnet_to_diffusers.py \ @@ -42,16 +43,28 @@ def load_original_checkpoint(args): original_state_dict = safetensors.torch.load_file(ckpt_path) return original_state_dict + def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + model = ZImageTransformer2DModel.from_pretrained( + args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ) return model.state_dict(), model.config + def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + to_copy = { + "all_x_embedder.", + "noise_refiner.", + "context_refiner.", + "t_embedder.", + "cap_embedder.", + "x_pad_token", + "cap_pad_token", + } for key in z_image.keys(): for copy_key in to_copy: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 6fe9d38ce3d1..b76a2c54c3d8 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -20,15 +20,18 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ( + SEQ_MULTI_OF, + ZImageTransformer2DModel, + ZImageTransformerBlock, +) class ZImageControlTransformerBlock(ZImageTransformerBlock): def __init__( - self, + self, layer_id: int, dim: int, n_heads: int, @@ -36,7 +39,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, - block_id=0 + block_id=0, ): super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id @@ -57,7 +60,8 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): all_c += [c_skip, c] c = torch.stack(all_c) return c - + + class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @@ -72,7 +76,7 @@ def __init__( n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - control_layers_places: List[int]=None, + control_layers_places: List[int] = None, control_in_dim=None, ): super().__init__() @@ -84,15 +88,7 @@ def __init__( # control blocks self.control_layers = nn.ModuleList( [ - ZImageControlTransformerBlock( - i, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - block_id=i - ) + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places ] ) @@ -425,7 +421,9 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) else: for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) @@ -440,14 +438,14 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) - + new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: c = self._gradient_checkpointing_func(layer, c, **new_kwargs) else: c = layer(c, **new_kwargs) - + hints = torch.unbind(c)[:-1] * conditioning_scale controlnet_block_samples = {} for layer_idx in range(self.n_layers): diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 2d332217d897..70ffced8b63a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,7 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, - controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, + controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None, return_dict: bool = True, ): assert patch_size in self.all_patch_size diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index d374b8032ea8..44906a0db519 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -89,7 +89,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -509,7 +508,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, - ) + ) height, width = control_image.shape[-2:] control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor From 52f996e226dfd1e7f1a1b0d001c022dae71e24a8 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:23 +0000 Subject: [PATCH 04/15] convert passed --- ...convert_z_image_controlnet_to_diffusers.py | 58 ++----------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index a9f97d81676d..aed27c14f205 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,19 +1,17 @@ import argparse from contextlib import nullcontext -import safetensors.torch import torch +import safetensors.torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel from diffusers.utils.import_utils import is_accelerate_available """ python scripts/convert_z_image_controlnet_to_diffusers.py \ ---original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ --original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ --filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" --output_path "z-image-controlnet-hf/" @@ -23,7 +21,6 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext parser = argparse.ArgumentParser() -parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) parser.add_argument("--original_controlnet_repo_id", default=None, type=str) parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) @@ -44,72 +41,29 @@ def load_original_checkpoint(args): return original_state_dict -def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained( - args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 - ) - return model.state_dict(), model.config - - -def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): +def convert_z_image_controlnet_checkpoint_to_diffusers(original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = { - "all_x_embedder.", - "noise_refiner.", - "context_refiner.", - "t_embedder.", - "cap_embedder.", - "x_pad_token", - "cap_pad_token", - } - - for key in z_image.keys(): - for copy_key in to_copy: - if key.startswith(copy_key): - converted_state_dict[key] = z_image[key] - return converted_state_dict def main(args): original_ckpt = load_original_checkpoint(args) - z_image, config = load_z_image(args) control_in_dim = 16 control_layers_places = [0, 5, 10, 15, 20, 25] - converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) - - for key, tensor in converted_controlnet_state_dict.items(): - print(f"{key} - {tensor.dtype}") + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(original_ckpt) controlnet = ZImageControlNetModel( - all_patch_size=config["all_patch_size"], - all_f_patch_size=config["all_f_patch_size"], - in_channels=config["in_channels"], - dim=config["dim"], - n_layers=config["n_layers"], - n_refiner_layers=config["n_refiner_layers"], - n_heads=config["n_heads"], - n_kv_heads=config["n_kv_heads"], - norm_eps=config["norm_eps"], - qk_norm=config["qk_norm"], - cap_feat_dim=config["cap_feat_dim"], - rope_theta=config["rope_theta"], - t_scale=config["t_scale"], - axes_dims=config["axes_dims"], - axes_lens=config["axes_lens"], control_layers_places=control_layers_places, control_in_dim=control_in_dim, - ) - missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) - print(f"{missing=}") - print(f"{unexpected=}") + ).to(torch.bfloat16) + controlnet.load_state_dict(converted_controlnet_state_dict) print("Saving Z-Image ControlNet in Diffusers format") - controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + controlnet.save_pretrained(args.output_path) if __name__ == "__main__": From 4b446b394150575b322836b048acd3eeeb2072a3 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:30 +0000 Subject: [PATCH 05/15] __init__ --- src/diffusers/__init__.py | 4 ++++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/z_image/__init__.py | 2 ++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb8e86c4c89d..f45be1560716 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageControlNetModel", "attention_backend", ] ) @@ -661,6 +662,7 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", "ZImagePipeline", + "ZImageControlNetPipeline", ] ) @@ -1004,6 +1006,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, + ZImageControlNetModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1357,6 +1360,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ZImagePipeline, + ZImageControlNetPipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..7ea15ef2a215 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -66,6 +66,7 @@ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] @@ -180,6 +181,7 @@ SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, + ZImageControlNetModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3d669aecf556..fe6af5cd1e0b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -404,7 +404,7 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImagePipeline"] + _import_structure["z_image"] = ["ZImagePipeline", "ZImageControlNetPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline + from .z_image import ZImagePipeline, ZImageControlNetPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index f95b3e5a0bed..842d5690e3d7 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -23,6 +23,7 @@ else: _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] + _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,6 +36,7 @@ else: from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline + from .pipeline_z_image_controlnet import ZImageControlNetPipeline else: import sys From a1ff390ecebb5afb9f6282209526adfdaf31c5d5 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:37 +0000 Subject: [PATCH 06/15] pipeline example --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 44906a0db519..ae81105eea27 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -36,9 +36,13 @@ Examples: ```py >>> import torch - >>> from diffusers import ZImagePipeline + >>> from diffusers import ZImageControlNetPipeline + >>> from diffusers import ZImageControlNetModel - >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> controlnet_model = "..." + >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + + >>> pipe = ZImageControlNetPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. @@ -47,9 +51,11 @@ >>> # (2) Use flash attention 3 >>> # pipe.transformer.set_attention_backend("_flash_3") - >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( ... prompt, + ... control_image=control_image, ... height=1024, ... width=1024, ... num_inference_steps=9, From 7ab347d812a5b78076f79541f4046d59948d0464 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:04:01 +0000 Subject: [PATCH 07/15] ruff --- scripts/convert_z_image_controlnet_to_diffusers.py | 2 +- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/controlnets/controlnet_z_image.py | 7 ++++++- src/diffusers/pipelines/__init__.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index aed27c14f205..e5d5f34e36e8 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,8 +1,8 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f45be1560716..398f72167ad3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1005,8 +1005,8 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, - ZImageTransformer2DModel, ZImageControlNetModel, + ZImageTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1359,8 +1359,8 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - ZImagePipeline, ZImageControlNetPipeline, + ZImagePipeline, ) try: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index b76a2c54c3d8..ff148781f49a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -438,7 +438,12 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + new_kwargs = { + "x": unified, + "attn_mask": unified_attn_mask, + "freqs_cis": unified_freqs_cis, + "adaln_input": adaln_input, + } for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index fe6af5cd1e0b..10ce49fe8111 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline, ZImageControlNetPipeline + from .z_image import ZImageControlNetPipeline, ZImagePipeline try: if not is_onnx_available(): From 8cab0c953c7b732324a92d4e5de067b8bb290a5d Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:05:36 +0000 Subject: [PATCH 08/15] pipeline load_image --- src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index ae81105eea27..67771dddabd7 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -38,6 +38,7 @@ >>> import torch >>> from diffusers import ZImageControlNetPipeline >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image >>> controlnet_model = "..." >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) From 8688fa66a110bd0d77df8d03299fc3a42130ce07 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:24:05 +0000 Subject: [PATCH 09/15] t_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index ff148781f49a..070724a85883 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -305,7 +305,7 @@ def forward( bsz = len(x) device = x[0].device - t = t * self.t_scale + t = t * transformer.t_scale t = transformer.t_embedder(t) ( From 9051272d47082c5cf6bc409b6368332cdac16f97 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:26 +0000 Subject: [PATCH 10/15] x_pad_token --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 070724a85883..48b9a66a25a3 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -409,7 +409,7 @@ def forward( # Match t_embedder output dtype to control_context for layerwise casting compatibility adaln_input = t.type_as(control_context) - control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) From 0d8c3f1a28180fc85fc9a4e0696d5f4f11def56f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:34 +0000 Subject: [PATCH 11/15] controlnet_block_samples --- src/diffusers/models/controlnets/controlnet_z_image.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 48b9a66a25a3..0127f7f9683f 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,10 +452,6 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {} - for layer_idx in range(self.n_layers): - if layer_idx in self.control_layers_places: - hints_idx = self.control_layers_places.index(layer_idx) - controlnet_block_samples[layer_idx] = hints[hints_idx] + controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From f789325ccd8f3f6fb35dffdd4acea6f21f30084e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:29:54 +0000 Subject: [PATCH 12/15] conditioning_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 0127f7f9683f..3a200b252a01 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -451,7 +451,7 @@ def forward( else: c = layer(c, **new_kwargs) - hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} + hints = torch.unbind(c)[:-1] + controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From 5f8ab7bf98549ff6bdc63db500ad1433b7cf84e2 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:33:36 +0000 Subject: [PATCH 13/15] self.config --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 3a200b252a01..d0f8b861e0c9 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -300,8 +300,8 @@ def forward( f_patch_size=1, conditioning_scale: float = 1.0, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.config.all_patch_size + assert f_patch_size in self.config.all_f_patch_size bsz = len(x) device = x[0].device From bc72f9ce93ca691018fb8f1b420684b6e18a6d55 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:08:24 +0000 Subject: [PATCH 14/15] sample_mode, default controlnet_conditioning_scale --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 67771dddabd7..d0460cf09244 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -360,7 +360,7 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, control_image: PipelineImageInput = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, cfg_normalization: bool = False, cfg_truncation: float = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -517,7 +517,7 @@ def __call__( dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = control_image.unsqueeze(2) From 13b706a99f209197352bcb8790260727d75b2b9b Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:16:49 +0000 Subject: [PATCH 15/15] ruff --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d0f8b861e0c9..c121f42c1a78 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,6 +452,8 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] - controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } return controlnet_block_samples