diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 24420af8e490..a71bc7d864a1 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -349,6 +349,8 @@
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
+ - local: api/models/flux2_transformer
+ title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer
@@ -525,6 +527,8 @@
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
+ - local: api/pipelines/flux2
+ title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 8e0326e0c334..9f6ee224e4dd 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
-- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
+- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
+- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
+## Flux2LoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
+
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
diff --git a/docs/source/en/api/models/flux2_transformer.md b/docs/source/en/api/models/flux2_transformer.md
new file mode 100644
index 000000000000..c85681d2b011
--- /dev/null
+++ b/docs/source/en/api/models/flux2_transformer.md
@@ -0,0 +1,19 @@
+
+
+# Flux2Transformer2DModel
+
+A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
+
+## Flux2Transformer2DModel
+
+[[autodoc]] Flux2Transformer2DModel
diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md
new file mode 100644
index 000000000000..90eaedc245b7
--- /dev/null
+++ b/docs/source/en/api/pipelines/flux2.md
@@ -0,0 +1,33 @@
+
+
+# Flux2
+
+
+

+

+
+
+Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
+
+Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2-dev).
+
+> [!TIP]
+> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
+>
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+## Flux2Pipeline
+
+[[autodoc]] Flux2Pipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md
new file mode 100644
index 000000000000..1a56196da5d7
--- /dev/null
+++ b/examples/dreambooth/README_flux2.md
@@ -0,0 +1,315 @@
+# DreamBooth training example for FLUX.2 [dev]
+
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
+
+The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2-dev).
+
+> [!NOTE]
+> **Memory consumption**
+>
+> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
+> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
+
+> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
+> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
+> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
+
+> [!NOTE]
+> **Gated model**
+>
+> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+hf auth login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_flux.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
+
+## Memory Optimizations
+> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
+> However some techniques may be mutually exclusive so be sure to check before launching a training run.
+### Remote Text Encoder
+Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
+This way, the text encoder model is not loaded into memory during training.
+> [!NOTE]
+> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
+### CPU Offloading
+To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
+### Latent Caching
+Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
+### QLoRA: Low Precision Training with Quantization
+Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
+- **FP8 training** with `torchao`:
+enable FP8 training by passing `--do_fp8_training`.
+> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
+> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
+- **NF4 training** with `bitsandbytes`:
+Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
+`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
+### Gradient Checkpointing and Accumulation
+* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
+by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
+* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
+Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
+### 8-bit-Adam Optimizer
+When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
+Make sure to install `bitsandbytes` if you want to do so.
+### Image Resolution
+An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
+Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
+### Precision of saved LoRA layers
+By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
+This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
+
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.2-dev"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux2"
+
+accelerate launch train_dreambooth_flux.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --do_fp8_training \
+ --gradient_checkpointing \
+ --remote_text_encoder \
+ --cache_latents \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --use_8bit_adam \
+ --gradient_accumulation_steps=4 \
+ --optimizer="adamW" \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+> [!NOTE]
+> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
+
+## LoRA + DreamBooth
+
+[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Prodigy Optimizer
+Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
+By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
+
+to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
+```bash
+--optimizer="prodigy"
+```
+> [!TIP]
+> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
+
+To perform DreamBooth with LoRA, run:
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.2-dev"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux2-lora"
+
+accelerate launch train_dreambooth_lora_flux.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --do_fp8_training \
+ --gradient_checkpointing \
+ --remote_text_encoder \
+ --cache_latents \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --optimizer="prodigy" \
+ --learning_rate=1. \
+ --report_to="wandb" \
+ --lr_scheduler="constant_with_warmup" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+### LoRA Rank and Alpha
+Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
+- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
+- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
+- lora_alpha vs. rank:
+This ratio dictates the LoRA's effective strength:
+lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
+lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
+lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
+
+> [!TIP]
+> A common starting point is to set `lora_alpha` equal to `rank`.
+> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
+> to give the LoRA updates more influence without increasing parameter count.
+> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
+> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
+
+### Target Modules
+When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
+More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
+the exact modules for LoRA training. Here are some examples of target modules you can provide:
+- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
+- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
+- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
+> [!NOTE]
+> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
+> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
+> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
+> [!NOTE]
+> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
+
+
+
+## Training Image-to-Image
+
+Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
+
+**important**
+
+**Important**
+To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+
+To start, you must have a dataset containing triplets:
+
+* Condition image - the input image to be transformed.
+* Target image - the desired output image after transformation.
+* Instruction - a text prompt describing the transformation from the condition image to the target image.
+
+[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
+
+```bash
+accelerate launch train_dreambooth_lora_flux2_img2img.py \
+ --pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
+ --output_dir="flux2-i2i" \
+ --dataset_name="kontext-community/relighting" \
+ --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
+ --do_fp8_training \
+ --gradient_checkpointing \
+ --remote_text_encoder \
+ --cache_latents \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --optimizer="adamw" \
+ --use_8bit_adam \
+ --cache_latents \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant_with_warmup" \
+ --lr_warmup_steps=200 \
+ --max_train_steps=1000 \
+ --rank=16\
+ --seed="0"
+```
+
+More generally, when performing I2I fine-tuning, we expect you to:
+
+* Have a dataset `kontext-community/relighting`
+* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
+
+### Misc notes
+
+* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
+### Aspect Ratio Bucketing
+we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
+
+To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
+
+`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
+`
+Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
diff --git a/examples/dreambooth/test_dreambooth_lora_flux2.py b/examples/dreambooth/test_dreambooth_lora_flux2.py
new file mode 100644
index 000000000000..80a0b502f9a2
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_flux2.py
@@ -0,0 +1,262 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "dog"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
+ script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
+ transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
+
+ def test_dreambooth_lora_flux2(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
+ starts_with_transformer = all(
+ key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --max_sequence_length 8
+ --checkpointing_steps=2
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+ def test_dreambooth_lora_with_metadata(self):
+ # Use a `lora_alpha` that is different from `rank`.
+ lora_alpha = 8
+ rank = 4
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --lora_alpha={lora_alpha}
+ --rank={rank}
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
+ self.assertTrue(os.path.isfile(state_dict_file))
+
+ # Check if the metadata was properly serialized.
+ with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
+ metadata = f.metadata() or {}
+
+ metadata.pop("format", None)
+ raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
+ if raw:
+ raw = json.loads(raw)
+
+ loaded_lora_alpha = raw["transformer.lora_alpha"]
+ self.assertTrue(loaded_lora_alpha == lora_alpha)
+ loaded_lora_rank = raw["transformer.r"]
+ self.assertTrue(loaded_lora_rank == rank)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py
new file mode 100644
index 000000000000..733abe16d2eb
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_flux2.py
@@ -0,0 +1,1914 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
+import argparse
+import copy
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import BatchSampler
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+from tqdm.auto import tqdm
+from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
+
+import diffusers
+from diffusers import (
+ AutoencoderKLFlux2,
+ BitsAndBytesConfig,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2Pipeline,
+ Flux2Transformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _collate_lora_metadata,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ find_nearest_bucket,
+ free_memory,
+ offload_models,
+ parse_buckets_string,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.36.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ quant_training=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux2 DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
+
+Quant training? {quant_training}
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux2",
+ "flux2-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(dtype=torch_dtype)
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ images = []
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt_embeds=pipeline_args["prompt_embeds"],
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def module_filter_fn(mod: torch.nn.Module, fqn: str):
+ # don't convert the output module
+ if fqn == "proj_out":
+ return False
+ # don't convert linear modules with weight dimensions not divisible by 16
+ if isinstance(mod, torch.nn.Linear):
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
+ return False
+ return True
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--bnb_quantization_config_path",
+ type=str,
+ default=None,
+ help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
+ )
+ parser.add_argument(
+ "--do_fp8_training",
+ action="store_true",
+ help="if we are doing FP8 training.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--text_encoder_out_layers",
+ type=int,
+ nargs="+",
+ default=[10, 20, 30],
+ help="Text encoder hidden layers to compute the final text embeddings.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--skip_final_inference",
+ default=False,
+ action="store_true",
+ help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
+ )
+ parser.add_argument(
+ "--final_validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=4,
+ help="LoRA alpha to be used for additional scaling.",
+ )
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
+
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--aspect_ratio_buckets",
+ type=str,
+ default=None,
+ help=(
+ "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
+ "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
+ "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument(
+ "--remote_text_encoder",
+ action="store_true",
+ help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.",
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+ if args.do_fp8_training and args.bnb_quantization_config_path:
+ raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ buckets=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ self.buckets = buckets
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ for i, image in enumerate(self.instance_images):
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ width, height = image.size
+
+ # Find the closest bucket
+ bucket_idx = find_nearest_bucket(height, width, self.buckets)
+ target_height, target_width = self.buckets[bucket_idx]
+ self.size = (target_height, target_width)
+
+ # based on the bucket assignment, define the transformations
+ image = self.train_transform(
+ image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
+ )
+ self.pixel_values.append((image, bucket_idx))
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["bucket_idx"] = bucket_idx
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+
+ return image
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class BucketBatchSampler(BatchSampler):
+ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # Group indices by bucket
+ self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
+ for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
+ self.bucket_indices[bucket_idx].append(idx)
+
+ self.sampler_len = 0
+ self.batches = []
+
+ # Pre-generate batches for each bucket
+ for indices_in_bucket in self.bucket_indices:
+ # Shuffle indices within the bucket
+ random.shuffle(indices_in_bucket)
+ # Create batches
+ for i in range(0, len(indices_in_bucket), self.batch_size):
+ batch = indices_in_bucket[i : i + self.batch_size]
+ if len(batch) < self.batch_size and self.drop_last:
+ continue # Skip partial batch if drop_last is True
+ self.batches.append(batch)
+ self.sampler_len += 1 # Count the number of batches
+
+ def __iter__(self):
+ # Shuffle the order of the batches each epoch
+ random.shuffle(self.batches)
+ for batch in self.batches:
+ yield batch
+
+ def __len__(self):
+ return self.sampler_len
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+ if args.do_fp8_training:
+ from torchao.float8 import Float8LinearConfig, convert_to_float8_training
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+
+ pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
+ images = pipeline(prompt=example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer = PixtralProcessor.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ revision=args.revision,
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKLFlux2.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
+ accelerator.device
+ )
+
+ quantization_config = None
+ if args.bnb_quantization_config_path is not None:
+ with open(args.bnb_quantization_config_path, "r") as f:
+ config_kwargs = json.load(f)
+ if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
+ config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
+ quantization_config = BitsAndBytesConfig(**config_kwargs)
+
+ transformer = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=quantization_config,
+ torch_dtype=weight_dtype,
+ )
+ if args.bnb_quantization_config_path is not None:
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ if not args.remote_text_encoder:
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder.requires_grad_(False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
+ # flux vae is stable in bf16 so load it in weight_dtype to reduce memory
+ vae.to(**to_kwargs)
+ # we never offload the transformer to CPU, so we can just use the accelerator device
+ transformer_to_kwargs = (
+ {"device": accelerator.device}
+ if args.bnb_quantization_config_path is not None
+ else {"device": accelerator.device, "dtype": weight_dtype}
+ )
+ transformer.to(**transformer_to_kwargs)
+ if args.do_fp8_training:
+ convert_to_float8_training(
+ transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
+ )
+
+ if not args.remote_text_encoder:
+ text_encoder.to(**to_kwargs)
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ scheduler=None,
+ revision=args.revision,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ modules_to_save = {}
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ Flux2Pipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ if args.aspect_ratio_buckets is not None:
+ buckets = parse_buckets_string(args.aspect_ratio_buckets)
+ else:
+ buckets = [(args.resolution, args.resolution)]
+ logger.info(f"Using parsed aspect ratio buckets: {buckets}")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ buckets=buckets,
+ )
+ batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ with torch.no_grad():
+ prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ prompt=prompt,
+ max_sequence_length=args.max_sequence_length,
+ text_encoder_out_layers=args.text_encoder_out_layers,
+ )
+ return prompt_embeds, text_ids
+
+ def compute_remote_text_embeddings(prompts):
+ import io
+
+ import requests
+
+ if args.hub_token is not None:
+ hf_token = args.hub_token
+ else:
+ from huggingface_hub import get_token
+
+ hf_token = get_token()
+ if hf_token is None:
+ raise ValueError(
+ "No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token"
+ )
+
+ def _encode_single(prompt: str):
+ response = requests.post(
+ "https://remote-text-encoder-flux-2.huggingface.co/predict",
+ json={"prompt": prompt},
+ headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"},
+ )
+ assert response.status_code == 200, f"{response.status_code=}"
+ return torch.load(io.BytesIO(response.content))
+
+ try:
+ if isinstance(prompts, (list, tuple)):
+ embeds = [_encode_single(p) for p in prompts]
+ prompt_embeds = torch.cat(embeds, dim=0)
+ else:
+ prompt_embeds = _encode_single(prompts)
+
+ text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ return prompt_embeds, text_ids
+
+ except Exception as e:
+ raise RuntimeError("Remote text encoder inference failed.") from e
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ if args.remote_text_encoder:
+ instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if args.remote_text_encoder:
+ class_prompt_hidden_states, class_text_ids = compute_remote_text_embeddings(args.class_prompt)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoding_pipeline
+ )
+ validation_embeddings = {}
+ if args.validation_prompt is not None:
+ if args.remote_text_encoder:
+ (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = (
+ compute_remote_text_embeddings(args.validation_prompt)
+ )
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings(
+ args.validation_prompt, text_encoding_pipeline
+ )
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+
+ # if cache_latents is set to True, we encode images to latents and store them.
+ # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
+ # we encode them in advance as well.
+ precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
+ if precompute_latents:
+ prompt_embeds_cache = []
+ text_ids_cache = []
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ if args.cache_latents:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ if train_dataset.custom_instance_prompts:
+ if args.remote_text_encoder:
+ prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ prompt_embeds_cache.append(prompt_embeds)
+ text_ids_cache.append(text_ids)
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if args.cache_latents:
+ vae = vae.to("cpu")
+ del vae
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if not args.remote_text_encoder:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ del text_encoder, tokenizer
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux2-lora"
+ args_cp = vars(args).copy()
+ args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"])
+ accelerator.init_trackers(tracker_name, config=args_cp)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ prompts = batch["prompts"]
+
+ with accelerator.accumulate(models_to_accumulate):
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds = prompt_embeds_cache[step]
+ text_ids = text_ids_cache[step]
+ else:
+ num_repeat_elements = len(prompts)
+ prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
+ text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].mode()
+ else:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.mode()
+
+ model_input = Flux2Pipeline._patchify_latents(model_input)
+ model_input = (model_input - latents_bn_mean) / latents_bn_std
+
+ model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # [B, C, H, W] -> [B, H*W, C]
+ packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
+
+ # handle guidance
+ guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=model_input_ids, # B, image_seq_len, 4
+ return_dict=False,
+ )[0]
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
+
+ model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=None,
+ tokenizer=None,
+ transformer=unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_embeddings,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ del pipeline
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ modules_to_save = {}
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+ modules_to_save["transformer"] = transformer
+
+ Flux2Pipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ images = []
+ run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
+ should_run_final_inference = not args.skip_final_inference and run_validation
+ if should_run_final_inference:
+ pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_embeddings,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+ images = None
+ del pipeline
+ free_memory()
+
+ validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ quant_training = None
+ if args.do_fp8_training:
+ quant_training = "FP8 TorchAO"
+ elif args.bnb_quantization_config_path:
+ quant_training = "BitsandBytes"
+ save_model_card(
+ (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=validation_prompt,
+ repo_folder=args.output_dir,
+ quant_training=quant_training,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
new file mode 100644
index 000000000000..32bce9531b71
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
@@ -0,0 +1,1831 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
+import argparse
+import copy
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import BatchSampler
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+from tqdm.auto import tqdm
+from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
+
+import diffusers
+from diffusers import (
+ AutoencoderKLFlux2,
+ BitsAndBytesConfig,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2Pipeline,
+ Flux2Transformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
+from diffusers.training_utils import (
+ _collate_lora_metadata,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ find_nearest_bucket,
+ free_memory,
+ offload_models,
+ parse_buckets_string,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+ load_image,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.36.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ fp8_training=False,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
+
+FP8 training? {fp8_training}
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux2",
+ "flux2-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(dtype=torch_dtype)
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ images = []
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ image=pipeline_args["image"],
+ prompt_embeds=pipeline_args["prompt_embeds"],
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def module_filter_fn(mod: torch.nn.Module, fqn: str):
+ # don't convert the output module
+ if fqn == "proj_out":
+ return False
+ # don't convert linear modules with weight dimensions not divisible by 16
+ if isinstance(mod, torch.nn.Linear):
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
+ return False
+ return True
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--bnb_quantization_config_path",
+ type=str,
+ default=None,
+ help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
+ )
+ parser.add_argument(
+ "--do_fp8_training",
+ action="store_true",
+ help="if we are doing FP8 training.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--cond_image_column",
+ type=str,
+ default=None,
+ help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help="path to an image that is used during validation as the condition image to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--skip_final_inference",
+ default=False,
+ action="store_true",
+ help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
+ )
+ parser.add_argument(
+ "--final_validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=4,
+ help="LoRA alpha to be used for additional scaling.",
+ )
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--aspect_ratio_buckets",
+ type=str,
+ default=None,
+ help=(
+ "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
+ "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
+ "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument(
+ "--remote_text_encoder",
+ action="store_true",
+ help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.",
+ )
+
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.cond_image_column is None:
+ raise ValueError(
+ "you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example."
+ )
+ else:
+ assert args.image_column is not None
+ assert args.caption_column is not None
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ buckets=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+
+ self.buckets = buckets
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.cond_image_column is not None and args.cond_image_column not in column_names:
+ raise ValueError(
+ f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+ cond_images = None
+ cond_image_column = args.cond_image_column
+ if cond_image_column is not None:
+ cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
+ assert len(instance_images) == len(cond_images)
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ self.cond_images = []
+ for i, img in enumerate(instance_images):
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ if args.dataset_name is not None and cond_images is not None:
+ self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
+
+ self.pixel_values = []
+ self.cond_pixel_values = []
+ for i, image in enumerate(self.instance_images):
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ dest_image = None
+ if self.cond_images: # todo: take care of max area for buckets
+ dest_image = self.cond_images[i]
+ image_width, image_height = dest_image.size
+ if image_width * image_height > 1024 * 1024:
+ dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
+ image_width, image_height = dest_image.size
+
+ multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ dest_image = Flux2ImageProcessor.image_processor.preprocess(
+ dest_image, height=image_height, width=image_width, resize_mode="crop"
+ )
+
+ dest_image = exif_transpose(dest_image)
+ if not dest_image.mode == "RGB":
+ dest_image = dest_image.convert("RGB")
+
+ width, height = image.size
+
+ # Find the closest bucket
+ bucket_idx = find_nearest_bucket(height, width, self.buckets)
+ target_height, target_width = self.buckets[bucket_idx]
+ self.size = (target_height, target_width)
+
+ # based on the bucket assignment, define the transformations
+ image, dest_image = self.paired_transform(
+ image,
+ dest_image=dest_image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
+ )
+ self.pixel_values.append((image, bucket_idx))
+ if dest_image is not None:
+ self.cond_pixel_values.append((dest_image, bucket_idx))
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["bucket_idx"] = bucket_idx
+ if self.cond_pixel_values:
+ dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
+ example["cond_images"] = dest_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ return example
+
+ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+ if dest_image is not None:
+ dest_image = resize(dest_image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ if dest_image is not None:
+ dest_image = crop(dest_image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+ if dest_image is not None:
+ dest_image = TF.crop(dest_image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+ if dest_image is not None:
+ dest_image = TF.hflip(dest_image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+ if dest_image is not None:
+ dest_image = normalize(to_tensor(dest_image))
+
+ return (image, dest_image) if dest_image is not None else (image, None)
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ if any("cond_images" in example for example in examples):
+ cond_pixel_values = [example["cond_images"] for example in examples]
+ cond_pixel_values = torch.stack(cond_pixel_values)
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
+ batch.update({"cond_pixel_values": cond_pixel_values})
+ return batch
+
+
+class BucketBatchSampler(BatchSampler):
+ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # Group indices by bucket
+ self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
+ for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
+ self.bucket_indices[bucket_idx].append(idx)
+
+ self.sampler_len = 0
+ self.batches = []
+
+ # Pre-generate batches for each bucket
+ for indices_in_bucket in self.bucket_indices:
+ # Shuffle indices within the bucket
+ random.shuffle(indices_in_bucket)
+ # Create batches
+ for i in range(0, len(indices_in_bucket), self.batch_size):
+ batch = indices_in_bucket[i : i + self.batch_size]
+ if len(batch) < self.batch_size and self.drop_last:
+ continue # Skip partial batch if drop_last is True
+ self.batches.append(batch)
+ self.sampler_len += 1 # Count the number of batches
+
+ def __iter__(self):
+ # Shuffle the order of the batches each epoch
+ random.shuffle(self.batches)
+ for batch in self.batches:
+ yield batch
+
+ def __len__(self):
+ return self.sampler_len
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+ if args.do_fp8_training:
+ from torchao.float8 import Float8LinearConfig, convert_to_float8_training
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer = PixtralProcessor.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ revision=args.revision,
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKLFlux2.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
+ accelerator.device
+ )
+
+ quantization_config = None
+ if args.bnb_quantization_config_path is not None:
+ with open(args.bnb_quantization_config_path, "r") as f:
+ config_kwargs = json.load(f)
+ if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
+ config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
+ quantization_config = BitsAndBytesConfig(**config_kwargs)
+
+ transformer = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=quantization_config,
+ torch_dtype=weight_dtype,
+ )
+ if args.bnb_quantization_config_path is not None:
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ if not args.remote_text_encoder:
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder.requires_grad_(False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
+ # flux vae is stable in bf16 so load it in weight_dtype to reduce memory
+ vae.to(**to_kwargs)
+ # we never offload the transformer to CPU, so we can just use the accelerator device
+ transformer_to_kwargs = (
+ {"device": accelerator.device}
+ if args.bnb_quantization_config_path is not None
+ else {"device": accelerator.device, "dtype": weight_dtype}
+ )
+ transformer.to(**transformer_to_kwargs)
+ if args.do_fp8_training:
+ convert_to_float8_training(
+ transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
+ )
+
+ if not args.remote_text_encoder:
+ text_encoder.to(**to_kwargs)
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ scheduler=None,
+ revision=args.revision,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ modules_to_save = {}
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ Flux2Pipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ if args.aspect_ratio_buckets is not None:
+ buckets = parse_buckets_string(args.aspect_ratio_buckets)
+ else:
+ buckets = [(args.resolution, args.resolution)]
+ logger.info(f"Using parsed aspect ratio buckets: {buckets}")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ buckets=buckets,
+ )
+ batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda examples: collate_fn(examples),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ with torch.no_grad():
+ prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ prompt=prompt, max_sequence_length=args.max_sequence_length
+ )
+ # prompt_embeds = prompt_embeds.to(accelerator.device)
+ # text_ids = text_ids.to(accelerator.device)
+ return prompt_embeds, text_ids
+
+ def compute_remote_text_embeddings(prompts: str | list[str]):
+ import io
+
+ import requests
+
+ if args.hub_token is not None:
+ hf_token = args.hub_token
+ else:
+ from huggingface_hub import get_token
+
+ hf_token = get_token()
+ if hf_token is None:
+ raise ValueError(
+ "No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token"
+ )
+
+ def _encode_single(prompt: str):
+ response = requests.post(
+ "https://remote-text-encoder-flux-2.huggingface.co/predict",
+ json={"prompt": prompt},
+ headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"},
+ )
+ assert response.status_code == 200, f"{response.status_code=}"
+ return torch.load(io.BytesIO(response.content))
+
+ try:
+ if isinstance(prompts, (list, tuple)):
+ embeds = [_encode_single(p) for p in prompts]
+ prompt_embeds = torch.cat(embeds, dim=0).to(accelerator.device)
+ else:
+ prompt_embeds = _encode_single(prompts).to(accelerator.device)
+
+ text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)
+ return prompt_embeds, text_ids
+
+ except Exception as e:
+ raise RuntimeError("Remote text encoder inference failed.") from e
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ if args.remote_text_encoder:
+ instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ validation_image = load_image(args.validation_image_path).convert("RGB")
+ validation_kwargs = {"image": validation_image}
+ if args.validation_prompt is not None:
+ if args.remote_text_encoder:
+ validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ validation_kwargs["prompt_embeds"] = compute_text_embeddings(
+ args.validation_prompt, text_encoding_pipeline
+ )
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ text_ids = instance_text_ids
+
+ # if cache_latents is set to True, we encode images to latents and store them.
+ # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
+ # we encode them in advance as well.
+ precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
+ if precompute_latents:
+ prompt_embeds_cache = []
+ text_ids_cache = []
+ latents_cache = []
+ cond_latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ if args.cache_latents:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
+ if train_dataset.custom_instance_prompts:
+ if args.remote_text_encoder:
+ prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ prompt_embeds_cache.append(prompt_embeds)
+ text_ids_cache.append(text_ids)
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if args.cache_latents:
+ vae = vae.to("cpu")
+ del vae
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if not args.remote_text_encoder:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ del text_encoder, tokenizer
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux2-image2img-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ prompts = batch["prompts"]
+
+ with accelerator.accumulate(models_to_accumulate):
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds = prompt_embeds_cache[step]
+ text_ids = text_ids_cache[step]
+ else:
+ num_repeat_elements = len(prompts)
+ prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
+ text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].mode()
+ cond_model_input = cond_latents_cache[step].mode()
+ else:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
+
+ model_input = vae.encode(pixel_values).latent_dist.mode()
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
+
+ # model_input = Flux2Pipeline._encode_vae_image(pixel_values)
+
+ model_input = Flux2Pipeline._patchify_latents(model_input)
+ model_input = (model_input - latents_bn_mean) / latents_bn_std
+
+ cond_model_input = Flux2Pipeline._patchify_latents(cond_model_input)
+ cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
+
+ model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
+ cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
+ device=cond_model_input.device
+ )
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # [B, C, H, W] -> [B, H*W, C]
+ packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
+ packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
+
+ # concatenate the model inputs with the cond inputs
+ packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
+ model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
+
+ # handle guidance
+ guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=model_input_ids, # B, image_seq_len, 4
+ return_dict=False,
+ )[0]
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
+
+ model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=None,
+ tokenizer=None,
+ transformer=unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_kwargs,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ del pipeline
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ modules_to_save = {}
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+ modules_to_save["transformer"] = transformer
+
+ Flux2Pipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ images = []
+ run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
+ should_run_final_inference = not args.skip_final_inference and run_validation
+ if should_run_final_inference:
+ pipeline = Flux2Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_kwargs,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+ del pipeline
+ free_memory()
+
+ validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ save_model_card(
+ (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=validation_prompt,
+ repo_folder=args.output_dir,
+ fp8_training=args.do_fp8_training,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py
new file mode 100644
index 000000000000..2973913fa215
--- /dev/null
+++ b/scripts/convert_flux2_to_diffusers.py
@@ -0,0 +1,475 @@
+import argparse
+from contextlib import nullcontext
+from typing import Any, Dict, Tuple
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
+
+from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# VAE
+
+python scripts/convert_flux2_to_diffusers.py \
+--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
+--vae_filename "flux2-vae.sft" \
+--output_path "/raid/yiyi/dummy-flux2-diffusers" \
+--vae
+
+# DiT
+
+python scripts/convert_flux2_to_diffusers.py \
+ --original_state_dict_repo_id diffusers-internal-dev/new-model-image \
+ --dit_filename flux-dev-dummy.sft \
+ --dit \
+ --output_path .
+
+# Full pipe
+
+python scripts/convert_flux2_to_diffusers.py \
+ --original_state_dict_repo_id diffusers-internal-dev/new-model-image \
+ --dit_filename flux-dev-dummy.sft \
+ --vae_filename "flux2-vae.sft" \
+ --dit --vae --full_pipe \
+ --output_path .
+"""
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
+parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
+parser.add_argument("--vae", action="store_true")
+parser.add_argument("--dit", action="store_true")
+parser.add_argument("--vae_dtype", type=str, default="fp32")
+parser.add_argument("--dit_dtype", type=str, default="bf16")
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--full_pipe", action="store_true")
+parser.add_argument("--output_path", type=str)
+
+args = parser.parse_args()
+
+
+def load_original_checkpoint(args, filename):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
+ "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
+ "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
+ "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
+ "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
+ "quant_conv.weight": "encoder.quant_conv.weight",
+ "quant_conv.bias": "encoder.quant_conv.bias",
+ "post_quant_conv.weight": "decoder.post_quant_conv.weight",
+ "post_quant_conv.bias": "decoder.post_quant_conv.bias",
+ "bn.running_mean": "bn.running_mean",
+ "bn.running_var": "bn.running_var",
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+
+def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = (
+ ldm_key.replace(mapping["old"], mapping["new"])
+ .replace("norm.weight", "group_norm.weight")
+ .replace("norm.bias", "group_norm.bias")
+ .replace("q.weight", "to_q.weight")
+ .replace("q.bias", "to_q.bias")
+ .replace("k.weight", "to_k.weight")
+ .replace("k.bias", "to_k.bias")
+ .replace("v.weight", "to_v.weight")
+ .replace("v.bias", "to_v.bias")
+ .replace("proj_out.weight", "to_out.0.weight")
+ .replace("proj_out.bias", "to_out.0.bias")
+ )
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ shape = new_checkpoint[diffusers_key].shape
+
+ if len(shape) == 3:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
+ elif len(shape) == 4:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
+
+
+def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
+ new_checkpoint = {}
+ for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
+ if ldm_key not in vae_state_dict:
+ continue
+ new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len(config["down_block_types"])
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
+ )
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len(config["up_block_types"])
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
+ )
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+ conv_attn_to_linear(new_checkpoint)
+
+ return new_checkpoint
+
+
+FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Image and text input projections
+ "img_in": "x_embedder",
+ "txt_in": "context_embedder",
+ # Timestep and guidance embeddings
+ "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
+ "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
+ "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
+ "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
+ # Modulation parameters
+ "double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
+ "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
+ "single_stream_modulation.lin": "single_stream_modulation.linear",
+ # Final output layer
+ # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
+ "final_layer.linear": "proj_out",
+}
+
+
+FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear",
+}
+
+
+FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
+ # Handle fused QKV projections separately as we need to break into Q, K, V projections
+ "img_attn.norm.query_norm": "attn.norm_q",
+ "img_attn.norm.key_norm": "attn.norm_k",
+ "img_attn.proj": "attn.to_out.0",
+ "img_mlp.0": "ff.linear_in",
+ "img_mlp.2": "ff.linear_out",
+ "txt_attn.norm.query_norm": "attn.norm_added_q",
+ "txt_attn.norm.key_norm": "attn.norm_added_k",
+ "txt_attn.proj": "attn.to_add_out",
+ "txt_mlp.0": "ff_context.linear_in",
+ "txt_mlp.2": "ff_context.linear_out",
+}
+
+
+FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
+ "linear1": "attn.to_qkv_mlp_proj",
+ "linear2": "attn.to_out",
+ "norm.query_norm": "attn.norm_q",
+ "norm.key_norm": "attn.norm_k",
+}
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
+# diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight
+ if ".weight" not in key:
+ return
+
+ # If adaLN_modulation is in the key, swap scale and shift parameters
+ # Original implementation is (shift, scale); diffusers implementation is (scale, shift)
+ if "adaLN_modulation" in key:
+ key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
+ # Assume all such keys are in the AdaLayerNorm key map
+ new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
+ new_key = ".".join([new_key_without_param_type, param_type])
+
+ swapped_weight = swap_scale_shift(state_dict.pop(key))
+ state_dict[new_key] = swapped_weight
+ return
+
+
+def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ new_prefix = "transformer_blocks"
+ if "double_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ if "qkv" in within_block_name:
+ fused_qkv_weight = state_dict.pop(key)
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ if "img" in modality_block_name:
+ # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.to_q"
+ new_k_name = "attn.to_k"
+ new_v_name = "attn.to_v"
+ elif "txt" in modality_block_name:
+ # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.add_q_proj"
+ new_k_name = "attn.add_k_proj"
+ new_v_name = "attn.add_v_proj"
+ new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
+ new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
+ new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
+ state_dict[new_q_key] = to_q_weight
+ state_dict[new_k_key] = to_k_weight
+ state_dict[new_v_key] = to_v_weight
+ else:
+ new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+
+def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ # Mapping:
+ # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
+ # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
+ # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
+ # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
+ new_prefix = "single_transformer_blocks"
+ if "single_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "adaLN_modulation": convert_ada_layer_norm_weights,
+ "double_blocks": convert_flux2_double_stream_blocks,
+ "single_blocks": convert_flux2_single_stream_blocks,
+}
+
+
+def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
+ if model_type == "test" or model_type == "dummy-flux2":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-flux2",
+ "diffusers_config": {
+ "patch_size": 1,
+ "in_channels": 128,
+ "num_layers": 8,
+ "num_single_layers": 48,
+ "attention_head_dim": 128,
+ "num_attention_heads": 48,
+ "joint_attention_dim": 15360,
+ "timestep_guidance_channels": 256,
+ "mlp_ratio": 3.0,
+ "axes_dims_rope": (32, 32, 32, 32),
+ "rope_theta": 2000,
+ "eps": 1e-6,
+ },
+ }
+ rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
+ config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
+
+ diffusers_config = config["diffusers_config"]
+
+ with init_empty_weights():
+ transformer = Flux2Transformer2DModel.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict(original_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def main(args):
+ if args.vae:
+ original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
+ vae = AutoencoderKLFlux2()
+ converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if not args.full_pipe:
+ vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
+ vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
+
+ if args.dit:
+ original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
+ transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
+ if not args.full_pipe:
+ dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
+ transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if args.full_pipe:
+ tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
+ text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
+ generate_config = GenerationConfig.from_pretrained(text_encoder_id)
+ generate_config.do_sample = True
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
+ text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
+ )
+ tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", subfolder="scheduler"
+ )
+
+ pipe = Flux2Pipeline(
+ vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
+ )
+ pipe.save_pretrained(args.output_path)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index cd7a2cb581b7..f4ceb06882f2 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -186,6 +186,7 @@
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
+ "AutoencoderKLFlux2",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
@@ -215,6 +216,7 @@
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
+ "Flux2Transformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
@@ -457,6 +459,7 @@
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
+ "Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -900,6 +903,7 @@
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
+ AutoencoderKLFlux2,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
@@ -929,6 +933,7 @@
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
+ Flux2Transformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
@@ -1141,6 +1146,7 @@
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
+ Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 48507aae038c..4e3eb009533a 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -81,6 +81,7 @@ def text_encoder_attn_modules(text_encoder):
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
+ "Flux2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -113,6 +114,7 @@ def text_encoder_attn_modules(text_encoder):
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
+ Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 2807416f97ae..dc7487e302c7 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -2265,3 +2265,89 @@ def get_alpha_scales(down_weight, alpha_key):
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
+
+
+def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
+ converted_state_dict = {}
+
+ prefix = "diffusion_model."
+ original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
+
+ num_double_layers = 8
+ num_single_layers = 48
+ lora_keys = ("lora_A", "lora_B")
+ attn_types = ("img_attn", "txt_attn")
+
+ for sl in range(num_single_layers):
+ single_block_prefix = f"single_blocks.{sl}"
+ attn_prefix = f"single_transformer_blocks.{sl}.attn"
+
+ for lora_key in lora_keys:
+ converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{single_block_prefix}.linear1.{lora_key}.weight"
+ )
+
+ converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{single_block_prefix}.linear2.{lora_key}.weight"
+ )
+
+ for dl in range(num_double_layers):
+ transformer_block_prefix = f"transformer_blocks.{dl}"
+
+ for lora_key in lora_keys:
+ for attn_type in attn_types:
+ attn_prefix = f"{transformer_block_prefix}.attn"
+ qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
+ fused_qkv_weight = original_state_dict.pop(qkv_key)
+
+ if lora_key == "lora_A":
+ diff_attn_proj_keys = (
+ ["to_q", "to_k", "to_v"]
+ if attn_type == "img_attn"
+ else ["add_q_proj", "add_k_proj", "add_v_proj"]
+ )
+ for proj_key in diff_attn_proj_keys:
+ converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
+ [fused_qkv_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
+
+ if attn_type == "img_attn":
+ converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+ else:
+ converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
+
+ proj_mappings = [
+ ("img_attn.proj", "attn.to_out.0"),
+ ("txt_attn.proj", "attn.to_add_out"),
+ ]
+ for org_proj, diff_proj in proj_mappings:
+ for lora_key in lora_keys:
+ original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
+ diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
+ converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
+
+ mlp_mappings = [
+ ("img_mlp.0", "ff.linear_in"),
+ ("img_mlp.2", "ff.linear_out"),
+ ("txt_mlp.0", "ff_context.linear_in"),
+ ("txt_mlp.2", "ff_context.linear_out"),
+ ]
+ for org_mlp, diff_mlp in mlp_mappings:
+ for lora_key in lora_keys:
+ original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
+ diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
+ converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 25919a896af0..4302d145a6c5 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -45,6 +45,7 @@
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
+ _convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
@@ -5084,6 +5085,209 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)
+class Flux2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
+ if is_ai_toolkit:
+ state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 7d65b30659fb..b759e04cbf2d 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -62,6 +62,7 @@
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
+ "Flux2Transformer2DModel": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index b53647d47630..7b581ac3eb9c 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -34,6 +34,7 @@
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
+ convert_flux2_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -162,6 +163,10 @@
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
+ "Flux2Transformer2DModel": {
+ "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index ef6c41e3ce97..d4676ba2526a 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -140,6 +140,7 @@
"net.blocks.0.self_attn.q_proj.weight",
"net.pos_embedder.dim_spatial_range",
],
+ "flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -189,6 +190,7 @@
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
+ "flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "animatediff_v3"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
+ model_type = "flux-2-dev"
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
@@ -3647,3 +3652,168 @@ def rename_transformer_blocks_(key: str, state_dict):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
+
+
+def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Image and text input projections
+ "img_in": "x_embedder",
+ "txt_in": "context_embedder",
+ # Timestep and guidance embeddings
+ "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
+ "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
+ "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
+ "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
+ # Modulation parameters
+ "double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
+ "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
+ "single_stream_modulation.lin": "single_stream_modulation.linear",
+ # Final output layer
+ # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
+ "final_layer.linear": "proj_out",
+ }
+
+ FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear",
+ }
+
+ FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
+ # Handle fused QKV projections separately as we need to break into Q, K, V projections
+ "img_attn.norm.query_norm": "attn.norm_q",
+ "img_attn.norm.key_norm": "attn.norm_k",
+ "img_attn.proj": "attn.to_out.0",
+ "img_mlp.0": "ff.linear_in",
+ "img_mlp.2": "ff.linear_out",
+ "txt_attn.norm.query_norm": "attn.norm_added_q",
+ "txt_attn.norm.key_norm": "attn.norm_added_k",
+ "txt_attn.proj": "attn.to_add_out",
+ "txt_mlp.0": "ff_context.linear_in",
+ "txt_mlp.2": "ff_context.linear_out",
+ }
+
+ FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
+ "linear1": "attn.to_qkv_mlp_proj",
+ "linear2": "attn.to_out",
+ "norm.query_norm": "attn.norm_q",
+ "norm.key_norm": "attn.norm_k",
+ }
+
+ def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ # Mapping:
+ # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
+ # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
+ # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
+ # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
+ new_prefix = "single_transformer_blocks"
+ if "single_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ return
+
+ def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight
+ if ".weight" not in key:
+ return
+
+ # If adaLN_modulation is in the key, swap scale and shift parameters
+ # Original implementation is (shift, scale); diffusers implementation is (scale, shift)
+ if "adaLN_modulation" in key:
+ key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
+ # Assume all such keys are in the AdaLayerNorm key map
+ new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
+ new_key = ".".join([new_key_without_param_type, param_type])
+
+ swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
+ state_dict[new_key] = swapped_weight
+
+ return
+
+ def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ new_prefix = "transformer_blocks"
+ if "double_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ if "qkv" in within_block_name:
+ fused_qkv_weight = state_dict.pop(key)
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ if "img" in modality_block_name:
+ # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.to_q"
+ new_k_name = "attn.to_k"
+ new_v_name = "attn.to_v"
+ elif "txt" in modality_block_name:
+ # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.add_q_proj"
+ new_k_name = "attn.add_k_proj"
+ new_v_name = "attn.add_v_proj"
+ new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
+ new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
+ new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
+ state_dict[new_q_key] = to_q_weight
+ state_dict[new_k_key] = to_k_weight
+ state_dict[new_v_key] = to_v_weight
+ else:
+ new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "adaLN_modulation": convert_ada_layer_norm_weights,
+ "double_blocks": convert_flux2_double_stream_blocks,
+ "single_blocks": convert_flux2_single_stream_blocks,
+ }
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict(converted_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index b42e981f71a9..61bc613d88ea 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -35,6 +35,7 @@
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
+ _import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
@@ -92,6 +93,7 @@
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
+ _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
@@ -140,6 +142,7 @@
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
+ AutoencoderKLFlux2,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
@@ -190,6 +193,7 @@
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
+ Flux2Transformer2DModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 5164cf311d3c..8b583d1a1cce 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -105,7 +105,7 @@ def fuse_qkv_projections(self):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
for module in self.modules():
- if isinstance(module, AttentionModuleMixin):
+ if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.fuse_projections()
def unfuse_qkv_projections(self):
@@ -114,13 +114,14 @@ def unfuse_qkv_projections(self):
> [!WARNING] > This API is 🧪 experimental.
"""
for module in self.modules():
- if isinstance(module, AttentionModuleMixin):
+ if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.unfuse_projections()
class AttentionModuleMixin:
_default_processor_cls = None
_available_processors = []
+ _supports_qkv_fusion = True
fused_projections = False
def set_processor(self, processor: AttentionProcessor) -> None:
@@ -248,6 +249,14 @@ def fuse_projections(self):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
+ # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
+ # single stream blocks are always fused)
+ if not self._supports_qkv_fusion:
+ logger.debug(
+ f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
+ )
+ return
+
# Skip if already fused
if getattr(self, "fused_projections", False):
return
@@ -307,6 +316,11 @@ def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
+ # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
+ # single stream blocks are always fused)
+ if not self._supports_qkv_fusion:
+ return
+
# Skip if not fused
if not getattr(self, "fused_projections", False):
return
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index edfaabb070c5..470979ad33a7 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -4,6 +4,7 @@
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
+from .autoencoder_kl_flux2 import AutoencoderKLFlux2
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
new file mode 100644
index 000000000000..7b572f82ad67
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
@@ -0,0 +1,546 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import deprecate
+from ...utils.accelerate_utils import apply_forward_hook
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+ FusedAttnProcessor2_0,
+)
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
+ mid_block will only have resnet blocks
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = (
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ up_block_types: Tuple[str, ...] = (
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ),
+ block_out_channels: Tuple[int, ...] = (
+ 128,
+ 256,
+ 512,
+ 512,
+ ),
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ latent_channels: int = 32,
+ norm_num_groups: int = 32,
+ sample_size: int = 1024, # YiYi notes: not sure
+ force_upcast: bool = True,
+ use_quant_conv: bool = True,
+ use_post_quant_conv: bool = True,
+ mid_block_add_attention: bool = True,
+ batch_norm_eps: float = 1e-4,
+ batch_norm_momentum: float = 0.1,
+ patch_size: Tuple[int, int] = (2, 2),
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
+
+ self.bn = nn.BatchNorm2d(
+ math.prod(patch_size) * latent_channels,
+ eps=batch_norm_eps,
+ momentum=batch_norm_momentum,
+ affine=False,
+ track_running_stats=True,
+ )
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
+ return self._tiled_encode(x)
+
+ enc = self.encoder(x)
+ if self.quant_conv is not None:
+ enc = self.quant_conv(enc)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ enc = torch.cat(result_rows, dim=2)
+ return enc
+
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ deprecation_message = (
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
+ )
+ deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
+
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ if self.config.use_post_quant_conv:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+ > [!WARNING] > This API is 🧪 experimental.
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+ > [!WARNING] > This API is 🧪 experimental.
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 826469237fb1..5a482c5c0e6a 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -26,6 +26,7 @@
from .transformer_cosmos import CosmosTransformer3DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
+ from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py
new file mode 100644
index 000000000000..d2b3d8a733f3
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_flux2.py
@@ -0,0 +1,908 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ TimestepEmbedding,
+ Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class Flux2SwiGLU(nn.Module):
+ """
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.gate_fn = nn.SiLU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ x = self.gate_fn(x1) * x2
+ return x
+
+
+class Flux2FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: float = 3.0,
+ inner_dim: Optional[int] = None,
+ bias: bool = False,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out or dim
+
+ # Flux2SwiGLU will reduce the dimension by half
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
+ self.act_fn = Flux2SwiGLU()
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_in(x)
+ x = self.act_fn(x)
+ x = self.linear_out(x)
+ return x
+
+
+class Flux2AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "Flux2Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Flux2AttnProcessor
+ _available_processors = [Flux2AttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ # QK Norm
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class Flux2ParallelSelfAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "Flux2ParallelSelfAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Parallel in (QKV + MLP in) projection
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
+ qkv, mlp_hidden_states = torch.split(
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
+ )
+
+ # Handle the attention logic
+ query, key, value = qkv.chunk(3, dim=-1)
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Handle the feedforward (FF) logic
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
+
+ # Concatenate and parallel output projection
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
+ hidden_states = attn.to_out(hidden_states)
+
+ return hidden_states
+
+
+class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
+ """
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
+
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
+ """
+
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
+ _available_processors = [Flux2ParallelSelfAttnProcessor]
+ # Does not support QKV fusion as the QKV projections are always fused
+ _supports_qkv_fusion = False
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ mlp_ratio: float = 4.0,
+ mlp_mult_factor: int = 2,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.mlp_ratio = mlp_ratio
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
+ self.mlp_mult_factor = mlp_mult_factor
+
+ # Fused QKV projections + MLP input projection
+ self.to_qkv_mlp_proj = torch.nn.Linear(
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
+ )
+ self.mlp_act_fn = Flux2SwiGLU()
+
+ # QK Norm
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+
+ # Fused attention output projection + MLP output projection
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class Flux2SingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
+ # for a visual depiction of this type of transformer block.
+ self.attn = Flux2ParallelSelfAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ out_bias=bias,
+ eps=eps,
+ mlp_ratio=mlp_ratio,
+ mlp_mult_factor=2,
+ processor=Flux2ParallelSelfAttnProcessor(),
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ split_hidden_states: bool = False,
+ text_seq_len: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
+ # concatenated
+ if encoder_hidden_states is not None:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ mod_shift, mod_scale, mod_gate = temb_mod_params
+
+ norm_hidden_states = self.norm(hidden_states)
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = hidden_states + mod_gate * attn_output
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ if split_hidden_states:
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+ else:
+ return hidden_states
+
+
+class Flux2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+
+ self.attn = Flux2Attention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ added_proj_bias=bias,
+ out_bias=bias,
+ eps=eps,
+ processor=Flux2AttnProcessor(),
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Modulation parameters shape: [1, 1, self.dim]
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
+
+ # Img stream
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
+
+ # Conditioning txt stream
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
+
+ # Attention on concatenated img + txt stream
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ attn_output, context_attn_output = attention_outputs
+
+ # Process attention outputs for the image stream (`hidden_states`).
+ attn_output = gate_msa * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
+ context_attn_output = c_gate_msa * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class Flux2PosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ # Expected ids shape: [S, len(self.axes_dim)]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
+ for i in range(len(self.axes_dim)):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[..., i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class Flux2TimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
+ )
+
+ self.guidance_embedder = TimestepEmbedding(
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
+ )
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ return time_guidance_emb
+
+
+class Flux2Modulation(nn.Module):
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
+ super().__init__()
+ self.mod_param_sets = mod_param_sets
+
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
+ mod = self.act_fn(temb)
+ mod = self.linear(mod)
+
+ if mod.ndim == 2:
+ mod = mod.unsqueeze(1)
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
+
+
+class Flux2Transformer2DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
+):
+ """
+ The Transformer model introduced in Flux 2.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `128`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `8`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `48`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `48`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `15360`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ pooled_projection_dim (`int`, defaults to `768`):
+ The number of dimensions to use for the pooled projection.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 128,
+ out_channels: Optional[int] = None,
+ num_layers: int = 8,
+ num_single_layers: int = 48,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 48,
+ joint_attention_dim: int = 15360,
+ timestep_guidance_channels: int = 256,
+ mlp_ratio: float = 3.0,
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
+ rope_theta: int = 2000,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ # 2. Combined timestep + guidance embedding
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
+ in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
+ )
+
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
+
+ # 4. Input projections
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
+
+ # 5. Double Stream Transformer Blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ Flux2TransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 6. Single Stream Transformer Blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ Flux2SingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 7. Output layers
+ self.norm_out = AdaLayerNormContinuous(
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # 0. Handle input arguments
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ num_txt_tokens = encoder_hidden_states.shape[1]
+
+ # 1. Calculate timestep embedding and modulation parameters
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = self.time_guidance_embed(timestep, guidance)
+
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
+ single_stream_mod = self.single_stream_modulation(temb)[0]
+
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # 3. Calculate RoPE embeddings from image and text tokens
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
+ # text prompts of differents lengths. Is this a use case we want to support?
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+
+ if is_torch_npu_available():
+ freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
+ freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
+ else:
+ image_rotary_emb = self.pos_embed(img_ids)
+ text_rotary_emb = self.pos_embed(txt_ids)
+ concat_rotary_emb = (
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
+ )
+
+ # 4. Double Stream Transformer Blocks
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ double_stream_mod_img,
+ double_stream_mod_txt,
+ concat_rotary_emb,
+ joint_attention_kwargs,
+ )
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb_mod_params_img=double_stream_mod_img,
+ temb_mod_params_txt=double_stream_mod_txt,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ # Concatenate text and image streams for single-block inference
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 5. Single Stream Transformer Blocks
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ None,
+ single_stream_mod,
+ concat_rotary_emb,
+ joint_attention_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=None,
+ temb_mod_params=single_stream_mod,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ # Remove text tokens from concatenated stream
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
+
+ # 6. Output layers
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 69bb14b98edc..e29e6ff1cfdb 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -129,6 +129,7 @@
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
+ _import_structure["flux2"] = ["Flux2Pipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -654,6 +655,7 @@
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
+ from .flux2 import Flux2Pipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py
new file mode 100644
index 000000000000..d986c9a63011
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_flux2 import Flux2Pipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py
new file mode 100644
index 000000000000..91c8f875dd1d
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/image_processor.py
@@ -0,0 +1,138 @@
+# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Tuple
+
+import PIL.Image
+
+from ...configuration_utils import register_to_config
+from ...image_processor import VaeImageProcessor
+
+
+class Flux2ImageProcessor(VaeImageProcessor):
+ r"""
+ Image processor to preprocess the reference (character) image for the Flux2 model.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `16`):
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
+ this factor.
+ vae_latent_channels (`int`, *optional*, defaults to `32`):
+ VAE latent channels.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_convert_rgb (`bool`, *optional*, defaults to be `True`):
+ Whether to convert the images to RGB format.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 16,
+ vae_latent_channels: int = 32,
+ do_normalize: bool = True,
+ do_convert_rgb: bool = True,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ do_normalize=do_normalize,
+ do_convert_rgb=do_convert_rgb,
+ )
+
+ @staticmethod
+ def check_image_input(
+ image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
+ ) -> PIL.Image.Image:
+ """
+ Check if image meets minimum size and aspect ratio requirements.
+
+ Args:
+ image: PIL Image to validate
+ max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
+ min_side_length: Minimum pixels required for width and height
+ max_area: Maximum allowed area in pixels²
+
+ Returns:
+ The input image if valid
+
+ Raises:
+ ValueError: If image is too small or aspect ratio is too extreme
+ """
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
+
+ width, height = image.size
+
+ # Check minimum dimensions
+ if width < min_side_length or height < min_side_length:
+ raise ValueError(
+ f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
+ )
+
+ # Check aspect ratio
+ aspect_ratio = max(width / height, height / width)
+ if aspect_ratio > max_aspect_ratio:
+ raise ValueError(
+ f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
+ f"Maximum allowed ratio is {max_aspect_ratio}:1"
+ )
+
+ return image
+
+ @staticmethod
+ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
+ image_width, image_height = image.size
+
+ scale = math.sqrt(target_area / (image_width * image_height))
+ width = int(image_width * scale)
+ height = int(image_height * scale)
+
+ return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
+
+ def _resize_and_crop(
+ self,
+ image: PIL.Image.Image,
+ width: int,
+ height: int,
+ ) -> PIL.Image.Image:
+ r"""
+ center crop the image to the specified width and height.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to resize and crop.
+ width (`int`):
+ The width to resize the image to.
+ height (`int`):
+ The height to resize the image to.
+
+ Returns:
+ `PIL.Image.Image`:
+ The resized and cropped image.
+ """
+ image_width, image_height = image.size
+
+ left = (image_width - width) // 2
+ top = (image_height - height) // 2
+ right = left + width
+ bottom = top + height
+
+ return image.crop((left, top, right, bottom))
diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py
new file mode 100644
index 000000000000..676bf6d98429
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py
@@ -0,0 +1,883 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+from ...loaders import Flux2LoraLoaderMixin
+from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import Flux2ImageProcessor
+from .pipeline_output import Flux2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Flux2Pipeline
+
+ >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+def format_text_input(prompts: List[str], system_message: str = None):
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
+ # if the count changes after truncation.
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
+
+ return [
+ [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": system_message}],
+ },
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
+ ]
+ for prompt in cleaned_txt
+ ]
+
+
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
+ a1, b1 = 8.73809524e-05, 1.89833333
+ a2, b2 = 0.00016927, 0.45666666
+
+ if image_seq_len > 4300:
+ mu = a2 * image_seq_len + b2
+ return float(mu)
+
+ m_200 = a2 * image_seq_len + b2
+ m_10 = a1 * image_seq_len + b1
+
+ a = (m_200 - m_10) / 190.0
+ b = m_200 - 200.0 * a
+ mu = a * num_steps + b
+
+ return float(mu)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
+ r"""
+ The Flux2 pipeline for text-to-image generation.
+
+ Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
+
+ Args:
+ transformer ([`Flux2Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLFlux2`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Mistral3ForConditionalGeneration`]):
+ [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
+ tokenizer (`AutoProcessor`):
+ Tokenizer of class
+ [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLFlux2,
+ text_encoder: Mistral3ForConditionalGeneration,
+ tokenizer: AutoProcessor,
+ transformer: Flux2Transformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 512
+ self.default_sample_size = 128
+
+ # fmt: off
+ self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
+ # fmt: on
+
+ @staticmethod
+ def _get_mistral_3_small_prompt_embeds(
+ text_encoder: Mistral3ForConditionalGeneration,
+ tokenizer: AutoProcessor,
+ prompt: Union[str, List[str]],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 512,
+ # fmt: off
+ system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
+ # fmt: on
+ hidden_states_layers: List[int] = (10, 20, 30),
+ ):
+ dtype = text_encoder.dtype if dtype is None else dtype
+ device = text_encoder.device if device is None else device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ # Format input messages
+ messages_batch = format_text_input(prompts=prompt, system_message=system_message)
+
+ # Process all messages at once
+ inputs = tokenizer.apply_chat_template(
+ messages_batch,
+ add_generation_prompt=False,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_sequence_length,
+ )
+
+ # Move to device
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+
+ # Forward pass through the model
+ output = text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ use_cache=False,
+ )
+
+ # Only use outputs from intermediate layers and stack them
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
+ out = out.to(dtype=dtype, device=device)
+
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
+
+ return prompt_embeds
+
+ @staticmethod
+ def _prepare_text_ids(
+ x: torch.Tensor, # (B, L, D) or (L, D)
+ t_coord: Optional[torch.Tensor] = None,
+ ):
+ B, L, _ = x.shape
+ out_ids = []
+
+ for i in range(B):
+ t = torch.arange(1) if t_coord is None else t_coord[i]
+ h = torch.arange(1)
+ w = torch.arange(1)
+ l = torch.arange(L)
+
+ coords = torch.cartesian_prod(t, h, w, l)
+ out_ids.append(coords)
+
+ return torch.stack(out_ids)
+
+ @staticmethod
+ def _prepare_latent_ids(
+ latents: torch.Tensor, # (B, C, H, W)
+ ):
+ r"""
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
+
+ Args:
+ latents (torch.Tensor):
+ Latent tensor of shape (B, C, H, W)
+
+ Returns:
+ torch.Tensor:
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
+ H=[0..H-1], W=[0..W-1], L=0
+ """
+
+ batch_size, _, height, width = latents.shape
+
+ t = torch.arange(1) # [0] - time dimension
+ h = torch.arange(height)
+ w = torch.arange(width)
+ l = torch.arange(1) # [0] - layer dimension
+
+ # Create position IDs: (H*W, 4)
+ latent_ids = torch.cartesian_prod(t, h, w, l)
+
+ # Expand to batch: (B, H*W, 4)
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
+
+ return latent_ids
+
+ @staticmethod
+ def _prepare_image_ids(
+ image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
+ scale: int = 10,
+ ):
+ r"""
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
+
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
+ dimensions.
+
+ Args:
+ image_latents (List[torch.Tensor]):
+ A list of image latent feature tensors, typically of shape (C, H, W).
+ scale (int, optional):
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
+ latent is: 'scale + scale * i'. Defaults to 10.
+
+ Returns:
+ torch.Tensor:
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
+ input latents.
+
+ Coordinate Components (Dimension 4):
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
+ - H (Height): The row index within that latent image.
+ - W (Width): The column index within that latent image.
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
+ """
+
+ if not isinstance(image_latents, list):
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
+
+ # create time offset for each reference image
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
+ t_coords = [t.view(-1) for t in t_coords]
+
+ image_latent_ids = []
+ for x, t in zip(image_latents, t_coords):
+ x = x.squeeze(0)
+ _, height, width = x.shape
+
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
+ image_latent_ids.append(x_ids)
+
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
+ image_latent_ids = image_latent_ids.unsqueeze(0)
+
+ return image_latent_ids
+
+ @staticmethod
+ def _patchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
+ return latents
+
+ @staticmethod
+ def _unpatchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
+ return latents
+
+ @staticmethod
+ def _pack_latents(latents):
+ """
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
+ """
+
+ batch_size, num_channels, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
+ """
+ using position ids to scatter tokens into place
+ """
+ x_list = []
+ for data, pos in zip(x, x_ids):
+ _, ch = data.shape # noqa: F841
+ h_ids = pos[:, 1].to(torch.int64)
+ w_ids = pos[:, 2].to(torch.int64)
+
+ h = torch.max(h_ids) + 1
+ w = torch.max(w_ids) + 1
+
+ flat_ids = h_ids * w + w_ids
+
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
+
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
+
+ out = out.view(h, w, ch).permute(2, 0, 1)
+ x_list.append(out)
+
+ return torch.stack(x_list, dim=0)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
+ ):
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = ""
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_mistral_3_small_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ system_message=self.system_message,
+ hidden_states_layers=text_encoder_out_layers,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ text_ids = self._prepare_text_ids(prompt_embeds)
+ text_ids = text_ids.to(device)
+ return prompt_embeds, text_ids
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if image.ndim != 4:
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
+
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ image_latents = self._patchify_latents(image_latents)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_latents_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator: torch.Generator,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latent_ids = self._prepare_latent_ids(latents)
+ latent_ids = latent_ids.to(device)
+
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
+ return latents, latent_ids
+
+ def prepare_image_latents(
+ self,
+ images: List[torch.Tensor],
+ batch_size,
+ generator: torch.Generator,
+ device,
+ dtype,
+ ):
+ image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
+
+ image_latent_ids = self._prepare_image_ids(image_latents)
+
+ # Pack each latent and concatenate
+ packed_latents = []
+ for latent in image_latents:
+ # latent: (1, 128, 32, 32)
+ packed = self._pack_latents(latent) # (1, 1024, 128)
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
+ packed_latents.append(packed)
+
+ # Concatenate all reference tokens along sequence dimension
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
+
+ image_latents = image_latents.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.to(device)
+
+ return image_latents, image_latent_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if (
+ height is not None
+ and height % (self.vae_scale_factor * 2) != 0
+ or width is not None
+ and width % (self.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ text_encoder_out_layers (`Tuple[int]`):
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare text embeddings
+ prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ # 4. process images
+ if image is not None and not isinstance(image, list):
+ image = [image]
+
+ condition_images = None
+ if image is not None:
+ for img in image:
+ self.image_processor.check_image_input(img)
+
+ condition_images = []
+ for img in image:
+ image_width, image_height = img.size
+ if image_width * image_height > 1024 * 1024:
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
+ image_width, image_height = img.size
+
+ multiple_of = self.vae_scale_factor * 2
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
+ condition_images.append(img)
+ height = height or image_height
+ width = width or image_width
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 5. prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_ids = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_latents_channels=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ image_latents = None
+ image_latent_ids = None
+ if condition_images is not None:
+ image_latents, image_latent_ids = self.prepare_image_latents(
+ images=condition_images,
+ batch_size=batch_size * num_images_per_prompt,
+ generator=generator,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 6. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+
+ # 7. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ latent_model_input = latents.to(self.transformer.dtype)
+ latent_image_ids = latent_ids
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=latent_image_ids, # B, image_seq_len, 4
+ joint_attention_kwargs=self._attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred[:, : latents.size(1) :]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ torch.save({"pred": latents}, "pred_d.pt")
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
+ latents.device, latents.dtype
+ )
+ latents = latents * latents_bn_std + latents_bn_mean
+ latents = self._unpatchify_latents(latents)
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return Flux2PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux2/pipeline_output.py b/src/diffusers/pipelines/flux2/pipeline_output.py
new file mode 100644
index 000000000000..58e8ad49c210
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/pipeline_output.py
@@ -0,0 +1,23 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class Flux2PipelineOutput(BaseOutput):
+ """
+ Output class for Flux2 image generation pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
+ passed to the decoder.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index f56a8b932505..928f0b977473 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -408,6 +408,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLFlux2(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
_backends = ["torch"]
@@ -843,6 +858,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class Flux2Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class FluxControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 9eb123b94e9d..769dda25c125 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -827,6 +827,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class Flux2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py
new file mode 100644
index 000000000000..768d10fec72e
--- /dev/null
+++ b/tests/lora/test_lora_layers_flux2.py
@@ -0,0 +1,127 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
+
+from ..testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = Flux2Pipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 16,
+ "timestep_guidance_channels": 256,
+ "axes_dims_rope": [4, 4, 4, 4],
+ }
+ transformer_cls = Flux2Transformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",),
+ "up_block_types": ("UpDecoderBlock2D",),
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 1,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ }
+ vae_cls = AutoencoderKLFlux2
+
+ tokenizer_cls, tokenizer_id = AutoProcessor, "hf-internal-testing/tiny-mistral3-diffusers"
+ text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
+ denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "a dog is dancing",
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 8,
+ "output_type": "np",
+ "text_encoder_out_layers": (1,),
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py
new file mode 100644
index 000000000000..316d5fa770bb
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_flux2.py
@@ -0,0 +1,162 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Flux2Transformer2DModel, attention_backend
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.7, 0.6, 0.6]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_dummy_input(self, height=4, width=4):
+ batch_size = 1
+ num_latent_channels = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+
+ t_coords = torch.arange(1)
+ h_coords = torch.arange(height)
+ w_coords = torch.arange(width)
+ l_coords = torch.arange(1)
+ image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4]
+ image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
+
+ text_t_coords = torch.arange(1)
+ text_h_coords = torch.arange(1)
+ text_w_coords = torch.arange(1)
+ text_l_coords = torch.arange(sequence_length)
+ text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
+ text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
+
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+ guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ "guidance": guidance,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "timestep_guidance_channels": 256, # Hardcoded in original code
+ "axes_dims_rope": [4, 4, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ # TODO (Daniel, Sayak): We can remove this test.
+ def test_flux2_consistency(self, seed=0):
+ torch.manual_seed(seed)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(seed)
+ model = self.model_class(**init_dict)
+ # state_dict = model.state_dict()
+ # for key, param in state_dict.items():
+ # print(f"{key} | {param.shape}")
+ # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
+ model.to(torch_device)
+ model.eval()
+
+ with attention_backend("native"):
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+
+ # input & output have to have the same shape
+ input_tensor = inputs_dict[self.main_input_name]
+ expected_shape = input_tensor.shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ # Check against expected slice
+ # fmt: off
+ expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
+ # fmt: on
+
+ flat_output = output.cpu().flatten()
+ generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Flux2Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
+
+
+class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/pipelines/flux2/__init__.py b/tests/pipelines/flux2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py
new file mode 100644
index 000000000000..4404dbc51047
--- /dev/null
+++ b/tests/pipelines/flux2/test_pipeline_flux2.py
@@ -0,0 +1,190 @@
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLFlux2,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2Pipeline,
+ Flux2Transformer2DModel,
+)
+
+from ...testing_utils import (
+ torch_device,
+)
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fused_layers_exist,
+)
+
+
+class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Flux2Pipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = Flux2Transformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=16,
+ timestep_guidance_channels=256, # Hardcoded in original code
+ axes_dims_rope=[4, 4, 4, 4],
+ )
+
+ config = Mistral3Config(
+ text_config={
+ "model_type": "mistral",
+ "vocab_size": 32000,
+ "hidden_size": 16,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 4,
+ "num_hidden_layers": 1,
+ "num_key_value_heads": 2,
+ "rms_norm_eps": 1e-05,
+ "rope_theta": 1000000000.0,
+ "sliding_window": None,
+ "bos_token_id": 2,
+ "eos_token_id": 3,
+ "pad_token_id": 4,
+ },
+ vision_config={
+ "model_type": "pixtral",
+ "hidden_size": 16,
+ "num_hidden_layers": 1,
+ "num_attention_heads": 4,
+ "intermediate_size": 37,
+ "image_size": 30,
+ "patch_size": 6,
+ "num_channels": 3,
+ },
+ bos_token_id=2,
+ eos_token_id=3,
+ pad_token_id=4,
+ model_dtype="mistral3",
+ image_seq_length=4,
+ vision_feature_layer=-1,
+ image_token_index=1,
+ )
+ torch.manual_seed(0)
+ text_encoder = Mistral3ForConditionalGeneration(config)
+ tokenizer = AutoProcessor.from_pretrained(
+ "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLFlux2(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "a dog is dancing",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 8,
+ "output_type": "np",
+ "text_encoder_out_layers": (1,),
+ }
+ return inputs
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
+ ("Fusion of QKV projections shouldn't affect the outputs."),
+ )
+ self.assertTrue(
+ np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
+ )
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
+ ("Original outputs should match when fused QKV projections are disabled."),
+ )
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ self.assertEqual(
+ (output_height, output_width),
+ (expected_height, expected_width),
+ f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
+ )
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index e2bbce7b0ead..22570b28841e 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -103,7 +103,7 @@ def check_qkv_fusion_processors_exist(model):
def check_qkv_fused_layers_exist(model, layer_names):
is_fused_submodules = []
for submodule in model.modules():
- if not isinstance(submodule, AttentionModuleMixin):
+ if not isinstance(submodule, AttentionModuleMixin) or not submodule._supports_qkv_fusion:
continue
is_fused_attribute_set = submodule.fused_projections
is_fused_layer = True