From fce87b36675189534831e30a586c72ddd1f78f8f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Nov 2025 20:18:20 +0100 Subject: [PATCH 01/63] add vae --- scripts/convert_flux2_to_diffusers.py | 226 ++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_flux2.py | 525 ++++++++++++++++++ 5 files changed, 756 insertions(+) create mode 100644 scripts/convert_flux2_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_flux2.py diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py new file mode 100644 index 000000000000..91616e5bfc68 --- /dev/null +++ b/scripts/convert_flux2_to_diffusers.py @@ -0,0 +1,226 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import AutoencoderKLFlux2 +from diffusers.utils.import_utils import is_accelerate_available + + + +""" +# VAE + +python scripts/convert_flux2_to_diffusers.py \ +--original_state_dict_repo_id "diffusers-internal-dev/dummy-flux2" \ +--filename "ae.pt" \ +--output_path "/raid/yiyi/dummy-flux2-diffusers" \ +--dtype fp32 \ +--vae +""" + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--vae", action="store_true") +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + if ckpt_path.endswith(".pt"): + original_state_dict = torch.load(ckpt_path, map_location="cpu") + elif ckpt_path.endswith(".safetensors"): + original_state_dict = safetensors.torch.load_file(ckpt_path) + else: + raise ValueError(f"Unsupported file extension: {ckpt_path}") + return original_state_dict + + + +DIFFUSERS_VAE_TO_FLUX2_MAPPING = { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.norm_out.bias", + "quant_conv.weight": "encoder.quant_conv.weight", + "quant_conv.bias": "encoder.quant_conv.bias", + "post_quant_conv.weight": "decoder.post_quant_conv.weight", + "post_quant_conv.bias": "decoder.post_quant_conv.bias", + "bn.running_mean": "bn.running_mean", + "bn.running_var": "bn.running_var", + } + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] + elif len(shape) == 4: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + + +def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config): + new_checkpoint = {} + for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items(): + if ldm_key not in vae_state_dict: + continue + new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len(config["down_block_types"]) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, + ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.bias" + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len(config["up_block_types"]) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, + ) + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + if args.vae: + vae = AutoencoderKLFlux2() + if "model" in original_ckpt: + # YiYi Notes: remove this depends on if it has "model" key + original_ckpt = original_ckpt["model"] + converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_ckpt, vae.config) + vae.load_state_dict(converted_vae_state_dict, strict=True) + vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a5040bd28394..b084e07c82a5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -183,6 +183,7 @@ "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", + "AutoencoderKLFlux2", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", "AutoencoderKLCosmos", @@ -894,6 +895,7 @@ AutoencoderDC, AutoencoderKL, AutoencoderKLAllegro, + AutoencoderKLFlux2, AutoencoderKLCogVideoX, AutoencoderKLCosmos, AutoencoderKLHunyuanImage, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e97ab8bd1d2a..fb1c10c1a0cb 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -44,6 +44,7 @@ _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] + _import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] @@ -148,6 +149,7 @@ AutoencoderKLQwenImage, AutoencoderKLTemporalDecoder, AutoencoderKLWan, + AutoencoderKLFlux2, AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index edfaabb070c5..58a203a00ee8 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -13,6 +13,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan +from .autoencoder_kl_flux2 import AutoencoderKLFlux2 from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py new file mode 100644 index 000000000000..4a2c9c064d74 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py @@ -0,0 +1,525 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import deprecate +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512,), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: Tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d(math.prod(patch_size) * latent_channels, eps=batch_norm_eps, momentum=batch_norm_momentum, affine=False, track_running_stats=True) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) From a1f2ba1ab00f67e7791b2d5ddb13769eb6c6ec2c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 14 Nov 2025 04:53:44 +0100 Subject: [PATCH 02/63] Initial commit for Flux 2 Transformer implementation --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_flux2.py | 988 ++++++++++++++++++ .../test_models_transformer_flux2.py | 227 ++++ 5 files changed, 1220 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_flux2.py create mode 100644 tests/models/transformers/test_models_transformer_flux2.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a5040bd28394..7ba3f2793290 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -215,6 +215,7 @@ "CosmosTransformer3DModel", "DiTTransformer2DModel", "EasyAnimateTransformer3DModel", + "Flux2Transformer2DModel", "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", @@ -925,6 +926,7 @@ CosmosTransformer3DModel, DiTTransformer2DModel, EasyAnimateTransformer3DModel, + Flux2Transformer2DModel, FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e97ab8bd1d2a..d9d657ffbbd2 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -92,6 +92,7 @@ _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] @@ -189,6 +190,7 @@ DiTTransformer2DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, + Flux2Transformer2DModel, FluxTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 66daf56e23b2..c00abda53da3 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -26,6 +26,7 @@ from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_flux2 import Flux2Transformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py new file mode 100644 index 000000000000..ea7e0b0ebd84 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -0,0 +1,988 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import ( + TimestepEmbedding, + Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + mlp_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.parallel_proj_in: + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor] + ) + query, key, value = qkv.chunk(3, dim=-1) + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + + # Get encoder QKV, if available + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None: + if hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk( + 3, dim=-1 + ) + elif attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + else: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if attn.parallel_proj_out: + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + else: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +# TODO: support IP Adapter for Flux.2 as well +class FluxIPAdapterAttnProcessor(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + _attention_backend = None + _parallel_config = None + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + ip_query = query + + if encoder_hidden_states is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_attn_output = torch.zeros_like(hidden_states) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [ + Flux2AttnProcessor, + FluxIPAdapterAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + parallel_proj_in: bool = False, + parallel_proj_out: bool = False, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.parallel_proj_in = parallel_proj_in + self.parallel_proj_out = parallel_proj_out + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + if self.parallel_proj_in: + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + else: + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + if self.parallel_proj_out: + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + else: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block + self.attn = Flux2Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + parallel_proj_in=True, + parallel_proj_out=True, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2AttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +@maybe_allow_in_graph +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + return time_guidance_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Flux 2. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `8`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `48`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `48`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `15360`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=2), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=2), + ) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + double_stream_mod_img, + double_stream_mod_txt, + concat_rotary_emb, + joint_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + single_stream_mod, + concat_rotary_emb, + joint_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py new file mode 100644 index 000000000000..cbf3f0fa4296 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Flux2Transformer2DModel +from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 +from diffusers.models.embeddings import ImageProjection + +from ...testing_utils import enable_full_determinism, is_peft_available, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +def create_flux_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + # "image_proj" (ImageProjection layer weights) + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.7, 0.6, 0.6] + + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_dummy_input(self, height=4, width=4): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + # pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + # "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + "guidance": guidance, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + # "pooled_projection_dim": 32, + "timestep_guidance_channels": 16, + "axes_dims_rope": [4, 4, 8], + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_deprecated_inputs_img_txt_ids_3d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + self.assertEqual(output_1.shape, output_2.shape) + self.assertTrue( + torch.allclose(output_1, output_2, atol=1e-5), + msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"FluxTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # The test exists for cases like + # https://github.com/huggingface/diffusers/issues/11874 + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_exclude_modules(self): + from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict + + lora_rank = 4 + target_module = "single_transformer_blocks.0.proj_out" + adapter_name = "foo" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + state_dict = model.state_dict() + target_mod_shape = state_dict[f"{target_module}.weight"].shape + lora_state_dict = { + f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, + f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, + } + # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). + config = LoraConfig( + r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] + ) + inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) + set_peft_model_state_dict(model, lora_state_dict, adapter_name) + retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) + assert len(retrieved_lora_state_dict) == len(lora_state_dict) + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() + + +class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) + + +class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) From e470643e1234acbc075985ac429eb6e8e0c3cc02 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Nov 2025 12:12:23 +0100 Subject: [PATCH 03/63] add pipeline part --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/flux2/__init__.py | 47 ++ .../pipelines/flux2/image_processor.py | 148 ++++ .../pipelines/flux2/pipeline_flux2.py | 744 ++++++++++++++++++ .../pipelines/flux2/pipeline_output.py | 24 + 6 files changed, 967 insertions(+) create mode 100644 src/diffusers/pipelines/flux2/__init__.py create mode 100644 src/diffusers/pipelines/flux2/image_processor.py create mode 100644 src/diffusers/pipelines/flux2/pipeline_flux2.py create mode 100644 src/diffusers/pipelines/flux2/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b084e07c82a5..85e1b75cccb1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -470,6 +470,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "Flux2Pipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -1138,6 +1139,7 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, + Flux2Pipeline, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 495753041f10..7e0973be5bf1 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,6 +129,7 @@ ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["flux2"] = ["Flux2Pipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -629,6 +630,7 @@ EasyAnimateInpaintPipeline, EasyAnimatePipeline, ) + from .flux2 import Flux2Pipeline from .flux import ( FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py new file mode 100644 index 000000000000..d986c9a63011 --- /dev/null +++ b/src/diffusers/pipelines/flux2/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flux2 import Flux2Pipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py new file mode 100644 index 000000000000..7ec1f8f77267 --- /dev/null +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -0,0 +1,148 @@ +# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import math + +from ...configuration_utils import register_to_config +from ...image_processor import VaeImageProcessor + + +class Flux2ImageProcessor(VaeImageProcessor): + r""" + Image processor to preprocess the reference (character) image for the Flux2 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `16`): + VAE latent channels. + spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`): + The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2). + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + spatial_patch_size: Tuple[int, int] = (2, 2), + resample: str = "lanczos", + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__() + + + @staticmethod + def check_image_input( + image: PIL.Image.Image, + max_aspect_ratio: int = 8, + min_side_length: int = 64, + max_area: int = 1024 * 1024 + ) -> PIL.Image.Image: + """ + Check if image meets minimum size and aspect ratio requirements. + + Args: + image: PIL Image to validate + max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width) + min_side_length: Minimum pixels required for width and height + max_area: Maximum allowed area in pixels² + + Returns: + The input image if valid + + Raises: + ValueError: If image is too small or aspect ratio is too extreme + """ + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + + # Check minimum dimensions + if width < min_side_length or height < min_side_length: + raise ValueError( + f"Image too small: {width}×{height}. " + f"Both dimensions must be at least {min_side_length}px" + ) + + # Check aspect ratio + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + + return image + + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: + image_width, image_height = image.size + + scale = math.sqrt(target_area/ (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + + def _resize_and_crop( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + center crop the image to the specified width and height. + + Args: + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. + """ + image_width, image_height = image.size + + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + + return image.crop((left, top, right, bottom)) \ No newline at end of file diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py new file mode 100644 index 000000000000..3f2495529472 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -0,0 +1,744 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import PIL + +import numpy as np +import torch +from transformers import ( + Mistral3ForConditionalGeneration, AutoProcessor +) + +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import Flux2PipelineOutput +from .image_processor import Flux2ImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class Flux2Pipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" + self.text_encoder_out_layers = (10, 20, 30) + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""", + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l).to(x.device) + out_ids.append(coords) + + return torch.stack(out_ids) + + + # YiYi TODO: can optimize a bit + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(C, H, W), (C, H, W), ...] + scale: int = 10 + ): + + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all + input latent with different dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. + T-coordinate for the i-th latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. + Shape: (1, N_total, 4) + Where N_total is the sum of (H * W) for all input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + + if x.ndim == 4: + x = x.squeeze(0) + if x.ndim != 3: + raise ValueError(f"Expected `image_latent` to be a list of tensors with dims 3 or 4, got {x.ndim}.") + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2) , 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height *2 , width *2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=self.text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + return prompt_embeds, text_ids + + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = ( + self.vae.bn.running_mean.view(1, -1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + + shape = (batch_size, num_latents_channels * 4, height // 16, width // 16) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_image_ids(list(latents)) + latents = self._pack_latents(latents) + + return latents, latent_ids + + + def prepare_image_latents( + self, + images: List[torch.Tensor], + generator: torch.Generator, + ): + image_latents = [] + for image in images: + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + return image_latents, image_latent_ids + + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 2.5, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: + [`~pipelines.flux2.Flux2PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + # TODO: + + + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + prompt_embeds, txt_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + condition_image_sizes = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_images.append(self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode = "crop")) + condition_image_sizes.append((image_width, image_height)) + + + num_channels_latents = 32 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + generator=generator, + ) + + print(f"latents.shape = {latents.shape}, latent_ids.shape = {latent_ids.shape}") + print(f"image_latents.shape = {image_latents.shape}, image_latent_ids.shape = {image_latent_ids.shape}") + print(f"prompt_embeds.shape = {prompt_embeds.shape}, txt_ids.shape = {txt_ids.shape}") + + return latents, latent_ids, image_latents, image_latent_ids \ No newline at end of file diff --git a/src/diffusers/pipelines/flux2/pipeline_output.py b/src/diffusers/pipelines/flux2/pipeline_output.py new file mode 100644 index 000000000000..2183b8bcff41 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_output.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] From 7456a4985cfbae85d23f218b957ed4b9b3b1d35b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Nov 2025 11:54:32 +0000 Subject: [PATCH 04/63] small edits to the pipeline and conversion --- scripts/convert_flux2_to_diffusers.py | 9 ++++- .../pipelines/flux2/pipeline_flux2.py | 39 ++++++++++--------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 91616e5bfc68..b764543bd7b0 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKLFlux2 from diffusers.utils.import_utils import is_accelerate_available - +from transformers import Mistral3ForConditionalGeneration, AutoProcessor """ @@ -221,6 +221,13 @@ def main(args): vae.load_state_dict(converted_vae_state_dict, strict=True) vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + if args.full_pipe: + tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + text_encoder = Mistral3ForConditionalGeneration.from_pretrained(text_encoder_id, torch_dtype=torch.bfloat16) + tokenizer = AutoProcessor.from_pretrained(tokenizer_id) + + # TODO: collate denoiser, vae, text encoder, tokenizer here. if __name__ == "__main__": main(args) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 3f2495529472..8aedb868c1c5 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -18,9 +18,7 @@ import numpy as np import torch -from transformers import ( - Mistral3ForConditionalGeneration, AutoProcessor -) +from transformers import Mistral3ForConditionalGeneration, AutoProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel @@ -56,12 +54,12 @@ >>> import torch >>> from diffusers import Flux2Pipeline - >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-schnell", torch_dtype=torch.bfloat16) + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] >>> image.save("flux.png") ``` """ @@ -178,10 +176,10 @@ class Flux2Pipeline( FluxIPAdapterMixin, ): r""" - The Flux pipeline for text-to-image generation. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + The Flux2 pipeline for text-to-image generation. + Reference: TODO + Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. @@ -189,18 +187,11 @@ class Flux2Pipeline( A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" @@ -268,6 +259,16 @@ def _get_mistral_3_small_prompt_embeds( input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) + text_input_ids = input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + # Forward pass through the model output = text_encoder( input_ids=input_ids, From b10e629645424db659a384f1096b03e653f4be9f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Nov 2025 22:07:45 +0100 Subject: [PATCH 05/63] update conversion script --- scripts/convert_flux2_to_diffusers.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 91616e5bfc68..54924d9549f8 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -44,12 +44,7 @@ def load_original_checkpoint(args): else: raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") - if ckpt_path.endswith(".pt"): - original_state_dict = torch.load(ckpt_path, map_location="cpu") - elif ckpt_path.endswith(".safetensors"): - original_state_dict = safetensors.torch.load_file(ckpt_path) - else: - raise ValueError(f"Unsupported file extension: {ckpt_path}") + original_state_dict = safetensors.torch.load_file(ckpt_path) return original_state_dict @@ -214,9 +209,6 @@ def main(args): if args.vae: vae = AutoencoderKLFlux2() - if "model" in original_ckpt: - # YiYi Notes: remove this depends on if it has "model" key - original_ckpt = original_ckpt["model"] converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_ckpt, vae.config) vae.load_state_dict(converted_vae_state_dict, strict=True) vae.to(dtype).save_pretrained(f"{args.output_path}/vae") From 7382358495238645b83e408dbdacfe1f77c758b2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Nov 2025 22:07:56 +0100 Subject: [PATCH 06/63] fix --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 3f2495529472..7fd508f1c966 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -483,9 +483,12 @@ def prepare_image_latents( self, images: List[torch.Tensor], generator: torch.Generator, + device, + dtype, ): image_latents = [] for image in images: + image = image.to(device=device, dtype=dtype) imagge_latent = self._encode_vae_image(image=image, generator=generator) image_latents.append(imagge_latent) # (1, 128, 32, 32) @@ -711,11 +714,13 @@ def __call__( image_width, image_height = img.size if image_width * image_height > 1024 * 1024: img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size multiple_of = self.vae_scale_factor * 2 image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - condition_images.append(self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode = "crop")) + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode = "crop") + condition_images.append(img) condition_image_sizes.append((image_width, image_height)) @@ -735,10 +740,8 @@ def __call__( image_latents, image_latent_ids = self.prepare_image_latents( images=condition_images, generator=generator, + device=device, + dtype=self.vae.dtype, ) - print(f"latents.shape = {latents.shape}, latent_ids.shape = {latent_ids.shape}") - print(f"image_latents.shape = {image_latents.shape}, image_latent_ids.shape = {image_latent_ids.shape}") - print(f"prompt_embeds.shape = {prompt_embeds.shape}, txt_ids.shape = {txt_ids.shape}") - - return latents, latent_ids, image_latents, image_latent_ids \ No newline at end of file + return image_latents, image_latent_ids \ No newline at end of file From 524b1238ace0424f361017ebd7016a8b31b01e69 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Nov 2025 23:44:40 +0100 Subject: [PATCH 07/63] up up --- .../pipelines/flux2/pipeline_flux2.py | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index a81d2cf20afe..cb04152c46c3 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -307,6 +307,38 @@ def _prepare_text_ids( return torch.stack(out_ids) + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) + All batches share the same coordinate structure: T=0, H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l).to(latents.device) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + # YiYi TODO: can optimize a bit @staticmethod def _prepare_image_ids( @@ -463,7 +495,7 @@ def prepare_latents( width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_latents_channels * 4, height // 16, width // 16) + shape = (batch_size, num_latents_channels * 4, height//2, width//2) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -474,15 +506,15 @@ def prepare_latents( else: latents = latents.to(device=device, dtype=dtype) - latent_ids = self._prepare_image_ids(list(latents)) - latents = self._pack_latents(latents) - + latent_ids = self._prepare_latent_ids(latents) + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] return latents, latent_ids def prepare_image_latents( self, images: List[torch.Tensor], + batch_size, generator: torch.Generator, device, dtype, @@ -506,6 +538,10 @@ def prepare_image_latents( # Concatenate all reference tokens along sequence dimension image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + return image_latents, image_latent_ids @@ -740,9 +776,13 @@ def __call__( if condition_images is not None: image_latents, image_latent_ids = self.prepare_image_latents( images=condition_images, + batch_size=batch_size * num_images_per_prompt, generator=generator, device=device, dtype=self.vae.dtype, ) - return image_latents, image_latent_ids \ No newline at end of file + # YiYi Testing + # return image_latents, image_latent_ids, latents, latent_ids + # YiYi Testing + # return None, None, latents, latent_ids \ No newline at end of file From 8a48adcd2c9b6fc68f2408c4e3366c56acf30aa0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 15 Nov 2025 01:25:32 +0100 Subject: [PATCH 08/63] finish pipeline --- .../pipelines/flux2/pipeline_flux2.py | 167 +++++++++++++++++- 1 file changed, 159 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index cb04152c46c3..612a18d1e040 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -64,6 +64,46 @@ ``` """ +# YiYi TODO: refactor later, remove rearrange and potentially compress_time is no-op here +def compress_time(t_ids: torch.Tensor) -> torch.Tensor: + assert t_ids.ndim == 1 + t_ids_max = torch.max(t_ids) + t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) + t_unique_sorted_ids = torch.unique(t_ids, sorted=True) + t_remap[t_unique_sorted_ids] = torch.arange( + len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype + ) + t_ids_compressed = t_remap[t_ids] + return t_ids_compressed + +from einops import rearrange +def scatter_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + t_coords = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + t_ids = pos[:, 0].to(torch.int64) + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + t_ids_cmpr = compress_time(t_ids) + + t = torch.max(t_ids_cmpr) + 1 + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids + + out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) + t_coords.append(torch.unique(t_ids, sorted=True)) + return x_list + def format_text_input(prompts: List[str], system_message: str = None): @@ -703,9 +743,6 @@ def __call__( width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - # TODO: - - self.check_inputs( prompt=prompt, height=height, @@ -728,15 +765,17 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - - prompt_embeds, txt_ids = self.encode_prompt( + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - + + # 4. process images if image is not None and not isinstance(image, list): image = [image] @@ -760,8 +799,9 @@ def __call__( condition_images.append(img) condition_image_sizes.append((image_width, image_height)) - + # 5. prepare latent variables num_channels_latents = 32 + # num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_ids = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, num_latents_channels=num_channels_latents, @@ -785,4 +825,115 @@ def __call__( # YiYi Testing # return image_latents, image_latent_ids, latents, latent_ids # YiYi Testing - # return None, None, latents, latent_ids \ No newline at end of file + # return None, None, latents, latent_ids + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + latent_image_ids = torch.cat([latent_ids, image_latent_ids],dim=1) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, L, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1):] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = torch.cat(scatter_ids(latents, latent_ids), dim=1).squeeze(2) + + latents_bn_mean = ( + self.vae.bn.running_mean.view(1, -1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file From 429d2cf5543b1046fe013035855fdec8ce1afb95 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 03:34:52 +0100 Subject: [PATCH 09/63] Remove Flux IP Adapter logic for now --- .../models/transformers/transformer_flux2.py | 135 ------------------ 1 file changed, 135 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index ea7e0b0ebd84..de8c509efe05 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -208,145 +208,10 @@ def __call__( return hidden_states -# TODO: support IP Adapter for Flux.2 as well -class FluxIPAdapterAttnProcessor(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - _attention_backend = None - _parallel_config = None - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: "Flux2Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - batch_size = hidden_states.shape[0] - - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states - ) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - query = attn.norm_q(query) - key = attn.norm_k(key) - ip_query = query - - if encoder_hidden_states is not None: - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - - encoder_query = attn.norm_added_q(encoder_query) - encoder_key = attn.norm_added_k(encoder_key) - - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) - ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) - - current_ip_hidden_states = dispatch_attention_fn( - ip_query, - ip_key, - ip_value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - class Flux2Attention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = Flux2AttnProcessor _available_processors = [ Flux2AttnProcessor, - FluxIPAdapterAttnProcessor, ] def __init__( From 2d7bad73634c7e7edfc914af9d629ffdb73504af Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 03:36:12 +0100 Subject: [PATCH 10/63] Remove deprecated 3D id logic --- .../models/transformers/transformer_flux2.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index de8c509efe05..1598af626ce1 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -748,19 +748,6 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) # 3. Calculate RoPE embeddings from image and text tokens - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - if is_torch_npu_available(): freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) From c67f582073826cf6b9bdc9694ed294677eeaf6ae Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 03:37:29 +0100 Subject: [PATCH 11/63] Remove ControlNet logic for now --- .../models/transformers/transformer_flux2.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 1598af626ce1..bbacc7972f37 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -788,18 +788,6 @@ def forward( joint_attention_kwargs=joint_attention_kwargs, ) - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] - ) - else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - # 5. Single Stream Transformer Blocks for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -820,12 +808,6 @@ def forward( joint_attention_kwargs=joint_attention_kwargs, ) - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] - # 6. Output layers hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) From 7acd7dad550e6714b9f2802b18ded9da26a617d8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 03:48:46 +0100 Subject: [PATCH 12/63] Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block --- src/diffusers/models/transformers/transformer_flux2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index bbacc7972f37..06d7f1596c65 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -321,7 +321,8 @@ def __init__( self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this - # is often called a "parallel" transformer block + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. self.attn = Flux2Attention( query_dim=dim, dim_head=attention_head_dim, From 66040866007cf9a6ba6bedeace25663e839b0f0c Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 15 Nov 2025 06:26:46 +0000 Subject: [PATCH 13/63] update pipeline --- .../pipelines/flux2/pipeline_flux2.py | 75 ++++++++++--------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 612a18d1e040..5ab963bbf274 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -208,13 +208,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline( - DiffusionPipeline, - FluxLoraLoaderMixin, - FromSingleFileMixin, - TextualInversionLoaderMixin, - FluxIPAdapterMixin, -): +class Flux2Pipeline(DiffusionPipeline): r""" The Flux2 pipeline for text-to-image generation. @@ -244,6 +238,7 @@ def __init__( vae: AutoencoderKL, text_encoder: Mistral3ForConditionalGeneration, tokenizer: AutoProcessor, + transformer: FluxTransformer2DModel, ): super().__init__() @@ -252,6 +247,7 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, + transformer=transformer, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible @@ -341,7 +337,7 @@ def _prepare_text_ids( w = torch.arange(1) l = torch.arange(L) - coords = torch.cartesian_prod(t, h, w, l).to(x.device) + coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) return torch.stack(out_ids) @@ -372,7 +368,7 @@ def _prepare_latent_ids( l = torch.arange(1) # [0] - layer dimension # Create position IDs: (H*W, 4) - latent_ids = torch.cartesian_prod(t, h, w, l).to(latents.device) + latent_ids = torch.cartesian_prod(t, h, w, l) # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) @@ -382,7 +378,7 @@ def _prepare_latent_ids( # YiYi TODO: can optimize a bit @staticmethod def _prepare_image_ids( - image_latents: List[torch.Tensor], # [(C, H, W), (C, H, W), ...] + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] scale: int = 10 ): @@ -422,10 +418,7 @@ def _prepare_image_ids( image_latent_ids = [] for x, t in zip(image_latents, t_coords): - if x.ndim == 4: - x = x.squeeze(0) - if x.ndim != 3: - raise ValueError(f"Expected `image_latent` to be a list of tensors with dims 3 or 4, got {x.ndim}.") + x = x.squeeze(0) _, height, width = x.shape x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) @@ -496,6 +489,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) return prompt_embeds, text_ids @@ -547,6 +541,8 @@ def prepare_latents( latents = latents.to(device=device, dtype=dtype) latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] return latents, latent_ids @@ -581,6 +577,7 @@ def prepare_image_latents( image_latents = image_latents.repeat(batch_size, 1, 1) image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) return image_latents, image_latent_ids @@ -812,7 +809,9 @@ def __call__( generator=generator, latents=latents, ) - + + image_latents = None + image_latent_ids = None if condition_images is not None: image_latents, image_latent_ids = self.prepare_image_latents( images=condition_images, @@ -822,11 +821,6 @@ def __call__( dtype=self.vae.dtype, ) - # YiYi Testing - # return image_latents, image_latent_ids, latents, latent_ids - # YiYi Testing - # return None, None, latents, latent_ids - # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: @@ -867,25 +861,36 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents + latent_model_input = latents.to(self.transformer.dtype) latent_image_ids = latent_ids if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids],dim=1) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, L, C) - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1):] + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, L, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # # YiYi NOTES: uncomment this to use the tranformer from original repo for testing + # # YIYI TODO: remove this before merging + # noise_pred = self.transformer( + # x=latent_model_input, + # x_ids=latent_image_ids, + # timesteps=timestep / 1000, + # guidance=guidance.to(self.transformer.dtype), + # ctx=prompt_embeds.to(self.transformer.dtype), + # ctx_ids=text_ids, + # ) + noise_pred = noise_pred[:, : latents.size(1):] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -921,7 +926,7 @@ def __call__( latents_bn_mean = ( self.vae.bn.running_mean.view(1, -1, 1, 1) - .to(image_latents.device, image_latents.dtype) + .to(latents.device, latents.dtype) ) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) latents = latents * latents_bn_std + latents_bn_mean From a200780fa7e698c4c4b8d2b0747a82ae8b125ccf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 07:44:57 +0100 Subject: [PATCH 14/63] Don't use biases for input projs and output AdaNorm --- src/diffusers/models/transformers/transformer_flux2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 06d7f1596c65..65981aa7f0d0 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -641,8 +641,8 @@ def __init__( self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) # 4. Input projections - self.x_embedder = nn.Linear(in_channels, self.inner_dim) - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) # 5. Double Stream Transformer Blocks self.transformer_blocks = nn.ModuleList( @@ -675,7 +675,9 @@ def __init__( ) # 7. Output layers - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps) + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) self.gradient_checkpointing = False From f7148a02c32bd05484c941624ea546aaf14969bb Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 15 Nov 2025 08:35:23 +0000 Subject: [PATCH 15/63] up --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 5ab963bbf274..57bfc138d8cc 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -876,20 +876,10 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, + joint_attention_kwargs=self._attention_kwargs, return_dict=False, )[0] - # # YiYi NOTES: uncomment this to use the tranformer from original repo for testing - # # YIYI TODO: remove this before merging - # noise_pred = self.transformer( - # x=latent_model_input, - # x_ids=latent_image_ids, - # timesteps=timestep / 1000, - # guidance=guidance.to(self.transformer.dtype), - # ctx=prompt_embeds.to(self.transformer.dtype), - # ctx_ids=text_ids, - # ) noise_pred = noise_pred[:, : latents.size(1):] # compute the previous noisy sample x_t -> x_t-1 From 54c60803646ac02600293c9e9708f68eea6952ad Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 10:44:40 +0100 Subject: [PATCH 16/63] Remove bias for double stream block text QKV projections --- src/diffusers/models/transformers/transformer_flux2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 65981aa7f0d0..6b9dc41d5aef 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -401,6 +401,7 @@ def __init__( heads=num_attention_heads, out_dim=dim, bias=bias, + added_proj_bias=bias, out_bias=bias, eps=eps, processor=Flux2AttnProcessor(), From dba4c1fe1900ac9a57fdcf0aced73652269a0704 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 10:45:06 +0100 Subject: [PATCH 17/63] Add script to convert Flux 2 transformer to diffusers --- scripts/convert_flux2_to_diffusers.py | 282 ++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 scripts/convert_flux2_to_diffusers.py diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py new file mode 100644 index 000000000000..b4a713baf8ce --- /dev/null +++ b/scripts/convert_flux2_to_diffusers.py @@ -0,0 +1,282 @@ +import argparse +import os +import pathlib +from contextlib import nullcontext +from typing import Any, Dict, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import Flux2Transformer2DModel +from diffusers.utils.import_utils import is_accelerate_available +from transformers import Mistral3ForConditionalGeneration, AutoProcessor + + +""" +# Transformer +""" + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +FLUX2_TRANSFORMER_KEYS_RENAME_DICT ={ + # Image and text input projections + "img_in": "x_embedder", + "txt_in": "context_embedder", + # Timestep and guidance embeddings + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + # Modulation parameters + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + # Final output layer + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params + "final_layer.linear": "proj_out", +} + + +FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear", +} + + +FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + # Handle fused QKV projections separately as we need to break into Q, K, V projections + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", +} + + +FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", +} + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use +# diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight + if ".weight" not in key: + return + + # If adaLN_modulation is in the key, swap scale and shift parameters + # Original implementation is (shift, scale); diffusers implementation is (scale, shift) + if "adaLN_modulation" in key: + key_without_param_type, param_type = key.rsplit(".", maxsplit=1) + # Assume all such keys are in the AdaLayerNorm key map + new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type] + new_key = ".".join([new_key_without_param_type, param_type]) + + swapped_weight = swap_scale_shift(state_dict.pop(key)) + state_dict[new_key] = swapped_weight + return + + +def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + new_prefix = "transformer_blocks" + if "double_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + if "qkv" in within_block_name: + fused_qkv_weight = state_dict.pop(key) + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + if "img" in modality_block_name: + # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.to_q" + new_k_name = "attn.to_k" + new_v_name = "attn.to_v" + elif "txt" in modality_block_name: + # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.add_q_proj" + new_k_name = "attn.add_k_proj" + new_v_name = "attn.add_v_proj" + new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type]) + new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type]) + new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type]) + state_dict[new_q_key] = to_q_weight + state_dict[new_k_key] = to_k_weight + state_dict[new_v_key] = to_v_weight + else: + new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + return + + +def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + # Mapping: + # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj + # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out + # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight + # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight + new_prefix = "single_transformer_blocks" + if "single_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + return + + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "adaLN_modulation": convert_ada_layer_norm_weights, + "double_blocks": convert_flux2_double_stream_blocks, + "single_blocks": convert_flux2_single_stream_blocks, +} + + +def load_original_checkpoint( + repo_id: Optional[str], model_file: Optional[str], checkpoint_path: Optional[str] = None +) -> Dict[str, torch.Tensor]: + if repo_id is not None: + ckpt_path = hf_hub_download(repo_id=repo_id, filename=model_file) + elif checkpoint_path is not None: + ckpt_path = checkpoint_path + else: + raise ValueError("Please provide either `repo_id` or a local `checkpoint_path`") + + if "safetensors" in model_file: + original_state_dict = safetensors.torch.load_file(ckpt_path) + else: + original_state_dict = torch.load(ckpt_path, map_location="cpu") + return original_state_dict + + +def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: + if model_type == "test" or model_type == "dummy-flux2": + config = { + "model_id": "diffusers-internal-dev/dummy-flux2", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 8, + "num_single_layers": 48, + "attention_head_dim": 128, + "num_attention_heads": 48, + "joint_attention_dim": 15360, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + } + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str): + config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type) + + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + transformer = Flux2Transformer2DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--original_state_dict_repo_id", default="diffusers-internal-dev/dummy-flux2", type=str) + parser.add_argument("--filename", default="flux.safetensors", type=str) + parser.add_argument("--checkpoint_path", default=None, type=str) + + parser.add_argument("--model_type", type=str, default="test") + parser.add_argument("--vae", action="store_true") + parser.add_argument("--transformer", action="store_true") + + parser.add_argument("--dtype", type=str, default="bf16") + + parser.add_argument("--output_path", type=str) + + args = parser.parse_args() + args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + return args + + +def main(args): + original_ckpt = load_original_checkpoint(args.original_state_dict_repo_id, args.filename, args.checkpoint_path) + + if args.transformer: + transformer = convert_flux2_transformer_to_diffusers(original_ckpt, args.model_type) + transformer.to(args.dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + +if __name__ == "__main__": + args = parse_args() + main(args) From cceffc4b5c6978a0354ad2b8f9d764b52145f51c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 15 Nov 2025 10:46:17 +0100 Subject: [PATCH 18/63] make style and make quality --- scripts/convert_flux2_to_diffusers.py | 14 ++++++-------- .../models/transformers/transformer_flux2.py | 1 - 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index b4a713baf8ce..4001c8032cb6 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -1,6 +1,5 @@ import argparse import os -import pathlib from contextlib import nullcontext from typing import Any, Dict, Optional, Tuple @@ -11,7 +10,6 @@ from diffusers import Flux2Transformer2DModel from diffusers.utils.import_utils import is_accelerate_available -from transformers import Mistral3ForConditionalGeneration, AutoProcessor """ @@ -22,7 +20,7 @@ CTX = init_empty_weights if is_accelerate_available() else nullcontext -FLUX2_TRANSFORMER_KEYS_RENAME_DICT ={ +FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { # Image and text input projections "img_in": "x_embedder", "txt_in": "context_embedder", @@ -82,7 +80,7 @@ def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None # Skip if not a weight if ".weight" not in key: return - + # If adaLN_modulation is in the key, swap scale and shift parameters # Original implementation is (shift, scale); diffusers implementation is (scale, shift) if "adaLN_modulation" in key: @@ -100,7 +98,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> # Skip if not a weight, bias, or scale if ".weight" not in key and ".bias" not in key and ".scale" not in key: return - + new_prefix = "transformer_blocks" if "double_blocks." in key: parts = key.split(".") @@ -111,7 +109,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> if param_type == "scale": param_type = "weight" - + if "qkv" in within_block_name: fused_qkv_weight = state_dict.pop(key) to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) @@ -146,7 +144,7 @@ def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> # Skip if not a weight, bias, or scale if ".weight" not in key and ".bias" not in key and ".scale" not in key: return - + # Mapping: # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out @@ -215,7 +213,7 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: "axes_dims_rope": (32, 32, 32, 32), "rope_theta": 2000, "eps": 1e-6, - } + }, } rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 6b9dc41d5aef..b009db2deced 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F From e93a74617a06e3f73521681f9117961163c3cb50 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 15 Nov 2025 16:00:18 +0000 Subject: [PATCH 19/63] fix a few things. --- scripts/convert_flux2_to_diffusers.py | 1 + src/diffusers/pipelines/flux2/pipeline_flux2.py | 14 +++----------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 60e660839563..8832dd382d61 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -29,6 +29,7 @@ parser.add_argument("--filename", default="flux.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) parser.add_argument("--vae", action="store_true") +parser.add_argument("--full_pipe", action="store_true") parser.add_argument("--output_path", type=str) parser.add_argument("--dtype", type=str, default="bf16") diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 57bfc138d8cc..8c4778964c9f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -295,16 +295,6 @@ def _get_mistral_3_small_prompt_embeds( input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) - text_input_ids = input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - # Forward pass through the model output = text_encoder( input_ids=input_ids, @@ -918,7 +908,9 @@ def __call__( self.vae.bn.running_mean.view(1, -1, 1, 1) .to(latents.device, latents.dtype) ) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents_bn_std = torch.sqrt( + self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps + ).to(latents.device, latents.dtype) latents = latents * latents_bn_std + latents_bn_mean latents = self._unpatchify_latents(latents) From cfdd0057d28433fd2b6a106240d9f0e051a7d47c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 15 Nov 2025 17:04:33 +0000 Subject: [PATCH 20/63] allow sft files to go. --- scripts/convert_flux2_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 4001c8032cb6..714a4d02cca2 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -185,7 +185,7 @@ def load_original_checkpoint( else: raise ValueError("Please provide either `repo_id` or a local `checkpoint_path`") - if "safetensors" in model_file: + if "safetensors" in model_file or "sft" in model_file: original_state_dict = safetensors.torch.load_file(ckpt_path) else: original_state_dict = torch.load(ckpt_path, map_location="cpu") From 148398076ae0abc136739b21e6501b54242724c6 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 15 Nov 2025 22:24:51 +0000 Subject: [PATCH 21/63] fix image processor --- src/diffusers/pipelines/flux2/image_processor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index 7ec1f8f77267..3c9b376dc7ac 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -52,12 +52,16 @@ def __init__( do_resize: bool = True, vae_scale_factor: int = 16, vae_latent_channels: int = 32, - spatial_patch_size: Tuple[int, int] = (2, 2), - resample: str = "lanczos", do_normalize: bool = True, do_convert_rgb: bool = True, ): - super().__init__() + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) @staticmethod From 62231ef99fa11caeef00497dc66a2a40e92d8d63 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 16 Nov 2025 02:17:44 +0000 Subject: [PATCH 22/63] fix batch --- .../pipelines/flux2/pipeline_flux2.py | 76 ++++++++----------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 8c4778964c9f..39d3a78b0854 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -64,47 +64,6 @@ ``` """ -# YiYi TODO: refactor later, remove rearrange and potentially compress_time is no-op here -def compress_time(t_ids: torch.Tensor) -> torch.Tensor: - assert t_ids.ndim == 1 - t_ids_max = torch.max(t_ids) - t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) - t_unique_sorted_ids = torch.unique(t_ids, sorted=True) - t_remap[t_unique_sorted_ids] = torch.arange( - len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype - ) - t_ids_compressed = t_remap[t_ids] - return t_ids_compressed - -from einops import rearrange -def scatter_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: - """ - using position ids to scatter tokens into place - """ - x_list = [] - t_coords = [] - for data, pos in zip(x, x_ids): - _, ch = data.shape # noqa: F841 - t_ids = pos[:, 0].to(torch.int64) - h_ids = pos[:, 1].to(torch.int64) - w_ids = pos[:, 2].to(torch.int64) - - t_ids_cmpr = compress_time(t_ids) - - t = torch.max(t_ids_cmpr) + 1 - h = torch.max(h_ids) + 1 - w = torch.max(w_ids) + 1 - - flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids - - out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) - out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) - - x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) - t_coords.append(torch.unique(t_ids, sorted=True)) - return x_list - - def format_text_input(prompts: List[str], system_message: str = None): # Remove [IMG] tokens from prompts to avoid Pixtral validation issues @@ -447,6 +406,33 @@ def _pack_latents(latents): return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids =h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + def encode_prompt( self, @@ -860,12 +846,12 @@ def __call__( noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, L, C) + hidden_states=latent_model_input, # (B, image_seq_len, C) timestep=timestep / 1000, guidance=guidance, encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, + txt_ids=text_ids, #B, text_seq_len, 4 + img_ids=latent_image_ids, #B, image_seq_len, 4 joint_attention_kwargs=self._attention_kwargs, return_dict=False, )[0] @@ -902,7 +888,7 @@ def __call__( if output_type == "latent": image = latents else: - latents = torch.cat(scatter_ids(latents, latent_ids), dim=1).squeeze(2) + latents = self._unpack_latents_with_ids(latents, latent_ids) latents_bn_mean = ( self.vae.bn.running_mean.view(1, -1, 1, 1) From 68db17826e35d76f1f00d50050e1256788521926 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 16 Nov 2025 02:28:48 +0000 Subject: [PATCH 23/63] style a bit --- scripts/convert_flux2_to_diffusers.py | 2 +- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/autoencoders/__init__.py | 2 +- .../autoencoders/autoencoder_kl_flux2.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/flux2/image_processor.py | 32 ++++--- .../pipelines/flux2/pipeline_flux2.py | 83 +++++++++---------- .../pipelines/flux2/pipeline_output.py | 1 - 9 files changed, 60 insertions(+), 68 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 8832dd382d61..f917ce356aaf 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -5,10 +5,10 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download +from transformers import AutoProcessor, Mistral3ForConditionalGeneration from diffusers import AutoencoderKLFlux2 from diffusers.utils.import_utils import is_accelerate_available -from transformers import Mistral3ForConditionalGeneration, AutoProcessor """ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 85e1b75cccb1..32777f9b9328 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -896,9 +896,9 @@ AutoencoderDC, AutoencoderKL, AutoencoderKLAllegro, - AutoencoderKLFlux2, AutoencoderKLCogVideoX, AutoencoderKLCosmos, + AutoencoderKLFlux2, AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index fb1c10c1a0cb..70c89536fc92 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -140,6 +140,7 @@ AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLCosmos, + AutoencoderKLFlux2, AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, @@ -149,7 +150,6 @@ AutoencoderKLQwenImage, AutoencoderKLTemporalDecoder, AutoencoderKLWan, - AutoencoderKLFlux2, AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 58a203a00ee8..470979ad33a7 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -4,6 +4,7 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cosmos import AutoencoderKLCosmos +from .autoencoder_kl_flux2 import AutoencoderKLFlux2 from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner @@ -13,7 +14,6 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan -from .autoencoder_kl_flux2 import AutoencoderKLFlux2 from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py index 4a2c9c064d74..b800fbeeccc5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union import math +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7e0973be5bf1..1ce1061edf2b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -630,7 +630,6 @@ EasyAnimateInpaintPipeline, EasyAnimatePipeline, ) - from .flux2 import Flux2Pipeline from .flux import ( FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, @@ -647,6 +646,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .flux2 import Flux2Pipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index 3c9b376dc7ac..2d088f875b3f 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +import math +from typing import Tuple -import numpy as np import PIL.Image -import torch -import math from ...configuration_utils import register_to_config from ...image_processor import VaeImageProcessor @@ -63,41 +61,41 @@ def __init__( do_convert_rgb=do_convert_rgb, ) - + @staticmethod def check_image_input( - image: PIL.Image.Image, - max_aspect_ratio: int = 8, + image: PIL.Image.Image, + max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 ) -> PIL.Image.Image: """ Check if image meets minimum size and aspect ratio requirements. - + Args: image: PIL Image to validate max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width) min_side_length: Minimum pixels required for width and height max_area: Maximum allowed area in pixels² - + Returns: The input image if valid - + Raises: ValueError: If image is too small or aspect ratio is too extreme """ if not isinstance(image, PIL.Image.Image): raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") - + width, height = image.size - + # Check minimum dimensions if width < min_side_length or height < min_side_length: raise ValueError( f"Image too small: {width}×{height}. " f"Both dimensions must be at least {min_side_length}px" ) - + # Check aspect ratio aspect_ratio = max(width / height, height / width) if aspect_ratio > max_aspect_ratio: @@ -113,14 +111,14 @@ def check_image_input( @staticmethod def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: image_width, image_height = image.size - + scale = math.sqrt(target_area/ (image_width * image_height)) width = int(image_width * scale) height = int(image_height * scale) return image.resize((width, height), PIL.Image.Resampling.LANCZOS) - - + + def _resize_and_crop( self, image: PIL.Image.Image, @@ -149,4 +147,4 @@ def _resize_and_crop( right = left + width bottom = top + height - return image.crop((left, top, right, bottom)) \ No newline at end of file + return image.crop((left, top, right, bottom)) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 39d3a78b0854..e053c511d93d 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -14,28 +14,23 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import PIL import numpy as np +import PIL import torch -from transformers import Mistral3ForConditionalGeneration, AutoProcessor +from transformers import AutoProcessor, Mistral3ForConditionalGeneration -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import Flux2PipelineOutput from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput if is_torch_xla_available(): @@ -172,7 +167,7 @@ class Flux2Pipeline(DiffusionPipeline): The Flux2 pipeline for text-to-image generation. Reference: TODO - + Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. @@ -217,7 +212,7 @@ def __init__( self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" self.text_encoder_out_layers = (10, 20, 30) - + @staticmethod def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, @@ -230,11 +225,11 @@ def _get_mistral_3_small_prompt_embeds( attribution and actions without speculation.""", hidden_states_layers: List[int] = (10, 20, 30), ): - dtype = text_encoder.dtype if dtype is None else dtype - device = text_encoder.device if device is None else device + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device prompt = [prompt] if isinstance(prompt, str) else prompt - + # Format input messages messages_batch = format_text_input(prompts=prompt, system_message=system_message) @@ -268,7 +263,7 @@ def _get_mistral_3_small_prompt_embeds( batch_size, num_channels, seq_len, hidden_dim = out.shape prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) - + return prompt_embeds @@ -295,33 +290,33 @@ def _prepare_text_ids( @staticmethod def _prepare_latent_ids( latents: torch.Tensor, # (B, C, H, W) - ): + ): r""" Generates 4D position coordinates (T, H, W, L) for latent tensors. - + Args: - latents (torch.Tensor): + latents (torch.Tensor): Latent tensor of shape (B, C, H, W) - + Returns: - torch.Tensor: + torch.Tensor: Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, H=[0..H-1], W=[0..W-1], L=0 """ - + batch_size, _, height, width = latents.shape - + t = torch.arange(1) # [0] - time dimension h = torch.arange(height) w = torch.arange(width) l = torch.arange(1) # [0] - layer dimension - + # Create position IDs: (H*W, 4) latent_ids = torch.cartesian_prod(t, h, w, l) - + # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - + return latent_ids # YiYi TODO: can optimize a bit @@ -329,7 +324,7 @@ def _prepare_latent_ids( def _prepare_image_ids( image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] scale: int = 10 - ): + ): r""" Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. @@ -338,14 +333,14 @@ def _prepare_image_ids( input latent with different dimensions. Args: - image_latents (List[torch.Tensor]): + image_latents (List[torch.Tensor]): A list of image latent feature tensors, typically of shape (C, H, W). - scale (int, optional): - A factor used to define the time separation (T-coordinate) between latents. + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th latent is: 'scale + scale * i'. Defaults to 10. Returns: - torch.Tensor: + torch.Tensor: The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all input latents. @@ -363,7 +358,7 @@ def _prepare_image_ids( # create time offset for each reference image t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] t_coords = [t.view(-1) for t in t_coords] - + image_latent_ids = [] for x, t in zip(image_latents, t_coords): @@ -375,10 +370,10 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - + return image_latent_ids - + @staticmethod def _patchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape @@ -406,7 +401,7 @@ def _pack_latents(latents): return latents - + @staticmethod def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: """ @@ -426,7 +421,7 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) - # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) out = out.view(h, w, ch).permute(2, 0, 1) x_list.append(out) @@ -459,7 +454,7 @@ def encode_prompt( system_message=self.system_message, hidden_states_layers=self.text_encoder_out_layers, ) - + batch_size, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -476,7 +471,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") image_latents = self._patchify_latents(image_latents) - + latents_bn_mean = ( self.vae.bn.running_mean.view(1, -1, 1, 1) .to(image_latents.device, image_latents.dtype) @@ -536,7 +531,7 @@ def prepare_image_latents( image = image.to(device=device, dtype=dtype) imagge_latent = self._encode_vae_image(image=image, generator=generator) image_latents.append(imagge_latent) # (1, 128, 32, 32) - + image_latent_ids = self._prepare_image_ids(image_latents) # Pack each latent and concatenate @@ -546,7 +541,7 @@ def prepare_image_latents( packed = self._pack_latents(latent) # (1, 1024, 128) packed = packed.squeeze(0) # (1024, 128) - remove batch dim packed_latents.append(packed) - + # Concatenate all reference tokens along sequence dimension image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) @@ -557,7 +552,7 @@ def prepare_image_latents( return image_latents, image_latent_ids - + def check_inputs( self, prompt, @@ -738,7 +733,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - + # 3. prepare text embeddings prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, @@ -747,11 +742,11 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - + # 4. process images if image is not None and not isinstance(image, list): image = [image] - + condition_images = None if image is not None: for img in image: @@ -785,7 +780,7 @@ def __call__( generator=generator, latents=latents, ) - + image_latents = None image_latent_ids = None if condition_images is not None: @@ -909,4 +904,4 @@ def __call__( if not return_dict: return (image,) - return Flux2PipelineOutput(images=image) \ No newline at end of file + return Flux2PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux2/pipeline_output.py b/src/diffusers/pipelines/flux2/pipeline_output.py index 2183b8bcff41..58e8ad49c210 100644 --- a/src/diffusers/pipelines/flux2/pipeline_output.py +++ b/src/diffusers/pipelines/flux2/pipeline_output.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch from ...utils import BaseOutput From f6c82a3158bca5fd814dde63e3ec8e00401d10b8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 05:34:33 +0100 Subject: [PATCH 24/63] Fix some bugs in Flux 2 transformer implementation --- .../models/transformers/transformer_flux2.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index b009db2deced..4b3a800a3155 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -133,10 +133,10 @@ def __call__( if attn.parallel_proj_in: hidden_states = attn.to_qkv_mlp_proj(hidden_states) qkv, mlp_hidden_states = torch.split( - hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor] + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 ) query, key, value = qkv.chunk(3, dim=-1) - mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) # Get encoder QKV, if available encoder_query = encoder_key = encoder_value = None @@ -423,6 +423,7 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: joint_attention_kwargs = joint_attention_kwargs or {} + # Modulation parameters shape: [1, 1, self.dim] (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt @@ -448,27 +449,27 @@ def forward( attn_output, context_attn_output, ip_attn_output = attention_outputs # Process attention outputs for the image stream (`hidden_states`). - attn_output = gate_msa.unsqueeze(1) * attn_output + attn_output = gate_msa * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + gate_mlp * ff_output if len(attention_outputs) == 3: hidden_states = hidden_states + ip_attn_output # Process attention outputs for the text stream (`encoder_hidden_states`). - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + context_attn_output = c_gate_msa * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) @@ -483,6 +484,7 @@ def __init__(self, theta: int, axes_dim: List[int]): self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] cos_out = [] sin_out = [] pos = ids.float() @@ -493,7 +495,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], - pos[:, i], + pos[..., i], theta=self.theta, repeat_interleave_real=True, use_real=True, @@ -736,6 +738,8 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) + num_txt_tokens = encoder_hidden_states.shape[1] + # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 guidance = guidance.to(hidden_states.dtype) * 1000 @@ -751,6 +755,13 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if is_torch_npu_available(): freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) @@ -760,8 +771,8 @@ def forward( image_rotary_emb = self.pos_embed(img_ids) text_rotary_emb = self.pos_embed(txt_ids) concat_rotary_emb = ( - torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=2), - torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=2), + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), ) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: @@ -790,26 +801,30 @@ def forward( image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 5. Single Stream Transformer Blocks for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( block, hidden_states, - encoder_hidden_states, + None, single_stream_mod, concat_rotary_emb, joint_attention_kwargs, ) else: - encoder_hidden_states, hidden_states = block( + hidden_states = block( hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=None, temb_mod_params=single_stream_mod, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] # 6. Output layers hidden_states = self.norm_out(hidden_states, temb) From 4082c43002a972890f3d579bd8b7dffc95eabdce Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 05:35:19 +0100 Subject: [PATCH 25/63] Fix dummy input preparation and fix some test bugs --- .../test_models_transformer_flux2.py | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index cbf3f0fa4296..30254b8e915d 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -104,15 +104,26 @@ def output_shape(self): def prepare_dummy_input(self, height=4, width=4): batch_size = 1 num_latent_channels = 4 - num_image_channels = 3 sequence_length = 48 embedding_dim = 32 hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - # pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4] + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) @@ -135,44 +146,50 @@ def prepare_init_args_and_inputs_for_common(self): "attention_head_dim": 16, "num_attention_heads": 2, "joint_attention_dim": 32, - # "pooled_projection_dim": 32, - "timestep_guidance_channels": 16, - "axes_dims_rope": [4, 4, 8], + "timestep_guidance_channels": 256, # Hardcoded in original code + "axes_dims_rope": [4, 4, 4, 4], } inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_deprecated_inputs_img_txt_ids_3d(self): + def test_flux2_consistency(self, seed=0): + torch.manual_seed(seed) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(seed) model = self.model_class(**init_dict) + # state_dict = model.state_dict() + # for key, param in state_dict.items(): + # print(f"{key} | {param.shape}") + # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt") model.to(torch_device) model.eval() with torch.no_grad(): - output_1 = model(**inputs_dict).to_tuple()[0] + output = model(**inputs_dict) - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) - text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) - image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + if isinstance(output, dict): + output = output.to_tuple()[0] - assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" - assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + self.assertIsNotNone(output) - inputs_dict["txt_ids"] = text_ids_3d - inputs_dict["img_ids"] = image_ids_3d + # input & output have to have the same shape + input_tensor = inputs_dict[self.main_input_name] + expected_shape = input_tensor.shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - with torch.no_grad(): - output_2 = model(**inputs_dict).to_tuple()[0] + # Check against expected slice + # fmt: off + expected_slice = torch.tensor([-0.3180, 0.4818, 0.6621, -0.3386, 0.2313, 0.0688, 0.0985, -0.2686, -0.1480, -0.1607, -0.7245, 0.5385, -0.2842, 0.6575, -0.0697, 0.4951]) + # fmt: on - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", - ) + flat_output = output.cpu().flatten() + generated_slice = torch.cat([flat_output[:8], flat_output[-8:]]) + self.assertTrue(torch.allclose(expected_slice, generated_slice)) def test_gradient_checkpointing_is_applied(self): - expected_set = {"FluxTransformer2DModel"} + expected_set = {"Flux2Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) # The test exists for cases like @@ -205,7 +222,7 @@ def test_lora_exclude_modules(self): assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() -class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): +class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = Flux2Transformer2DModel different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] @@ -216,7 +233,7 @@ def prepare_dummy_input(self, height, width): return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) -class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): +class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = Flux2Transformer2DModel different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] From 2b5b2e33449ce7ed696fbea2e5e8d7cb3414678d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 04:47:02 +0000 Subject: [PATCH 26/63] fix dtype casting in timestep guidance module. --- src/diffusers/models/transformers/transformer_flux2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 4b3a800a3155..e16652edd445 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -523,10 +523,10 @@ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj) # (N, D) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) time_guidance_emb = timesteps_emb + guidance_emb From 3d022b8be2435b6948c3141d80c8beb104df9712 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 06:00:43 +0000 Subject: [PATCH 27/63] resolve conflicts., --- scripts/convert_flux2_to_diffusers.py | 287 +++++- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_flux2.py | 840 ++++++++++++++++++ .../pipelines/flux2/pipeline_flux2.py | 1 + .../test_models_transformer_flux2.py | 244 +++++ 7 files changed, 1359 insertions(+), 18 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_flux2.py create mode 100644 tests/models/transformers/test_models_transformer_flux2.py diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index f917ce356aaf..5d9e3f68891c 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -1,13 +1,16 @@ import argparse from contextlib import nullcontext +from typing import Any, Dict, Tuple import safetensors.torch import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, GenerationConfig -from diffusers import AutoencoderKLFlux2 +from diffusers import ( + AutoencoderKLFlux2, Flux2Pipeline, Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler +) from diffusers.utils.import_utils import is_accelerate_available @@ -15,31 +18,49 @@ # VAE python scripts/convert_flux2_to_diffusers.py \ ---original_state_dict_repo_id "diffusers-internal-dev/dummy-flux2" \ ---filename "ae.pt" \ +--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \ +--vae_filename "flux2-vae.sft" \ --output_path "/raid/yiyi/dummy-flux2-diffusers" \ ---dtype fp32 \ --vae + +# DiT + +python scripts/convert_flux2_to_diffusers.py \ + --original_state_dict_repo_id diffusers-internal-dev/new-model-image \ + --dit_filename flux-dev-dummy.sft \ + --dit \ + --output_path . + +# Full pipe + +python scripts/convert_flux2_to_diffusers.py \ + --original_state_dict_repo_id diffusers-internal-dev/new-model-image \ + --dit_filename flux-dev-dummy.sft \ + --vae_filename "flux2-vae.sft" \ + --dit --vae --full_pipe \ + --output_path . """ CTX = init_empty_weights if is_accelerate_available() else nullcontext parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default=None, type=str) -parser.add_argument("--filename", default="flux.safetensors", type=str) -parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str) +parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str) parser.add_argument("--vae", action="store_true") +parser.add_argument("--dit", action="store_true") +parser.add_argument("--vae_dtype", type=str, default="fp32") +parser.add_argument("--dit_dtype", type=str, default="bf16") +parser.add_argument("--checkpoint_path", default=None, type=str) parser.add_argument("--full_pipe", action="store_true") parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="bf16") args = parser.parse_args() -dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 -def load_original_checkpoint(args): +def load_original_checkpoint(args, filename): if args.original_state_dict_repo_id is not None: - ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) elif args.checkpoint_path is not None: ckpt_path = args.checkpoint_path else: @@ -205,22 +226,252 @@ def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config): return new_checkpoint -def main(args): - original_ckpt = load_original_checkpoint(args) +FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + # Image and text input projections + "img_in": "x_embedder", + "txt_in": "context_embedder", + # Timestep and guidance embeddings + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + # Modulation parameters + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + # Final output layer + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params + "final_layer.linear": "proj_out", +} + + +FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear", +} + + +FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + # Handle fused QKV projections separately as we need to break into Q, K, V projections + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", +} + + +FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", +} + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use +# diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight + if ".weight" not in key: + return + + # If adaLN_modulation is in the key, swap scale and shift parameters + # Original implementation is (shift, scale); diffusers implementation is (scale, shift) + if "adaLN_modulation" in key: + key_without_param_type, param_type = key.rsplit(".", maxsplit=1) + # Assume all such keys are in the AdaLayerNorm key map + new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type] + new_key = ".".join([new_key_without_param_type, param_type]) + + swapped_weight = swap_scale_shift(state_dict.pop(key)) + state_dict[new_key] = swapped_weight + return + + +def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + new_prefix = "transformer_blocks" + if "double_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + if "qkv" in within_block_name: + fused_qkv_weight = state_dict.pop(key) + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + if "img" in modality_block_name: + # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.to_q" + new_k_name = "attn.to_k" + new_v_name = "attn.to_v" + elif "txt" in modality_block_name: + # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.add_q_proj" + new_k_name = "attn.add_k_proj" + new_v_name = "attn.add_v_proj" + new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type]) + new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type]) + new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type]) + state_dict[new_q_key] = to_q_weight + state_dict[new_k_key] = to_k_weight + state_dict[new_v_key] = to_v_weight + else: + new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + return + + +def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + # Mapping: + # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj + # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out + # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight + # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight + new_prefix = "single_transformer_blocks" + if "single_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + return + + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "adaLN_modulation": convert_ada_layer_norm_weights, + "double_blocks": convert_flux2_double_stream_blocks, + "single_blocks": convert_flux2_single_stream_blocks, +} + + +def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: + if model_type == "test" or model_type == "dummy-flux2": + config = { + "model_id": "diffusers-internal-dev/dummy-flux2", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 8, + "num_single_layers": 48, + "attention_head_dim": 128, + "num_attention_heads": 48, + "joint_attention_dim": 15360, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + }, + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str): + config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type) + + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + transformer = Flux2Transformer2DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + +def main(args): if args.vae: + original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename) vae = AutoencoderKLFlux2() - converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_ckpt, vae.config) + converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config) vae.load_state_dict(converted_vae_state_dict, strict=True) - vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + if not args.full_pipe: + vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 + vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae") + + if args.dit: + original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test") + if not args.full_pipe: + dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32 + transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer") if args.full_pipe: tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" - text_encoder = Mistral3ForConditionalGeneration.from_pretrained(text_encoder_id, torch_dtype=torch.bfloat16) + generate_config = GenerationConfig.from_pretrained(text_encoder_id) + generate_config.do_sample = True + text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16 + ) tokenizer = AutoProcessor.from_pretrained(tokenizer_id) - - # TODO: collate denoiser, vae, text encoder, tokenizer here. + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler") + + pipe = Flux2Pipeline( + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler + ) + pipe.save_pretrained(args.output_path) if __name__ == "__main__": main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 32777f9b9328..f02d0852c972 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -216,6 +216,7 @@ "CosmosTransformer3DModel", "DiTTransformer2DModel", "EasyAnimateTransformer3DModel", + "Flux2Transformer2DModel", "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", @@ -928,6 +929,7 @@ CosmosTransformer3DModel, DiTTransformer2DModel, EasyAnimateTransformer3DModel, + Flux2Transformer2DModel, FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 70c89536fc92..2ff00f614040 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,6 +93,7 @@ _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] @@ -191,6 +192,7 @@ DiTTransformer2DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, + Flux2Transformer2DModel, FluxTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 66daf56e23b2..c00abda53da3 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -26,6 +26,7 @@ from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_flux2 import Flux2Transformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py new file mode 100644 index 000000000000..e16652edd445 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -0,0 +1,840 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import ( + TimestepEmbedding, + Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + mlp_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.parallel_proj_in: + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + query, key, value = qkv.chunk(3, dim=-1) + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Get encoder QKV, if available + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None: + if hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk( + 3, dim=-1 + ) + elif attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + else: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if attn.parallel_proj_out: + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + else: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [ + Flux2AttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + parallel_proj_in: bool = False, + parallel_proj_out: bool = False, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.parallel_proj_in = parallel_proj_in + self.parallel_proj_out = parallel_proj_out + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + if self.parallel_proj_in: + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + else: + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + if self.parallel_proj_out: + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + else: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + parallel_proj_in=True, + parallel_proj_out=True, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2AttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +@maybe_allow_in_graph +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + return time_guidance_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Flux 2. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `8`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `48`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `48`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `15360`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + double_stream_mod_img, + double_stream_mod_txt, + concat_rotary_emb, + joint_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + None, + single_stream_mod, + concat_rotary_emb, + joint_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index e053c511d93d..3e4bfda07706 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -883,6 +883,7 @@ def __call__( if output_type == "latent": image = latents else: + torch.save({"pred": latents}, "pred_d.pt") latents = self._unpack_latents_with_ids(latents, latent_ids) latents_bn_mean = ( diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py new file mode 100644 index 000000000000..30254b8e915d --- /dev/null +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Flux2Transformer2DModel +from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 +from diffusers.models.embeddings import ImageProjection + +from ...testing_utils import enable_full_determinism, is_peft_available, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +def create_flux_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + # "image_proj" (ImageProjection layer weights) + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.7, 0.6, 0.6] + + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_dummy_input(self, height=4, width=4): + batch_size = 1 + num_latent_channels = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4] + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + # "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + "guidance": guidance, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "timestep_guidance_channels": 256, # Hardcoded in original code + "axes_dims_rope": [4, 4, 4, 4], + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_flux2_consistency(self, seed=0): + torch.manual_seed(seed) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(seed) + model = self.model_class(**init_dict) + # state_dict = model.state_dict() + # for key, param in state_dict.items(): + # print(f"{key} | {param.shape}") + # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt") + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + + # input & output have to have the same shape + input_tensor = inputs_dict[self.main_input_name] + expected_shape = input_tensor.shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + # Check against expected slice + # fmt: off + expected_slice = torch.tensor([-0.3180, 0.4818, 0.6621, -0.3386, 0.2313, 0.0688, 0.0985, -0.2686, -0.1480, -0.1607, -0.7245, 0.5385, -0.2842, 0.6575, -0.0697, 0.4951]) + # fmt: on + + flat_output = output.cpu().flatten() + generated_slice = torch.cat([flat_output[:8], flat_output[-8:]]) + self.assertTrue(torch.allclose(expected_slice, generated_slice)) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Flux2Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # The test exists for cases like + # https://github.com/huggingface/diffusers/issues/11874 + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_exclude_modules(self): + from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict + + lora_rank = 4 + target_module = "single_transformer_blocks.0.proj_out" + adapter_name = "foo" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + state_dict = model.state_dict() + target_mod_shape = state_dict[f"{target_module}.weight"].shape + lora_state_dict = { + f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, + f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, + } + # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). + config = LoraConfig( + r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] + ) + inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) + set_peft_model_state_dict(model, lora_state_dict, adapter_name) + retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) + assert len(retrieved_lora_state_dict) == len(lora_state_dict) + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() + assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() + + +class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) + + +class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = Flux2Transformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) From ffb006197ca474721f608074212e4a78d00fbb7e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 06:30:56 +0000 Subject: [PATCH 28/63] remove ip adapter stuff. --- src/diffusers/models/transformers/transformer_flux2.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index e16652edd445..70b7b75fe334 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -775,11 +775,6 @@ def forward( torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), ) - if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: - ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") - ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) - joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) - # 4. Double Stream Transformer Blocks for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: From 6820d6ca39facca8d2b20067503a35516f98c02e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 08:06:21 +0100 Subject: [PATCH 29/63] Fix Flux 2 transformer consistency test --- .../transformers/test_models_transformer_flux2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index 30254b8e915d..ccc3d42fcba9 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -17,7 +17,7 @@ import torch -from diffusers import Flux2Transformer2DModel +from diffusers import Flux2Transformer2DModel, attention_backend from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection @@ -166,11 +166,12 @@ def test_flux2_consistency(self, seed=0): model.to(torch_device) model.eval() - with torch.no_grad(): - output = model(**inputs_dict) + with attention_backend("native"): + with torch.no_grad(): + output = model(**inputs_dict) - if isinstance(output, dict): - output = output.to_tuple()[0] + if isinstance(output, dict): + output = output.to_tuple()[0] self.assertIsNotNone(output) @@ -181,12 +182,12 @@ def test_flux2_consistency(self, seed=0): # Check against expected slice # fmt: off - expected_slice = torch.tensor([-0.3180, 0.4818, 0.6621, -0.3386, 0.2313, 0.0688, 0.0985, -0.2686, -0.1480, -0.1607, -0.7245, 0.5385, -0.2842, 0.6575, -0.0697, 0.4951]) + expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416]) # fmt: on flat_output = output.cpu().flatten() generated_slice = torch.cat([flat_output[:8], flat_output[-8:]]) - self.assertTrue(torch.allclose(expected_slice, generated_slice)) + self.assertTrue(torch.allclose(generated_slice, expected_slice)) def test_gradient_checkpointing_is_applied(self): expected_set = {"Flux2Transformer2DModel"} From f6059b758c671a58fc88b75e449fd2ecc882b9c2 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 08:09:06 +0100 Subject: [PATCH 30/63] Fix bug in Flux2TransformerBlock (double stream block) --- src/diffusers/models/transformers/transformer_flux2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index e16652edd445..65734318e02f 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -433,7 +433,7 @@ def forward( # Conditioning txt stream norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) - norm_encoder_hidden_states = (1 + c_scale_msa) * encoder_hidden_states + c_shift_msa + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa # Attention on concatenated img + txt stream attention_outputs = self.attn( From 05fcddbf152755a9a96ec2072718e8f7949b7f0a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 08:20:40 +0100 Subject: [PATCH 31/63] Get remaining Flux 2 transformer tests passing --- .../transformers/test_models_transformer_flux2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index ccc3d42fcba9..b58b0542a0f2 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -187,7 +187,7 @@ def test_flux2_consistency(self, seed=0): flat_output = output.cpu().flatten() generated_slice = torch.cat([flat_output[:8], flat_output[-8:]]) - self.assertTrue(torch.allclose(generated_slice, expected_slice)) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4)) def test_gradient_checkpointing_is_applied(self): expected_set = {"Flux2Transformer2DModel"} @@ -200,7 +200,7 @@ def test_lora_exclude_modules(self): from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict lora_rank = 4 - target_module = "single_transformer_blocks.0.proj_out" + target_module = "single_transformer_blocks.0.attn.to_out" adapter_name = "foo" init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) @@ -213,14 +213,14 @@ def test_lora_exclude_modules(self): } # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). config = LoraConfig( - r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] + r=lora_rank, target_modules=[target_module], exclude_modules=["to_out"] ) inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) set_peft_model_state_dict(model, lora_state_dict, adapter_name) retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) assert len(retrieved_lora_state_dict) == len(lora_state_dict) - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() + assert (retrieved_lora_state_dict[f"{target_module}.lora_A.weight"] == 22).all() + assert (retrieved_lora_state_dict[f"{target_module}.lora_B.weight"] == 33).all() class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): From ac4f61a1923033bb04de9674367742cad3ffa2b8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 17 Nov 2025 08:23:41 +0100 Subject: [PATCH 32/63] make style; make quality; make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../transformers/test_models_transformer_flux2.py | 4 +--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 81eb2569e303..b48b26942f9f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -843,6 +843,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Flux2Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FluxControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index b58b0542a0f2..37d5130a8273 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -212,9 +212,7 @@ def test_lora_exclude_modules(self): f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, } # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). - config = LoraConfig( - r=lora_rank, target_modules=[target_module], exclude_modules=["to_out"] - ) + config = LoraConfig(r=lora_rank, target_modules=[target_module], exclude_modules=["to_out"]) inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) set_peft_model_state_dict(model, lora_state_dict, adapter_name) retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) From 20f9b830149a882cb56dc74f85f914c6875657e1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 07:26:08 +0000 Subject: [PATCH 33/63] remove stuff. --- src/diffusers/models/transformers/transformer_flux2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index aa25d856f5b2..519170d60d2d 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -294,8 +294,7 @@ def forward( **kwargs, ) -> torch.Tensor: attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} - unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] if len(unused_kwargs) > 0: logger.warning( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." From 546e60da314aa296c08105c18f78baa5d1a38901 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 07:43:32 +0000 Subject: [PATCH 34/63] fix type annotaton. --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 3e4bfda07706..be68064870e4 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -20,7 +20,7 @@ import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration -from ...models import AutoencoderKL, FluxTransformer2DModel +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -169,11 +169,11 @@ class Flux2Pipeline(DiffusionPipeline): Reference: TODO Args: - transformer ([`FluxTransformer2DModel`]): + transformer ([`Flux2Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): + vae ([`AutoencoderKLFlux2`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`Mistral3ForConditionalGeneration`]): [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) @@ -189,10 +189,10 @@ class Flux2Pipeline(DiffusionPipeline): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, + vae: AutoencoderKLFlux2, text_encoder: Mistral3ForConditionalGeneration, tokenizer: AutoProcessor, - transformer: FluxTransformer2DModel, + transformer: Flux2Transformer2DModel, ): super().__init__() From eeb52c28ace4c915861b1bac0c51671e03fd84a1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 09:42:30 +0000 Subject: [PATCH 35/63] remove unneeded stuff from tests --- .../test_models_transformer_flux2.py | 83 +------------------ 1 file changed, 2 insertions(+), 81 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index 37d5130a8273..822eae93b9ad 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -28,58 +28,6 @@ enable_full_determinism() -def create_flux_ip_adapter_state_dict(model): - # "ip_adapter" (cross-attention weights) - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - # "image_proj" (ImageProjection layer weights) - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - - class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase): model_class = Flux2Transformer2DModel main_input_name = "hidden_states" @@ -132,7 +80,6 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states": encoder_hidden_states, "img_ids": image_ids, "txt_ids": text_ids, - # "pooled_projections": pooled_prompt_embeds, "timestep": timestep, "guidance": guidance, } @@ -153,6 +100,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + # TODO (Daniel, Sayak): We can remove this test. def test_flux2_consistency(self, seed=0): torch.manual_seed(seed) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -192,34 +140,7 @@ def test_flux2_consistency(self, seed=0): def test_gradient_checkpointing_is_applied(self): expected_set = {"Flux2Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - # The test exists for cases like - # https://github.com/huggingface/diffusers/issues/11874 - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_exclude_modules(self): - from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict - - lora_rank = 4 - target_module = "single_transformer_blocks.0.attn.to_out" - adapter_name = "foo" - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - state_dict = model.state_dict() - target_mod_shape = state_dict[f"{target_module}.weight"].shape - lora_state_dict = { - f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, - f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, - } - # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). - config = LoraConfig(r=lora_rank, target_modules=[target_module], exclude_modules=["to_out"]) - inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) - set_peft_model_state_dict(model, lora_state_dict, adapter_name) - retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) - assert len(retrieved_lora_state_dict) == len(lora_state_dict) - assert (retrieved_lora_state_dict[f"{target_module}.lora_A.weight"] == 22).all() - assert (retrieved_lora_state_dict[f"{target_module}.lora_B.weight"] == 33).all() - + class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = Flux2Transformer2DModel From 771e17bcc879b52a5ec63555317c68474a5e113a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 15:03:47 +0000 Subject: [PATCH 36/63] tests --- src/diffusers/models/embeddings.py | 2 +- .../pipelines/flux2/pipeline_flux2.py | 19 +- tests/pipelines/flux2/__init__.py | 0 tests/pipelines/flux2/test_pipeline_flux2.py | 200 ++++++++++++++++++ 4 files changed, 210 insertions(+), 11 deletions(-) create mode 100644 tests/pipelines/flux2/__init__.py create mode 100644 tests/pipelines/flux2/test_pipeline_flux2.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 37fc412adcc3..d630fd2c6c64 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1228,7 +1228,7 @@ def apply_rotary_emb( x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index be68064870e4..eb7961074193 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import PIL @@ -183,7 +183,6 @@ class Flux2Pipeline(DiffusionPipeline): """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" - _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -211,8 +210,7 @@ def __init__( self.default_sample_size = 128 self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" - self.text_encoder_out_layers = (10, 20, 30) - + @staticmethod def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, @@ -248,7 +246,7 @@ def _get_mistral_3_small_prompt_embeds( # Move to device input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) - + # Forward pass through the model output = text_encoder( input_ids=input_ids, @@ -436,6 +434,7 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), ): device = device or self._execution_device @@ -452,7 +451,7 @@ def encode_prompt( device=device, max_sequence_length=max_sequence_length, system_message=self.system_message, - hidden_states_layers=self.text_encoder_out_layers, + hidden_states_layers=text_encoder_out_layers, ) batch_size, seq_len, _ = prompt_embeds.shape @@ -492,14 +491,13 @@ def prepare_latents( device, generator: torch.Generator, latents: Optional[torch.Tensor] = None, - ): + ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_latents_channels * 4, height//2, width//2) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -628,6 +626,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), ): r""" Function invoked when calling the pipeline for generation. @@ -741,6 +740,7 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, ) # 4. process images @@ -768,8 +768,7 @@ def __call__( condition_image_sizes.append((image_width, image_height)) # 5. prepare latent variables - num_channels_latents = 32 - # num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_ids = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, num_latents_channels=num_channels_latents, diff --git a/tests/pipelines/flux2/__init__.py b/tests/pipelines/flux2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py new file mode 100644 index 000000000000..3bda5ba579fd --- /dev/null +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -0,0 +1,200 @@ +import unittest +import pytest +import numpy as np +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Mistral3Config + +from diffusers import ( + AutoencoderKLFlux2, + Flux2Pipeline, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + Flux2Transformer2DModel, +) + +from ...testing_utils import ( + Expectations, + backend_empty_cache, + nightly, + numpy_cosine_similarity_distance, + require_big_accelerator, + slow, + torch_device, +) +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + FluxIPAdapterTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + check_qkv_fused_layers_exist, +) + + +class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Flux2Pipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = Flux2Transformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=16, + timestep_guidance_channels=256, # Hardcoded in original code + axes_dims_rope=[4, 4, 4, 4], + ) + + config = Mistral3Config( + text_config={ + "model_type": "mistral", + "vocab_size": 32000, + "hidden_size": 16, + "intermediate_size": 37, + "max_position_embeddings": 512, + "num_attention_heads": 4, + "num_hidden_layers": 1, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "bos_token_id": 2, + "eos_token_id": 3, + "pad_token_id": 4, + }, + vision_config={ + "model_type": "pixtral", + "hidden_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 4, + "intermediate_size": 37, + "image_size": 30, + "patch_size": 6, + "num_channels": 3, + }, + bos_token_id=2, + eos_token_id=3, + pad_token_id=4, + model_dtype="mistral3", + image_seq_length=4, + vision_feature_layer=-1, + image_token_index=1, + ) + torch.manual_seed(0) + text_encoder = Mistral3ForConditionalGeneration(config) + tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor") + + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "a dog is dancing", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 8, + "output_type": "np", + "text_encoder_out_layers": (1,) + } + return inputs + + @pytest.mark.xfail(condition=True, reason="Flux2 uses parallel projections which are incompatible here.") + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), + ) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + self.assertTrue( + np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), + ("Fusion of QKV projections shouldn't affect the outputs."), + ) + self.assertTrue( + np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), + ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."), + ) + self.assertTrue( + np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), + ("Original outputs should match when fused QKV projections are disabled."), + ) + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) From 28da679969d64822d8f3588d1ba6a6fd47497ed8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 15:05:16 +0000 Subject: [PATCH 37/63] up --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index eb7961074193..9b8e93d16a2d 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -182,7 +182,7 @@ class Flux2Pipeline(DiffusionPipeline): [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( From 14f986636da570d0b2e24449e808d60b584dfa98 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Nov 2025 16:07:54 +0000 Subject: [PATCH 38/63] up --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 200 ++++++++++++++++++ src/diffusers/loaders/peft.py | 1 + .../models/transformers/transformer_flux2.py | 4 +- .../pipelines/flux2/pipeline_flux2.py | 3 +- tests/lora/test_lora_layers_flux2.py | 132 ++++++++++++ 6 files changed, 338 insertions(+), 4 deletions(-) create mode 100644 tests/lora/test_lora_layers_flux2.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 48507aae038c..e7a7109b3f6c 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -81,6 +81,7 @@ def text_encoder_attn_modules(text_encoder): "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", + "Flux2LoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -128,6 +129,7 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + Flux2LoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 25919a896af0..a807ddb5a0d2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5084,6 +5084,206 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class Flux2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 7d65b30659fb..b759e04cbf2d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -62,6 +62,7 @@ "WanVACETransformer3DModel": lambda model_cls, weights: weights, "ChromaTransformer2DModel": lambda model_cls, weights: weights, "QwenImageTransformer2DModel": lambda model_cls, weights: weights, + "Flux2Transformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 519170d60d2d..fc29d0a9db7c 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -209,9 +209,7 @@ def __call__( class Flux2Attention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = Flux2AttnProcessor - _available_processors = [ - Flux2AttnProcessor, - ] + _available_processors = [Flux2AttnProcessor] def __init__( self, diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 9b8e93d16a2d..3624d7e7d27a 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -20,6 +20,7 @@ import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from ...loaders import Flux2LoraLoaderMixin from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -162,7 +163,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(DiffusionPipeline): +class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin): r""" The Flux2 pipeline for text-to-image generation. diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py new file mode 100644 index 000000000000..67c3975ee177 --- /dev/null +++ b/tests/lora/test_lora_layers_flux2.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from diffusers import ( + Flux2Pipeline, + Flux2Transformer2DModel, + FlowMatchEulerDiscreteScheduler, + AutoencoderKLFlux2 +) + +from ..testing_utils import floats_tensor, require_peft_backend + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = Flux2Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 16, + "timestep_guidance_channels": 256, + "axes_dims_rope": [4, 4, 4, 4], + } + transformer_cls = Flux2Transformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ("DownEncoderBlock2D",), + "up_block_types": ("UpDecoderBlock2D",), + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + } + vae_cls = AutoencoderKLFlux2 + + tokenizer_cls, tokenizer_id = AutoProcessor, "hf-internal-testing/tiny-mistral3-diffusers" + text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" + denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "a dog is dancing", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 8, + "output_type": "np", + "text_encoder_out_layers": (1,) + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in Flux2.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Flux2.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Flux2.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Flux2.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Flux2.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Flux2.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Flux2.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Flux2.") + def test_simple_inference_with_text_lora_save_load(self): + pass From 6d1697592dc33ee2526fd28bc29e0c2b332adfb7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 17 Nov 2025 18:21:59 +0100 Subject: [PATCH 39/63] add sf support --- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 170 +++++++++++++++++++++ 2 files changed, 175 insertions(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b53647d47630..7b581ac3eb9c 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -34,6 +34,7 @@ convert_chroma_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, + convert_flux2_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, @@ -162,6 +163,10 @@ "checkpoint_mapping_fn": lambda x: x, "default_subfolder": "transformer", }, + "Flux2Transformer2DModel": { + "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ef6c41e3ce97..d4676ba2526a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -140,6 +140,7 @@ "net.blocks.0.self_attn.q_proj.weight", "net.pos_embedder.dim_spatial_range", ], + "flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -189,6 +190,7 @@ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"}, "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, @@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "animatediff_v3" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]): + model_type = "flux-2-dev" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): if any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] @@ -3647,3 +3652,168 @@ def rename_transformer_blocks_(key: str, state_dict): handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + # Image and text input projections + "img_in": "x_embedder", + "txt_in": "context_embedder", + # Timestep and guidance embeddings + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + # Modulation parameters + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + # Final output layer + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params + "final_layer.linear": "proj_out", + } + + FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear", + } + + FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + # Handle fused QKV projections separately as we need to break into Q, K, V projections + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", + } + + FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", + } + + def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + # Mapping: + # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj + # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out + # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight + # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight + new_prefix = "single_transformer_blocks" + if "single_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None: + # Skip if not a weight + if ".weight" not in key: + return + + # If adaLN_modulation is in the key, swap scale and shift parameters + # Original implementation is (shift, scale); diffusers implementation is (scale, shift) + if "adaLN_modulation" in key: + key_without_param_type, param_type = key.rsplit(".", maxsplit=1) + # Assume all such keys are in the AdaLayerNorm key map + new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type] + new_key = ".".join([new_key_without_param_type, param_type]) + + swapped_weight = swap_scale_shift(state_dict.pop(key), 0) + state_dict[new_key] = swapped_weight + + return + + def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None: + # Skip if not a weight, bias, or scale + if ".weight" not in key and ".bias" not in key and ".scale" not in key: + return + + new_prefix = "transformer_blocks" + if "double_blocks." in key: + parts = key.split(".") + block_idx = parts[1] + modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + + if param_type == "scale": + param_type = "weight" + + if "qkv" in within_block_name: + fused_qkv_weight = state_dict.pop(key) + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + if "img" in modality_block_name: + # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.to_q" + new_k_name = "attn.to_k" + new_v_name = "attn.to_v" + elif "txt" in modality_block_name: + # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) + new_q_name = "attn.add_q_proj" + new_k_name = "attn.add_k_proj" + new_v_name = "attn.add_v_proj" + new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type]) + new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type]) + new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type]) + state_dict[new_q_key] = to_q_weight + state_dict[new_k_key] = to_k_weight + state_dict[new_v_key] = to_v_weight + else: + new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name] + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) + + param = state_dict.pop(key) + state_dict[new_key] = param + return + + def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "adaLN_modulation": convert_ada_layer_norm_weights, + "double_blocks": convert_flux2_double_stream_blocks, + "single_blocks": convert_flux2_single_stream_blocks, + } + + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + update_state_dict(converted_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict From be6604bbf55a83c31ddccfc9770e14831208182b Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 17 Nov 2025 20:36:57 -0800 Subject: [PATCH 40/63] Remove unused IP Adapter and ControlNet logic from transformer (#9) --- .../models/transformers/transformer_flux2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index fc29d0a9db7c..65ce4921b1b9 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -440,10 +440,7 @@ def forward( **joint_attention_kwargs, ) - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs + attn_output, context_attn_output = attention_outputs # Process attention outputs for the image stream (`hidden_states`). attn_output = gate_msa * attn_output @@ -455,9 +452,6 @@ def forward( ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_mlp * ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output - # Process attention outputs for the text stream (`encoder_hidden_states`). context_attn_output = c_gate_msa * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output @@ -690,10 +684,7 @@ def forward( txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, return_dict: bool = True, - controlnet_blocks_repeat: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. From ec0a2addb5f215741ab2b66aa7c0121c933f2a23 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Nov 2025 05:12:03 +0000 Subject: [PATCH 41/63] copied from --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 3624d7e7d27a..d38590a52444 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -78,6 +78,8 @@ def format_text_input(prompts: List[str], system_message: str = None): for prompt in cleaned_txt ] + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, From 980bcc866c2cd3e9601572a0e5c0b822640b87b0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Nov 2025 10:46:43 +0530 Subject: [PATCH 42/63] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: YiYi Xu Co-authored-by: apolinário --- .../pipelines/flux2/image_processor.py | 10 +++------- src/diffusers/pipelines/flux2/pipeline_flux2.py | 17 ++++++++--------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index 2d088f875b3f..e7adda5d53db 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -29,18 +29,14 @@ class Flux2ImageProcessor(VaeImageProcessor): do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. - vae_scale_factor (`int`, *optional*, defaults to `8`): + vae_scale_factor (`int`, *optional*, defaults to `16`): VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. - vae_latent_channels (`int`, *optional*, defaults to `16`): + vae_latent_channels (`int`, *optional*, defaults to `32`): VAE latent channels. - spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`): - The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2). - resample (`str`, *optional*, defaults to `lanczos`): - Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. - do_convert_rgb (`bool`, *optional*, defaults to be `False`): + do_convert_rgb (`bool`, *optional*, defaults to be `True`): Whether to convert the images to RGB format. """ diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index d38590a52444..3b1a1091243e 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -211,8 +211,7 @@ def __init__( self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 512 self.default_sample_size = 128 - self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object -attribution and actions without speculation.""" + self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" @staticmethod def _get_mistral_3_small_prompt_embeds( @@ -222,8 +221,7 @@ def _get_mistral_3_small_prompt_embeds( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, max_sequence_length: int = 512, - system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object -attribution and actions without speculation.""", + system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""", hidden_states_layers: List[int] = (10, 20, 30), ): dtype = text_encoder.dtype if dtype is None else dtype @@ -320,7 +318,6 @@ def _prepare_latent_ids( return latent_ids - # YiYi TODO: can optimize a bit @staticmethod def _prepare_image_ids( image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] @@ -709,8 +706,6 @@ def __call__( returning a tuple, the first element is a list with the generated images. """ - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -756,7 +751,6 @@ def __call__( self.image_processor.check_image_input(img) condition_images = [] - condition_image_sizes = [] for img in image: image_width, image_height = img.size if image_width * image_height > 1024 * 1024: @@ -768,7 +762,12 @@ def __call__( image_height = (image_height // multiple_of) * multiple_of img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode = "crop") condition_images.append(img) - condition_image_sizes.append((image_width, image_height)) + height = height or image_height + width = width or image_width + + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor # 5. prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 From e1d46ce4586a1d63a5c9121b70bbb5b2cc896397 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Nov 2025 11:01:38 +0530 Subject: [PATCH 43/63] up --- scripts/convert_flux2_to_diffusers.py | 22 ++-- src/diffusers/__init__.py | 4 +- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- .../autoencoders/autoencoder_kl_flux2.py | 31 ++++- src/diffusers/models/embeddings.py | 2 +- .../pipelines/flux2/image_processor.py | 14 +-- .../pipelines/flux2/pipeline_flux2.py | 115 +++++++----------- tests/lora/test_lora_layers_flux2.py | 11 +- .../test_models_transformer_flux2.py | 6 +- tests/pipelines/flux2/test_pipeline_flux2.py | 26 ++-- 11 files changed, 103 insertions(+), 132 deletions(-) diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 5d9e3f68891c..2973913fa215 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -6,11 +6,9 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoProcessor, Mistral3ForConditionalGeneration, GenerationConfig +from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration -from diffusers import ( - AutoencoderKLFlux2, Flux2Pipeline, Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler -) +from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel from diffusers.utils.import_utils import is_accelerate_available @@ -70,7 +68,6 @@ def load_original_checkpoint(args, filename): return original_state_dict - DIFFUSERS_VAE_TO_FLUX2_MAPPING = { "encoder.conv_in.weight": "encoder.conv_in.weight", "encoder.conv_in.bias": "encoder.conv_in.bias", @@ -90,7 +87,8 @@ def load_original_checkpoint(args, filename): "post_quant_conv.bias": "decoder.post_quant_conv.bias", "bn.running_mean": "bn.running_mean", "bn.running_var": "bn.running_var", - } +} + # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear def conv_attn_to_linear(checkpoint): @@ -104,6 +102,7 @@ def conv_attn_to_linear(checkpoint): if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] + def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): for ldm_key in keys: diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") @@ -462,16 +461,15 @@ def main(args): text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16 ) tokenizer = AutoProcessor.from_pretrained(tokenizer_id) - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="scheduler" + ) pipe = Flux2Pipeline( - vae=vae, - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - scheduler=scheduler + vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler ) pipe.save_pretrained(args.output_path) + if __name__ == "__main__": main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f02d0852c972..25e9000ee8bb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -183,10 +183,10 @@ "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", - "AutoencoderKLFlux2", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", "AutoencoderKLCosmos", + "AutoencoderKLFlux2", "AutoencoderKLHunyuanImage", "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", @@ -458,6 +458,7 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", + "Flux2Pipeline", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", @@ -471,7 +472,6 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", - "Flux2Pipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index e7a7109b3f6c..4e3eb009533a 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -114,6 +114,7 @@ def text_encoder_attn_modules(text_encoder): AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, + Flux2LoraLoaderMixin, FluxLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, @@ -129,7 +130,6 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, - Flux2LoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 2ff00f614040..dd3104d4501a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,6 +35,7 @@ _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"] + _import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"] _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] @@ -44,7 +45,6 @@ _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] - _import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py index b800fbeeccc5..7b572f82ad67 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py @@ -72,14 +72,29 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), - up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), - block_out_channels: Tuple[int, ...] = (128, 256, 512, 512,), + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), layers_per_block: int = 2, act_fn: str = "silu", latent_channels: int = 32, norm_num_groups: int = 32, - sample_size: int = 1024, # YiYi notes: not sure + sample_size: int = 1024, # YiYi notes: not sure force_upcast: bool = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, @@ -118,7 +133,13 @@ def __init__( self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None - self.bn = nn.BatchNorm2d(math.prod(patch_size) * latent_channels, eps=batch_norm_eps, momentum=batch_norm_momentum, affine=False, track_running_stats=True) + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) self.use_slicing = False self.use_tiling = False diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d630fd2c6c64..37fc412adcc3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1228,7 +1228,7 @@ def apply_rotary_emb( x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index e7adda5d53db..91c8f875dd1d 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -57,13 +57,9 @@ def __init__( do_convert_rgb=do_convert_rgb, ) - @staticmethod def check_image_input( - image: PIL.Image.Image, - max_aspect_ratio: int = 8, - min_side_length: int = 64, - max_area: int = 1024 * 1024 + image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 ) -> PIL.Image.Image: """ Check if image meets minimum size and aspect ratio requirements. @@ -88,8 +84,7 @@ def check_image_input( # Check minimum dimensions if width < min_side_length or height < min_side_length: raise ValueError( - f"Image too small: {width}×{height}. " - f"Both dimensions must be at least {min_side_length}px" + f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px" ) # Check aspect ratio @@ -100,21 +95,18 @@ def check_image_input( f"Maximum allowed ratio is {max_aspect_ratio}:1" ) - return image - @staticmethod def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: image_width, image_height = image.size - scale = math.sqrt(target_area/ (image_width * image_height)) + scale = math.sqrt(target_area / (image_width * image_height)) width = int(image_width * scale) height = int(image_height * scale) return image.resize((width, height), PIL.Image.Resampling.LANCZOS) - def _resize_and_crop( self, image: PIL.Image.Image, diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 3b1a1091243e..1ddd8d9c7d82 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL @@ -152,6 +152,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -165,6 +166,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin): r""" The Flux2 pipeline for text-to-image generation. @@ -211,8 +213,11 @@ def __init__( self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 512 self.default_sample_size = 128 - self.system_message = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" - + + # fmt: off + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + @staticmethod def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, @@ -221,7 +226,9 @@ def _get_mistral_3_small_prompt_embeds( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, max_sequence_length: int = 512, - system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""", + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on hidden_states_layers: List[int] = (10, 20, 30), ): dtype = text_encoder.dtype if dtype is None else dtype @@ -247,7 +254,7 @@ def _get_mistral_3_small_prompt_embeds( # Move to device input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) - + # Forward pass through the model output = text_encoder( input_ids=input_ids, @@ -265,10 +272,9 @@ def _get_mistral_3_small_prompt_embeds( return prompt_embeds - @staticmethod def _prepare_text_ids( - x: torch.Tensor, # (B, L, D) or (L, D) + x: torch.Tensor, # (B, L, D) or (L, D) t_coord: Optional[torch.Tensor] = None, ): B, L, _ = x.shape @@ -285,7 +291,6 @@ def _prepare_text_ids( return torch.stack(out_ids) - @staticmethod def _prepare_latent_ids( latents: torch.Tensor, # (B, C, H, W) @@ -299,8 +304,8 @@ def _prepare_latent_ids( Returns: torch.Tensor: - Position IDs tensor of shape (B, H*W, 4) - All batches share the same coordinate structure: T=0, H=[0..H-1], W=[0..W-1], L=0 + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 """ batch_size, _, height, width = latents.shape @@ -320,28 +325,26 @@ def _prepare_latent_ids( @staticmethod def _prepare_image_ids( - image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] - scale: int = 10 + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, ): - r""" Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. - This function creates a unique coordinate for every pixel/patch across all - input latent with different dimensions. + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. Args: image_latents (List[torch.Tensor]): A list of image latent feature tensors, typically of shape (C, H, W). scale (int, optional): - A factor used to define the time separation (T-coordinate) between latents. - T-coordinate for the i-th latent is: 'scale + scale * i'. Defaults to 10. + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. Returns: torch.Tensor: - The combined coordinate tensor. - Shape: (1, N_total, 4) - Where N_total is the sum of (H * W) for all input latents. + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. Coordinate Components (Dimension 4): - T (Time): The unique index indicating which latent image the coordinate belongs to. @@ -359,7 +362,6 @@ def _prepare_image_ids( image_latent_ids = [] for x, t in zip(image_latents, t_coords): - x = x.squeeze(0) _, height, width = x.shape @@ -371,7 +373,6 @@ def _prepare_image_ids( return image_latent_ids - @staticmethod def _patchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape @@ -383,9 +384,9 @@ def _patchify_latents(latents): @staticmethod def _unpatchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape - latents = latents.reshape(batch_size, num_channels_latents // (2 * 2) , 2, 2, height, width) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) latents = latents.permute(0, 1, 4, 2, 5, 3) - latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height *2 , width *2) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) return latents @staticmethod @@ -399,7 +400,6 @@ def _pack_latents(latents): return latents - @staticmethod def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: """ @@ -414,7 +414,7 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch h = torch.max(h_ids) + 1 w = torch.max(w_ids) + 1 - flat_ids =h_ids * w + w_ids + flat_ids = h_ids * w + w_ids out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) @@ -426,7 +426,6 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) - def encode_prompt( self, prompt: Union[str, List[str]], @@ -462,25 +461,19 @@ def encode_prompt( text_ids = text_ids.to(device) return prompt_embeds, text_ids - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") image_latents = self._patchify_latents(image_latents) - latents_bn_mean = ( - self.vae.bn.running_mean.view(1, -1, 1, 1) - .to(image_latents.device, image_latents.dtype) - ) + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents - def prepare_latents( self, batch_size, @@ -492,13 +485,12 @@ def prepare_latents( generator: torch.Generator, latents: Optional[torch.Tensor] = None, ): - # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_latents_channels * 4, height//2, width//2) + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -512,10 +504,9 @@ def prepare_latents( latent_ids = self._prepare_latent_ids(latents) latent_ids = latent_ids.to(device) - latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] return latents, latent_ids - def prepare_image_latents( self, images: List[torch.Tensor], @@ -528,7 +519,7 @@ def prepare_image_latents( for image in images: image = image.to(device=device, dtype=dtype) imagge_latent = self._encode_vae_image(image=image, generator=generator) - image_latents.append(imagge_latent) # (1, 128, 32, 32) + image_latents.append(imagge_latent) # (1, 128, 32, 32) image_latent_ids = self._prepare_image_ids(image_latents) @@ -550,7 +541,6 @@ def prepare_image_latents( return image_latents, image_latent_ids - def check_inputs( self, prompt, @@ -583,7 +573,6 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - @property def guidance_scale(self): return self._guidance_scale @@ -604,7 +593,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -641,10 +629,6 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is - not greater than `1`). guidance_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. @@ -674,10 +658,6 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -697,16 +677,17 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. Examples: Returns: - [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: - [`~pipelines.flux2.Flux2PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is a list with the generated images. + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. """ - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -741,7 +722,7 @@ def __call__( text_encoder_out_layers=text_encoder_out_layers, ) - # 4. process images + # 4. process images if image is not None and not isinstance(image, list): image = [image] @@ -760,12 +741,11 @@ def __call__( multiple_of = self.vae_scale_factor * 2 image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode = "crop") + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") condition_images.append(img) height = height or image_height width = width or image_width - height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -819,7 +799,6 @@ def __call__( guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) - # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -838,21 +817,20 @@ def __call__( if image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) - latent_image_ids = torch.cat([latent_ids, image_latent_ids],dim=1) - + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, image_seq_len, C) + hidden_states=latent_model_input, # (B, image_seq_len, C) timestep=timestep / 1000, guidance=guidance, encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, #B, text_seq_len, 4 - img_ids=latent_image_ids, #B, image_seq_len, 4 + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 joint_attention_kwargs=self._attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred[:, : latents.size(1):] + noise_pred = noise_pred[:, : latents.size(1) :] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -887,13 +865,10 @@ def __call__( torch.save({"pred": latents}, "pred_d.pt") latents = self._unpack_latents_with_ids(latents, latent_ids) - latents_bn_mean = ( - self.vae.bn.running_mean.view(1, -1, 1, 1) - .to(latents.device, latents.dtype) + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype ) - latents_bn_std = torch.sqrt( - self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps - ).to(latents.device, latents.dtype) latents = latents * latents_bn_std + latents_bn_mean latents = self._unpatchify_latents(latents) diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 67c3975ee177..768d10fec72e 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -18,15 +18,10 @@ import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration -from diffusers import ( - Flux2Pipeline, - Flux2Transformer2DModel, - FlowMatchEulerDiscreteScheduler, - AutoencoderKLFlux2 -) +from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel from ..testing_utils import floats_tensor, require_peft_backend - + sys.path.append(".") @@ -92,7 +87,7 @@ def get_dummy_inputs(self, with_generator=True): "width": 8, "max_sequence_length": 8, "output_type": "np", - "text_encoder_out_layers": (1,) + "text_encoder_out_layers": (1,), } if with_generator: pipeline_inputs.update({"generator": generator}) diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index 822eae93b9ad..316d5fa770bb 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -18,10 +18,8 @@ import torch from diffusers import Flux2Transformer2DModel, attention_backend -from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 -from diffusers.models.embeddings import ImageProjection -from ...testing_utils import enable_full_determinism, is_peft_available, torch_device +from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin @@ -140,7 +138,7 @@ def test_flux2_consistency(self, seed=0): def test_gradient_checkpointing_is_applied(self): expected_set = {"Flux2Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - + class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = Flux2Transformer2DModel diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py index 3bda5ba579fd..f6dc52dc11ac 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2.py +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -1,32 +1,22 @@ import unittest -import pytest + import numpy as np +import pytest import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Mistral3Config +from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration from diffusers import ( AutoencoderKLFlux2, - Flux2Pipeline, - FasterCacheConfig, FlowMatchEulerDiscreteScheduler, + Flux2Pipeline, Flux2Transformer2DModel, ) from ...testing_utils import ( - Expectations, - backend_empty_cache, - nightly, - numpy_cosine_similarity_distance, - require_big_accelerator, - slow, torch_device, ) from ..test_pipelines_common import ( - FasterCacheTesterMixin, - FirstBlockCacheTesterMixin, - FluxIPAdapterTesterMixin, PipelineTesterMixin, - PyramidAttentionBroadcastTesterMixin, check_qkv_fused_layers_exist, ) @@ -93,11 +83,13 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): ) torch.manual_seed(0) text_encoder = Mistral3ForConditionalGeneration(config) - tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor") + tokenizer = AutoProcessor.from_pretrained( + "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor" + ) torch.manual_seed(0) vae = AutoencoderKLFlux2( - sample_size=32, + sample_size=32, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock2D",), @@ -135,7 +127,7 @@ def get_dummy_inputs(self, device, seed=0): "width": 8, "max_sequence_length": 8, "output_type": "np", - "text_encoder_out_layers": (1,) + "text_encoder_out_layers": (1,), } return inputs From de993441916bdc8ffe878a5e3f0d7f3cc8ad20f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Nov 2025 11:08:01 +0530 Subject: [PATCH 44/63] up --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b48b26942f9f..9ccf5d0edcca 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -408,6 +408,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLFlux2(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLHunyuanImage(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 19f6c0f58440..d360a0e34f3e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -827,6 +827,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 739c31b85205990adc5f7c61ea960589a209d6a6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Nov 2025 11:16:48 +0530 Subject: [PATCH 45/63] up --- docs/source/en/_toctree.yml | 4 +++ docs/source/en/api/loaders/lora.md | 3 +- .../source/en/api/models/flux2_transformer.md | 19 ++++++++++ docs/source/en/api/pipelines/flux2.md | 35 +++++++++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/models/flux2_transformer.md create mode 100644 docs/source/en/api/pipelines/flux2.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 55fe2a9a379f..95f109e76bea 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -345,6 +345,8 @@ title: DiTTransformer2DModel - local: api/models/easyanimate_transformer3d title: EasyAnimateTransformer3DModel + - local: api/models/flux2_transformer + title: Flux2Transformer2DModel - local: api/models/flux_transformer title: FluxTransformer2DModel - local: api/models/hidream_image_transformer @@ -519,6 +521,8 @@ title: EasyAnimate - local: api/pipelines/flux title: Flux + - local: api/pipelines/flux2 + title: Flux2 - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint - local: api/pipelines/hidream diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 8e0326e0c334..e96486fa8081 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream) -- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen) +- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen). +- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2). - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. > [!TIP] diff --git a/docs/source/en/api/models/flux2_transformer.md b/docs/source/en/api/models/flux2_transformer.md new file mode 100644 index 000000000000..55c7336d1059 --- /dev/null +++ b/docs/source/en/api/models/flux2_transformer.md @@ -0,0 +1,19 @@ + + +# Flux2Transformer2DModel + +A Transformer model for image-like data from [Flux2] (TODO). + +## Flux2Transformer2DModel + +[[autodoc]] Flux2Transformer2DModel diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md new file mode 100644 index 000000000000..903a1d36d2a2 --- /dev/null +++ b/docs/source/en/api/pipelines/flux2.md @@ -0,0 +1,35 @@ + + +# Flux2 + +
+ LoRA + MPS +
+ +Flux2 TODO + +Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](TODO). + +> [!TIP] +> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. +> +> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. + +TODO checkpoints + +## Flux2Pipeline + +[[autodoc]] Flux2Pipeline + - all + - __call__ \ No newline at end of file From ded81a9c528374fdf54aca72656f98dc833ede39 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Nov 2025 11:19:18 +0530 Subject: [PATCH 46/63] up --- docs/source/en/api/pipelines/flux2.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 903a1d36d2a2..4ed5393846a7 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -17,9 +17,9 @@ specific language governing permissions and limitations under the License. MPS -Flux2 TODO +TODO -Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](TODO). +Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here] (TODO). > [!TIP] > Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. From c0086ea14bdbb481bd8867896dcd62345fe8fe4e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Nov 2025 11:24:03 +0530 Subject: [PATCH 47/63] up --- docs/source/en/api/loaders/lora.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index e96486fa8081..9f6ee224e4dd 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -57,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin +## Flux2LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin + ## CogVideoXLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin From 4f1f67a89972efcd440268819485b65563884ee5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 18 Nov 2025 09:27:02 +0100 Subject: [PATCH 48/63] Refactor Flux2Attention into separate classes for double stream and single stream attention --- .../models/transformers/transformer_flux2.py | 253 ++++++++++++++---- tests/pipelines/flux2/test_pipeline_flux2.py | 1 - 2 files changed, 194 insertions(+), 60 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 65ce4921b1b9..d92c23a66327 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -25,6 +25,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_processor import AttentionProcessor from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( @@ -128,31 +129,10 @@ def __call__( encoder_hidden_states: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - mlp_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if attn.parallel_proj_in: - hidden_states = attn.to_qkv_mlp_proj(hidden_states) - qkv, mlp_hidden_states = torch.split( - hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 - ) - query, key, value = qkv.chunk(3, dim=-1) - mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) - - # Get encoder QKV, if available - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None: - if hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk( - 3, dim=-1 - ) - elif attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - else: - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states - ) + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) query = query.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1)) @@ -194,12 +174,8 @@ def __call__( ) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - if attn.parallel_proj_out: - hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) - hidden_states = attn.to_out(hidden_states) - else: - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states @@ -224,10 +200,6 @@ def __init__( eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, - parallel_proj_in: bool = False, - parallel_proj_out: bool = False, - mlp_ratio: float = 4.0, - mlp_mult_factor: int = 2, processor=None, ): super().__init__() @@ -244,32 +216,17 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - self.parallel_proj_in = parallel_proj_in - self.parallel_proj_out = parallel_proj_out - self.mlp_ratio = mlp_ratio - self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) - self.mlp_mult_factor = mlp_mult_factor - - if self.parallel_proj_in: - self.to_qkv_mlp_proj = torch.nn.Linear( - self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias - ) - self.mlp_act_fn = Flux2SwiGLU() - else: - self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) # QK Norm self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - if self.parallel_proj_out: - self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) - else: - self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(torch.nn.Dropout(dropout)) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) @@ -301,6 +258,186 @@ def forward( return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) +class Flux2ParallelSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +# NOTE: we don't inherit from AttentionModuleMixin because fuse_projections doesn't make sense for this Attention +# subclass (as the QKV projections are always fused). This means that we end up copying some useful methods in that +# mixin to this class. Is there a cleaner way (e.g. modifying AttentionModuleMixin to allow child classes to turn off +# fuse_projections)? +class Flux2ParallelSelfAttention(torch.nn.Module): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the + [ViT-22B paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + _default_processor_cls = Flux2ParallelSelfAttnProcessor + _available_processors = [Flux2ParallelSelfAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + # Fused attention output projection + MLP output projection + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + def set_processor(self, processor: AttentionProcessor) -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def set_attention_backend(self, backend: str): + from ..attention_dispatch import AttentionBackendName + + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend.lower()) + self.processor._attention_backend = backend + + @maybe_allow_in_graph class Flux2SingleTransformerBlock(nn.Module): def __init__( @@ -319,7 +456,7 @@ def __init__( # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) # for a visual depiction of this type of transformer block. - self.attn = Flux2Attention( + self.attn = Flux2ParallelSelfAttention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, @@ -327,11 +464,9 @@ def __init__( bias=bias, out_bias=bias, eps=eps, - parallel_proj_in=True, - parallel_proj_out=True, mlp_ratio=mlp_ratio, mlp_mult_factor=2, - processor=Flux2AttnProcessor(), + processor=Flux2ParallelSelfAttnProcessor(), ) def forward( diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py index f6dc52dc11ac..c1e024fed3d6 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2.py +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -131,7 +131,6 @@ def get_dummy_inputs(self, device, seed=0): } return inputs - @pytest.mark.xfail(condition=True, reason="Flux2 uses parallel projections which are incompatible here.") def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() From b8e3760005b8ec6094d013beb4fe14ed3176f3cb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 19 Nov 2025 03:11:43 +0100 Subject: [PATCH 49/63] Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion --- src/diffusers/models/attention.py | 15 +++++++++++++-- tests/pipelines/test_pipelines_common.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5164cf311d3c..91227be9c71c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -105,7 +105,7 @@ def fuse_qkv_projections(self): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") for module in self.modules(): - if isinstance(module, AttentionModuleMixin): + if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion: module.fuse_projections() def unfuse_qkv_projections(self): @@ -114,13 +114,14 @@ def unfuse_qkv_projections(self): > [!WARNING] > This API is 🧪 experimental. """ for module in self.modules(): - if isinstance(module, AttentionModuleMixin): + if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion: module.unfuse_projections() class AttentionModuleMixin: _default_processor_cls = None _available_processors = [] + _supports_qkv_fusion = True fused_projections = False def set_processor(self, processor: AttentionProcessor) -> None: @@ -248,6 +249,11 @@ def fuse_projections(self): """ Fuse the query, key, and value projections into a single projection for efficiency. """ + # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections are + # always fused) + if not self._supports_qkv_fusion: + return + # Skip if already fused if getattr(self, "fused_projections", False): return @@ -307,6 +313,11 @@ def unfuse_projections(self): """ Unfuse the query, key, and value projections back to separate projections. """ + # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections are + # always fused) + if not self._supports_qkv_fusion: + return + # Skip if not fused if not getattr(self, "fused_projections", False): return diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2af4ad0314c3..5757e34ad366 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -103,7 +103,7 @@ def check_qkv_fusion_processors_exist(model): def check_qkv_fused_layers_exist(model, layer_names): is_fused_submodules = [] for submodule in model.modules(): - if not isinstance(submodule, AttentionModuleMixin): + if not isinstance(submodule, AttentionModuleMixin) or not submodule._supports_qkv_fusion: continue is_fused_attribute_set = submodule.fused_projections is_fused_layer = True From 24159acc18b6a185502ff67c4d54976b2fdeb86f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 19 Nov 2025 03:13:05 +0100 Subject: [PATCH 50/63] Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False --- .../models/transformers/transformer_flux2.py | 53 ++----------------- 1 file changed, 3 insertions(+), 50 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index d92c23a66327..93c8fa9bdae2 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -25,7 +25,6 @@ from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin -from ..attention_processor import AttentionProcessor from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( @@ -314,11 +313,7 @@ def __call__( return hidden_states -# NOTE: we don't inherit from AttentionModuleMixin because fuse_projections doesn't make sense for this Attention -# subclass (as the QKV projections are always fused). This means that we end up copying some useful methods in that -# mixin to this class. Is there a cleaner way (e.g. modifying AttentionModuleMixin to allow child classes to turn off -# fuse_projections)? -class Flux2ParallelSelfAttention(torch.nn.Module): +class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): """ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. @@ -329,6 +324,8 @@ class Flux2ParallelSelfAttention(torch.nn.Module): _default_processor_cls = Flux2ParallelSelfAttnProcessor _available_processors = [Flux2ParallelSelfAttnProcessor] + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False def __init__( self, @@ -393,50 +390,6 @@ def forward( kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) - def set_processor(self, processor: AttentionProcessor) -> None: - """ - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": - """ - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - def set_attention_backend(self, backend: str): - from ..attention_dispatch import AttentionBackendName - - available_backends = {x.value for x in AttentionBackendName.__members__.values()} - if backend not in available_backends: - raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) - - backend = AttentionBackendName(backend.lower()) - self.processor._attention_backend = backend - @maybe_allow_in_graph class Flux2SingleTransformerBlock(nn.Module): From cb21400cc167724164b7a232731e45e6ade82877 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 19 Nov 2025 05:56:40 +0100 Subject: [PATCH 51/63] Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion --- src/diffusers/models/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 91227be9c71c..186bb60ea28e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -252,6 +252,9 @@ def fuse_projections(self): # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections are # always fused) if not self._supports_qkv_fusion: + logger.debug( + f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op." + ) return # Skip if already fused From ba52a591dbe358123747a0c9ea84ee973a819a80 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 19 Nov 2025 06:02:15 +0100 Subject: [PATCH 52/63] Address review comments --- src/diffusers/models/attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 186bb60ea28e..8b583d1a1cce 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -249,8 +249,8 @@ def fuse_projections(self): """ Fuse the query, key, and value projections into a single projection for efficiency. """ - # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections are - # always fused) + # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2 + # single stream blocks are always fused) if not self._supports_qkv_fusion: logger.debug( f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op." @@ -316,8 +316,8 @@ def unfuse_projections(self): """ Unfuse the query, key, and value projections back to separate projections. """ - # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections are - # always fused) + # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2 + # single stream blocks are always fused) if not self._supports_qkv_fusion: return From 697a43c9257eff0f065c3b1c8dc6097c9bd73b67 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Nov 2025 20:00:44 +0530 Subject: [PATCH 53/63] Update src/diffusers/pipelines/flux2/pipeline_flux2.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 1ddd8d9c7d82..52823c8c9523 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -549,7 +549,7 @@ def check_inputs( prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + if height is not None and height % (self.vae_scale_factor * 2) != 0 or width is not None and width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) From aba48ddeb6e70f476f2d443fa79338856e2c9f71 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 19 Nov 2025 20:02:10 +0530 Subject: [PATCH 54/63] up --- src/diffusers/models/transformers/transformer_flux2.py | 8 ++++---- src/diffusers/pipelines/flux2/pipeline_flux2.py | 7 ++++++- tests/pipelines/flux2/test_pipeline_flux2.py | 1 - 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 93c8fa9bdae2..18c7b1de294c 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -275,8 +275,8 @@ def __call__( # Parallel in (QKV + MLP in) projection hidden_states = attn.to_qkv_mlp_proj(hidden_states) qkv, mlp_hidden_states = torch.split( - hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 - ) + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) # Handle the attention logic query, key, value = qkv.chunk(3, dim=-1) @@ -318,8 +318,8 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) - input projections, and the attention output projections are fused to the FF output projections. See the - [ViT-22B paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. """ _default_processor_cls = Flux2ParallelSelfAttnProcessor diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 52823c8c9523..d4196f3ecf78 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -549,7 +549,12 @@ def check_inputs( prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if height is not None and height % (self.vae_scale_factor * 2) != 0 or width is not None and width % (self.vae_scale_factor * 2) != 0: + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py index c1e024fed3d6..4404dbc51047 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2.py +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -1,7 +1,6 @@ import unittest import numpy as np -import pytest import torch from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration From 963ed57b4c0c48030826a7996025bba785c1f847 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Wed, 19 Nov 2025 06:37:53 -0800 Subject: [PATCH 55/63] Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12) --- src/diffusers/models/transformers/transformer_flux2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 18c7b1de294c..f25700a2e0e9 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -391,7 +391,6 @@ def forward( return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) -@maybe_allow_in_graph class Flux2SingleTransformerBlock(nn.Module): def __init__( self, @@ -461,7 +460,6 @@ def forward( return hidden_states -@maybe_allow_in_graph class Flux2TransformerBlock(nn.Module): def __init__( self, From 997dfc24b0ed4fde4251e50378d7ae42ae6ac794 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 19 Nov 2025 20:09:42 +0530 Subject: [PATCH 56/63] up --- src/diffusers/models/transformers/transformer_flux2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index f25700a2e0e9..d2b3d8a733f3 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -22,7 +22,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn From 8d8bb3dad791d64b5562c1432492dfa16eb6ff1d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 20 Nov 2025 17:21:56 +0530 Subject: [PATCH 57/63] support ostris loras. (#13) --- .../loaders/lora_conversion_utils.py | 86 +++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 5 ++ 2 files changed, 91 insertions(+) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 2807416f97ae..dc7487e302c7 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2265,3 +2265,89 @@ def get_alpha_scales(down_weight, alpha_key): converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} return converted_state_dict + + +def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): + converted_state_dict = {} + + prefix = "diffusion_model." + original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} + + num_double_layers = 8 + num_single_layers = 48 + lora_keys = ("lora_A", "lora_B") + attn_types = ("img_attn", "txt_attn") + + for sl in range(num_single_layers): + single_block_prefix = f"single_blocks.{sl}" + attn_prefix = f"single_transformer_blocks.{sl}.attn" + + for lora_key in lora_keys: + converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop( + f"{single_block_prefix}.linear1.{lora_key}.weight" + ) + + converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop( + f"{single_block_prefix}.linear2.{lora_key}.weight" + ) + + for dl in range(num_double_layers): + transformer_block_prefix = f"transformer_blocks.{dl}" + + for lora_key in lora_keys: + for attn_type in attn_types: + attn_prefix = f"{transformer_block_prefix}.attn" + qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight" + fused_qkv_weight = original_state_dict.pop(qkv_key) + + if lora_key == "lora_A": + diff_attn_proj_keys = ( + ["to_q", "to_k", "to_v"] + if attn_type == "img_attn" + else ["add_q_proj", "add_k_proj", "add_v_proj"] + ) + for proj_key in diff_attn_proj_keys: + converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat( + [fused_qkv_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0) + + if attn_type == "img_attn": + converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v]) + else: + converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v]) + + proj_mappings = [ + ("img_attn.proj", "attn.to_out.0"), + ("txt_attn.proj", "attn.to_add_out"), + ] + for org_proj, diff_proj in proj_mappings: + for lora_key in lora_keys: + original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight" + diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight" + converted_state_dict[diffusers_key] = original_state_dict.pop(original_key) + + mlp_mappings = [ + ("img_mlp.0", "ff.linear_in"), + ("img_mlp.2", "ff.linear_out"), + ("txt_mlp.0", "ff_context.linear_in"), + ("txt_mlp.2", "ff_context.linear_out"), + ] + for org_mlp, diff_mlp in mlp_mappings: + for lora_key in lora_keys: + original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight" + diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight" + converted_state_dict[diffusers_key] = original_state_dict.pop(original_key) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a807ddb5a0d2..a1bb704b0626 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -45,6 +45,7 @@ _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, @@ -5144,6 +5145,10 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) + if is_ai_toolkit: + state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) + out = (state_dict, metadata) if return_lora_metadata else state_dict return out From 454cef8d9cd199af65fc2ee71d4ee050011cbb18 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 24 Nov 2025 07:47:59 +0530 Subject: [PATCH 58/63] up --- docs/source/en/api/models/flux2_transformer.md | 2 +- docs/source/en/api/pipelines/flux2.md | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/models/flux2_transformer.md b/docs/source/en/api/models/flux2_transformer.md index 55c7336d1059..c85681d2b011 100644 --- a/docs/source/en/api/models/flux2_transformer.md +++ b/docs/source/en/api/models/flux2_transformer.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # Flux2Transformer2DModel -A Transformer model for image-like data from [Flux2] (TODO). +A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev). ## Flux2Transformer2DModel diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 4ed5393846a7..87f0fd92658e 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -17,17 +17,15 @@ specific language governing permissions and limitations under the License. MPS -TODO +Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch! -Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here] (TODO). +Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2). > [!TIP] > Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. > > [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. -TODO checkpoints - ## Flux2Pipeline [[autodoc]] Flux2Pipeline From fc1bd8919415eaa552fbaaa6096e98b2e058b3c8 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Tue, 25 Nov 2025 02:31:14 +0000 Subject: [PATCH 59/63] update schdule --- .../pipelines/flux2/pipeline_flux2.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index d4196f3ecf78..8335b6ec6878 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -79,18 +79,18 @@ def format_text_input(prompts: List[str], system_message: str = None): ] -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 0.00020573, 1.85733333 + a2, b2 = 0.00016927, 0.45666666 + + m_200 = a2 * image_seq_len + b2 + m_30 = a1 * image_seq_len + b1 + + a = (m_200 - m_30) / 170.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + return float(mu) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps @@ -608,7 +608,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, - guidance_scale: Optional[float] = 2.5, + guidance_scale: Optional[float] = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -783,13 +783,10 @@ def __call__( if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: sigmas = None image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) + mu = compute_empirical_mu( + image_seq_len=image_seq_len, + num_steps= num_inference_steps, + ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, From 29b02b80f030607d6c2bc1965852211ec2617923 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 25 Nov 2025 08:10:53 +0530 Subject: [PATCH 60/63] up --- docs/source/en/api/pipelines/flux2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 87f0fd92658e..90eaedc245b7 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License. Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch! -Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2). +Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2-dev). > [!TIP] > Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. From 823f4c3fe2e9dba4c2cd33863fabdbf3e9c9f56f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 25 Nov 2025 20:12:17 +0530 Subject: [PATCH 61/63] up (#17) --- .../pipelines/flux2/pipeline_flux2.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 8335b6ec6878..676bf6d98429 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -23,11 +23,7 @@ from ...loaders import Flux2LoraLoaderMixin from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .image_processor import Flux2ImageProcessor @@ -79,17 +75,21 @@ def format_text_input(prompts: List[str], system_message: str = None): ] - def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: - a1, b1 = 0.00020573, 1.85733333 + a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + m_200 = a2 * image_seq_len + b2 - m_30 = a1 * image_seq_len + b1 + m_10 = a1 * image_seq_len + b1 - a = (m_200 - m_30) / 170.0 + a = (m_200 - m_10) / 190.0 b = m_200 - 200.0 * a mu = a * num_steps + b + return float(mu) @@ -171,7 +171,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin): r""" The Flux2 pipeline for text-to-image generation. - Reference: TODO + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) Args: transformer ([`Flux2Transformer2DModel`]): @@ -783,10 +783,7 @@ def __call__( if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: sigmas = None image_seq_len = latents.shape[1] - mu = compute_empirical_mu( - image_seq_len=image_seq_len, - num_steps= num_inference_steps, - ) + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, From 5419877cf60e85e04f58d2d7624d017fa26a59f1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 25 Nov 2025 20:22:42 +0530 Subject: [PATCH 62/63] add training scripts (#16) * add training scripts Co-authored-by: Linoy Tsaban * model cpu offload in validation. * add flux.2 readme * add img2img and tests * cpu offload in log validation * Apply suggestions from code review * fix * up * fixes * remove i2i training tests for now. --------- Co-authored-by: Linoy Tsaban Co-authored-by: linoytsaban --- examples/dreambooth/README_flux2.md | 315 +++ .../dreambooth/test_dreambooth_lora_flux2.py | 262 +++ .../dreambooth/train_dreambooth_lora_flux2.py | 1914 +++++++++++++++++ .../train_dreambooth_lora_flux2_img2img.py | 1831 ++++++++++++++++ 4 files changed, 4322 insertions(+) create mode 100644 examples/dreambooth/README_flux2.md create mode 100644 examples/dreambooth/test_dreambooth_lora_flux2.py create mode 100644 examples/dreambooth/train_dreambooth_lora_flux2.py create mode 100644 examples/dreambooth/train_dreambooth_lora_flux2_img2img.py diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md new file mode 100644 index 000000000000..1a56196da5d7 --- /dev/null +++ b/examples/dreambooth/README_flux2.md @@ -0,0 +1,315 @@ +# DreamBooth training example for FLUX.2 [dev] + +[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. + +The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2-dev). + +> [!NOTE] +> **Memory consumption** +> +> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - +> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training. + +> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: +> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md) +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training) + +> [!NOTE] +> **Gated model** +> +> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: + +```bash +hf auth login +``` + +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run +```bash +pip install -r requirements_flux.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. + +As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training: + +## Memory Optimizations +> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption. +> However some techniques may be mutually exclusive so be sure to check before launching a training run. +### Remote Text Encoder +Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. +This way, the text encoder model is not loaded into memory during training. +> [!NOTE] +> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. +### CPU Offloading +To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. +### Latent Caching +Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`. +### QLoRA: Low Precision Training with Quantization +Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags: +- **FP8 training** with `torchao`: +enable FP8 training by passing `--do_fp8_training`. +> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. +> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc. +- **NF4 training** with `bitsandbytes`: +Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing: +`--bnb_quantization_config_path` to enable 4-bit NF4 quantization. +### Gradient Checkpointing and Accumulation +* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. +by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. +* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. +Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. +### 8-bit-Adam Optimizer +When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. +Make sure to install `bitsandbytes` if you want to do so. +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. +### Precision of saved LoRA layers +By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. +This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. + + +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-dev" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2" + +accelerate launch train_dreambooth_flux.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --remote_text_encoder \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +> [!NOTE] +> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. + +## LoRA + DreamBooth + +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Prodigy Optimizer +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify - +```bash +--optimizer="prodigy" +``` +> [!TIP] +> When using prodigy it's generally good practice to set- `--learning_rate=1.0` + +To perform DreamBooth with LoRA, run: + +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-dev" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2-lora" + +accelerate launch train_dreambooth_lora_flux.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --remote_text_encoder \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --optimizer="prodigy" \ + --learning_rate=1. \ + --report_to="wandb" \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +### LoRA Rank and Alpha +Two key LoRA hyperparameters are LoRA rank and LoRA alpha. +- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters). +- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank. +- lora_alpha vs. rank: +This ratio dictates the LoRA's effective strength: +lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16) +lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16) +lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16) + +> [!TIP] +> A common starting point is to set `lora_alpha` equal to `rank`. +> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) +> to give the LoRA updates more influence without increasing parameter count. +> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank` +> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case. + +### Target Modules +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string +the exact modules for LoRA training. Here are some examples of target modules you can provide: +- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` +- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` +- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string: +> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` +> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` +> [!NOTE] +> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + + + +## Training Image-to-Image + +Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too. + +**important** + +**Important** +To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . + +To start, you must have a dataset containing triplets: + +* Condition image - the input image to be transformed. +* Target image - the desired output image after transformation. +* Instruction - a text prompt describing the transformation from the condition image to the target image. + +[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training: + +```bash +accelerate launch train_dreambooth_lora_flux2_img2img.py \ + --pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \ + --output_dir="flux2-i2i" \ + --dataset_name="kontext-community/relighting" \ + --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \ + --do_fp8_training \ + --gradient_checkpointing \ + --remote_text_encoder \ + --cache_latents \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --optimizer="adamw" \ + --use_8bit_adam \ + --cache_latents \ + --learning_rate=1e-4 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=200 \ + --max_train_steps=1000 \ + --rank=16\ + --seed="0" +``` + +More generally, when performing I2I fine-tuning, we expect you to: + +* Have a dataset `kontext-community/relighting` +* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training + +### Misc notes + +* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it. +### Aspect Ratio Bucketing +we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency. + +To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as: + +`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672" +` +Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 diff --git a/examples/dreambooth/test_dreambooth_lora_flux2.py b/examples/dreambooth/test_dreambooth_lora_flux2.py new file mode 100644 index 000000000000..80a0b502f9a2 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_flux2.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +import tempfile + +import safetensors + +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFlux2(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "dog" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2" + script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py" + transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj" + + def test_dreambooth_lora_flux2(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict + starts_with_transformer = all( + key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --checkpointing_steps=2 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py new file mode 100644 index 000000000000..733abe16d2eb --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -0,0 +1,1914 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Mistral3ForConditionalGeneration, PixtralProcessor + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2Pipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + offload_models, + parse_buckets_string, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.36.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + quant_training=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux2 DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +Quant training? {quant_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2", + "flux2-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--text_encoder_out_layers", + type=int, + nargs="+", + default=[10, 20, 30], + help="Text encoder hidden layers to compute the final text embeddings.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--remote_text_encoder", + action="store_true", + help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.", + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image = self.train_transform( + image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + + return image + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = PixtralProcessor.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + if not args.remote_text_encoder: + text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + if not args.remote_text_encoder: + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + Flux2Pipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, + max_sequence_length=args.max_sequence_length, + text_encoder_out_layers=args.text_encoder_out_layers, + ) + return prompt_embeds, text_ids + + def compute_remote_text_embeddings(prompts): + import io + + import requests + + if args.hub_token is not None: + hf_token = args.hub_token + else: + from huggingface_hub import get_token + + hf_token = get_token() + if hf_token is None: + raise ValueError( + "No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token" + ) + + def _encode_single(prompt: str): + response = requests.post( + "https://remote-text-encoder-flux-2.huggingface.co/predict", + json={"prompt": prompt}, + headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}, + ) + assert response.status_code == 200, f"{response.status_code=}" + return torch.load(io.BytesIO(response.content)) + + try: + if isinstance(prompts, (list, tuple)): + embeds = [_encode_single(p) for p in prompts] + prompt_embeds = torch.cat(embeds, dim=0) + else: + prompt_embeds = _encode_single(prompts) + + text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device) + prompt_embeds = prompt_embeds.to(accelerator.device) + return prompt_embeds, text_ids + + except Exception as e: + raise RuntimeError("Remote text encoder inference failed.") from e + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + if args.remote_text_encoder: + instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if args.remote_text_encoder: + class_prompt_hidden_states, class_text_ids = compute_remote_text_embeddings(args.class_prompt) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_prompt_hidden_states, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) + validation_embeddings = {} + if args.validation_prompt is not None: + if args.remote_text_encoder: + (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = ( + compute_remote_text_embeddings(args.validation_prompt) + ) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.remote_text_encoder: + prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if not args.remote_text_encoder: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-lora" + args_cp = vars(args).copy() + args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"]) + accelerator.init_trackers(tracker_name, config=args_cp) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + + model_input = Flux2Pipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) + + # handle guidance + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + + model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=None, + tokenizer=None, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + modules_to_save = {} + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer + + Flux2Pipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + images = None + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + quant_training = None + if args.do_fp8_training: + quant_training = "FP8 TorchAO" + elif args.bnb_quantization_config_path: + quant_training = "BitsandBytes" + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + quant_training=quant_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py new file mode 100644 index 000000000000..32bce9531b71 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -0,0 +1,1831 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Mistral3ForConditionalGeneration, PixtralProcessor + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2Pipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor +from diffusers.training_utils import ( + _collate_lora_metadata, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + offload_models, + parse_buckets_string, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, + load_image, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.36.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + fp8_training=False, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +FP8 training? {fp8_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2", + "flux2-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--cond_image_column", + type=str, + default=None, + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="path to an image that is used during validation as the condition image to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--remote_text_encoder", + action="store_true", + help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.", + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.cond_image_column is None: + raise ValueError( + "you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example." + ) + else: + assert args.image_column is not None + assert args.caption_column is not None + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + self.cond_images = [] + for i, img in enumerate(instance_images): + self.instance_images.extend(itertools.repeat(img, repeats)) + if args.dataset_name is not None and cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) + + self.pixel_values = [] + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + dest_image = None + if self.cond_images: # todo: take care of max area for buckets + dest_image = self.cond_images[i] + image_width, image_height = dest_image.size + if image_width * image_height > 1024 * 1024: + dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024) + image_width, image_height = dest_image.size + + multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + dest_image = Flux2ImageProcessor.image_processor.preprocess( + dest_image, height=image_height, width=image_width, resize_mode="crop" + ) + + dest_image = exif_transpose(dest_image) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + if dest_image is not None: + self.cond_pixel_values.append((dest_image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + return example + + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else (image, None) + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = PixtralProcessor.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + if not args.remote_text_encoder: + text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + if not args.remote_text_encoder: + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + Flux2Pipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, max_sequence_length=args.max_sequence_length + ) + # prompt_embeds = prompt_embeds.to(accelerator.device) + # text_ids = text_ids.to(accelerator.device) + return prompt_embeds, text_ids + + def compute_remote_text_embeddings(prompts: str | list[str]): + import io + + import requests + + if args.hub_token is not None: + hf_token = args.hub_token + else: + from huggingface_hub import get_token + + hf_token = get_token() + if hf_token is None: + raise ValueError( + "No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token" + ) + + def _encode_single(prompt: str): + response = requests.post( + "https://remote-text-encoder-flux-2.huggingface.co/predict", + json={"prompt": prompt}, + headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}, + ) + assert response.status_code == 200, f"{response.status_code=}" + return torch.load(io.BytesIO(response.content)) + + try: + if isinstance(prompts, (list, tuple)): + embeds = [_encode_single(p) for p in prompts] + prompt_embeds = torch.cat(embeds, dim=0).to(accelerator.device) + else: + prompt_embeds = _encode_single(prompts).to(accelerator.device) + + text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device) + return prompt_embeds, text_ids + + except Exception as e: + raise RuntimeError("Remote text encoder inference failed.") from e + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + if args.remote_text_encoder: + instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + validation_image = load_image(args.validation_image_path).convert("RGB") + validation_kwargs = {"image": validation_image} + if args.validation_prompt is not None: + if args.remote_text_encoder: + validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + validation_kwargs["prompt_embeds"] = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + cond_latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.remote_text_encoder: + prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if not args.remote_text_encoder: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-image2img-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + cond_model_input = cond_latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) + + model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + + # model_input = Flux2Pipeline._encode_vae_image(pixel_values) + + model_input = Flux2Pipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + cond_model_input = Flux2Pipeline._patchify_latents(cond_model_input) + cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + device=cond_model_input.device + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) + packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + + # concatenate the model inputs with the cond inputs + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) + model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) + + # handle guidance + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + + model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=None, + tokenizer=None, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + modules_to_save = {} + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer + + Flux2Pipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + fp8_training=args.do_fp8_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 0a5a4773807a4548959a7803e7959eee2bfdd7c0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 25 Nov 2025 20:35:04 +0530 Subject: [PATCH 63/63] up --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a1bb704b0626..4302d145a6c5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5095,7 +5095,6 @@ class Flux2LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],