From 7350d07bd2d8b09d3b330f202f2c7b4bf4d34081 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 10:31:36 +0530 Subject: [PATCH 1/9] feat: implement caption upsampling for flux.2. --- .../pipelines/flux2/image_processor.py | 43 +++- .../pipelines/flux2/pipeline_flux2.py | 192 ++++++++++++++++-- .../pipelines/flux2/system_messages.py | 30 +++ 3 files changed, 250 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/pipelines/flux2/system_messages.py diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index 91c8f875dd1d..6a36c959e415 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Tuple +from typing import List import PIL.Image @@ -98,10 +98,15 @@ def check_image_input( return image @staticmethod - def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: + def _resize_to_target_area( + image: PIL.Image.Image, target_area: int = 1024 * 1024, return_if_small_image: bool = False + ) -> PIL.Image.Image: image_width, image_height = image.size + pixel_count = image_width * image_height + if return_if_small_image and pixel_count <= target_area: + return image - scale = math.sqrt(target_area / (image_width * image_height)) + scale = math.sqrt(target_area / pixel_count) width = int(image_width * scale) height = int(image_height * scale) @@ -136,3 +141,35 @@ def _resize_and_crop( bottom = top + height return image.crop((left, top, right, bottom)) + + # Taken from + # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19 + @staticmethod + def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image: + """ + Concatenate a list of PIL images horizontally with center alignment and white background. + """ + + # If only one image, return a copy of it + if len(images) == 1: + return images[0].copy() + + # Convert all images to RGB if not already + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + + # Calculate dimensions for horizontal concatenation + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + + # Create new image with white background + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + # Paste images with center alignment + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 676bf6d98429..eb6c4c5871d9 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -28,6 +28,7 @@ from ..pipeline_utils import DiffusionPipeline from .image_processor import Flux2ImageProcessor from .pipeline_output import Flux2PipelineOutput +from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I if is_torch_xla_available(): @@ -56,25 +57,125 @@ ``` """ +# def format_text_input(prompts: List[str], system_message: str = None): +# # Remove [IMG] tokens from prompts to avoid Pixtral validation issues +# # when truncation is enabled. The processor counts [IMG] tokens and fails +# # if the count changes after truncation. +# cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + +# return [ +# [ +# { +# "role": "system", +# "content": [{"type": "text", "text": system_message}], +# }, +# {"role": "user", "content": [{"type": "text", "text": prompt}]}, +# ] +# for prompt in cleaned_txt +# ] + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +def format_input( + prompts: List[str], + system_message: str = SYSTEM_MESSAGE, + images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None, +): + """ + Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images + to the input. + + Args: + prompts: List of text prompts + system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) + images (optional): List of images to add to the input. -def format_text_input(prompts: List[str], system_message: str = None): + Returns: + List of conversations, where each conversation is a list of message dicts + """ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues # when truncation is enabled. The processor counts [IMG] tokens and fails # if the count changes after truncation. cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] - return [ + if images is None or len(images) == 0: + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + else: + assert len(images) == len(prompts), "Number of images must match number of prompts" + images = _validate_and_process_images(images) + + messages = [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + ] + for _ in cleaned_txt + ] + + for i, (el, images) in enumerate(zip(messages, images)): + # optionally add the images per batch element. + if images is not None: + el.append( + { + "role": "user", + "content": [{"type": "image", "image": image_obj} for image_obj in images], + } + ) + # add the text. + el.append( + { + "role": "user", + "content": [{"type": "text", "text": cleaned_txt[i]}], + } + ) + + return messages + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +def _validate_and_process_images( + images: List[List[PIL.Image.Image]] | List[PIL.Image.Image], + image_processor: Flux2ImageProcessor, + upsampling_max_image_size: int, +) -> List[List[PIL.Image.Image]]: + # Simple validation: ensure it's a list of PIL images or list of lists of PIL images + if not images: + return [] + + # Check if it's a list of lists or a list of images + if isinstance(images[0], PIL.Image.Image): + # It's a list of images, convert to list of lists + images = [[im] for im in images] + + # potentially concatenate multiple images to reduce the size + images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] + + # cap the pixels + images = [ [ - { - "role": "system", - "content": [{"type": "text", "text": system_message}], - }, - {"role": "user", "content": [{"type": "text", "text": prompt}]}, + image_processor._resize_to_target_area(img_i, upsampling_max_image_size, return_if_small_image=True) + for img_i in img_i ] - for prompt in cleaned_txt + for img_i in images ] + return images +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 @@ -214,9 +315,10 @@ def __init__( self.tokenizer_max_length = 512 self.default_sample_size = 128 - # fmt: off - self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." - # fmt: on + self.system_message = SYSTEM_MESSAGE + self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I + self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I + self.upsampling_max_image_size = 768**2 @staticmethod def _get_mistral_3_small_prompt_embeds( @@ -237,7 +339,7 @@ def _get_mistral_3_small_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt # Format input messages - messages_batch = format_text_input(prompts=prompt, system_message=system_message) + messages_batch = format_input(prompts=prompt, system_message=system_message) # Process all messages at once inputs = tokenizer.apply_chat_template( @@ -426,6 +528,63 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) + def upsample_prompt( + self, + prompts: List[str], + images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None, + temperature: float = 0.15, + device: torch.device = None, + ) -> List[str]: + device = device or self._execution_device + + # Set system message based on whether images are provided + if images is None or len(images) == 0 or images[0] is None: + system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I + else: + system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + + # Format input messages + messages_batch = format_input(prompts=prompts, system_message=system_message, images=images) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=2048, + ) + + # Move to device + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) + + # Generate text using the model's generate method + generated_ids = self.text_encoder.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=temperature, + use_cache=True, + ) + + # Decode only the newly generated tokens (skip input tokens) + # Extract only the generated portion + input_length = inputs["input_ids"].shape[1] + generated_tokens = generated_ids[:, input_length:] + + raw_txt = self.tokenizer.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return raw_txt + def encode_prompt( self, prompt: Union[str, List[str]], @@ -620,6 +779,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, text_encoder_out_layers: Tuple[int] = (10, 20, 30), + caption_upsample_temperature: float = None, ): r""" Function invoked when calling the pipeline for generation. @@ -684,6 +844,9 @@ def __call__( max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. text_encoder_out_layers (`Tuple[int]`): Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + caption_upsample_temperature (`float`): + When specified, we will try to perform caption upsampling for potentially improved outputs. We + recommend setting it to 0.15 if caption upsampling is to be performed. Examples: @@ -718,6 +881,11 @@ def __call__( device = self._execution_device # 3. prepare text embeddings + if caption_upsample_temperature: + prompt = self.upsample_prompt( + prompt, images=image, temperature=caption_upsample_temperature, device=device + ) + prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, diff --git a/src/diffusers/pipelines/flux2/system_messages.py b/src/diffusers/pipelines/flux2/system_messages.py new file mode 100644 index 000000000000..a2e11d55f348 --- /dev/null +++ b/src/diffusers/pipelines/flux2/system_messages.py @@ -0,0 +1,30 @@ +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, +object attribution and actions without speculation.""" + +SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while +strictly preserving their core subject and intent. + +Guidelines: +1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed + paragraphs. +2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), + shadows, spatial relationships, and environmental context. +3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text + for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates + gibberish. + +Output only the revised prompt and nothing else.""" + +SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction +(50-80 words, ~30 for brief requests). + +Rules: +- Single instruction only, no commentary +- Use clear, analytical language (avoid "whimsical," "cascading," etc.) +- Specify what changes AND what stays the same (face, lighting, composition) +- Reference actual image elements +- Turn negatives into positives ("don't change X" → "keep X") +- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels") +- Keep content PG-13 + +Output only the final instruction in plain text and nothing else.""" From e6a0ab62442ee0b38691ead31c6379978b5790da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 11:04:34 +0530 Subject: [PATCH 2/9] doc --- docs/source/en/api/pipelines/flux2.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 90eaedc245b7..177cde817d11 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -26,6 +26,12 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b > > [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. +## Caption upsampling + +Flux.2 can potentially generate better better outputs with better prompts. We can "upsample" +an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments. +The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15. + ## Flux2Pipeline [[autodoc]] Flux2Pipeline From b4a840698b99eddf4913c79364eaa855b240f168 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 11:29:29 +0530 Subject: [PATCH 3/9] up --- .../pipelines/flux2/pipeline_flux2.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index eb6c4c5871d9..53da7235f1df 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -57,24 +57,6 @@ ``` """ -# def format_text_input(prompts: List[str], system_message: str = None): -# # Remove [IMG] tokens from prompts to avoid Pixtral validation issues -# # when truncation is enabled. The processor counts [IMG] tokens and fails -# # if the count changes after truncation. -# cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] - -# return [ -# [ -# { -# "role": "system", -# "content": [{"type": "text", "text": system_message}], -# }, -# {"role": "user", "content": [{"type": "text", "text": prompt}]}, -# ] -# for prompt in cleaned_txt -# ] - - # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 def format_input( @@ -328,9 +310,7 @@ def _get_mistral_3_small_prompt_embeds( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, max_sequence_length: int = 512, - # fmt: off - system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", - # fmt: on + system_message: str = SYSTEM_MESSAGE, hidden_states_layers: List[int] = (10, 20, 30), ): dtype = text_encoder.dtype if dtype is None else dtype @@ -885,6 +865,7 @@ def __call__( prompt = self.upsample_prompt( prompt, images=image, temperature=caption_upsample_temperature, device=device ) + print(f"{prompt=}") prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, From 0b1f88445944cd0d48ded6478b73fa1ab1b44485 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 12:14:04 +0530 Subject: [PATCH 4/9] fix --- .../pipelines/flux2/pipeline_flux2.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 53da7235f1df..003f5154f06b 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -57,6 +57,7 @@ ``` """ + # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 def format_input( @@ -510,12 +511,13 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch def upsample_prompt( self, - prompts: List[str], + prompt: Union[str, List[str]], images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None, temperature: float = 0.15, device: torch.device = None, ) -> List[str]: - device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + device = self.text_encoder.device if device is None else device # Set system message based on whether images are provided if images is None or len(images) == 0 or images[0] is None: @@ -524,7 +526,7 @@ def upsample_prompt( system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I # Format input messages - messages_batch = format_input(prompts=prompts, system_message=system_message, images=images) + messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) # Process all messages at once # with image processing a too short max length can throw an error in here. @@ -560,10 +562,10 @@ def upsample_prompt( input_length = inputs["input_ids"].shape[1] generated_tokens = generated_ids[:, input_length:] - raw_txt = self.tokenizer.tokenizer.batch_decode( + upsampled_prompt = self.tokenizer.tokenizer.batch_decode( generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True ) - return raw_txt + return upsampled_prompt def encode_prompt( self, @@ -775,11 +777,11 @@ def __call__( The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. guidance_scale (`float`, *optional*, defaults to 1.0): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -865,8 +867,6 @@ def __call__( prompt = self.upsample_prompt( prompt, images=image, temperature=caption_upsample_temperature, device=device ) - print(f"{prompt=}") - prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, From ceb8a3a2f9d84c5eaf2c6f491965a36fe6ca5252 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 12:27:23 +0530 Subject: [PATCH 5/9] up --- src/diffusers/pipelines/flux2/system_messages.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/flux2/system_messages.py b/src/diffusers/pipelines/flux2/system_messages.py index a2e11d55f348..f520b051b155 100644 --- a/src/diffusers/pipelines/flux2/system_messages.py +++ b/src/diffusers/pipelines/flux2/system_messages.py @@ -1,3 +1,8 @@ +""" +These system prompts come from: +https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 +""" + SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" From 82685f2e11539dcfc6c8895d46ac8ad58e54ed75 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 26 Nov 2025 14:41:02 +0530 Subject: [PATCH 6/9] =?UTF-8?q?fix=20system=20prompts=20=F0=9F=A4=B7?= =?UTF-8?q?=E2=80=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pipelines/flux2/system_messages.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/flux2/system_messages.py b/src/diffusers/pipelines/flux2/system_messages.py index f520b051b155..bbcb207ff62f 100644 --- a/src/diffusers/pipelines/flux2/system_messages.py +++ b/src/diffusers/pipelines/flux2/system_messages.py @@ -3,25 +3,19 @@ https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 """ -SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, -object attribution and actions without speculation.""" +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" -SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while -strictly preserving their core subject and intent. +SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. Guidelines: -1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed - paragraphs. -2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), - shadows, spatial relationships, and environmental context. -3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text - for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates - gibberish. +1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs. +2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context. +3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish. Output only the revised prompt and nothing else.""" -SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction -(50-80 words, ~30 for brief requests). +SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). Rules: - Single instruction only, no commentary From 6397a67892479c2e1e68307620cc3041bb223002 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 27 Nov 2025 10:30:18 +0530 Subject: [PATCH 7/9] up --- .../pipelines/flux2/image_processor.py | 19 +++++++++++-------- .../pipelines/flux2/pipeline_flux2.py | 11 +++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index 6a36c959e415..a5920548ca86 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -96,21 +96,24 @@ def check_image_input( ) return image - + @staticmethod - def _resize_to_target_area( - image: PIL.Image.Image, target_area: int = 1024 * 1024, return_if_small_image: bool = False - ) -> PIL.Image.Image: + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: image_width, image_height = image.size - pixel_count = image_width * image_height - if return_if_small_image and pixel_count <= target_area: - return image - scale = math.sqrt(target_area / pixel_count) + scale = math.sqrt(target_area / (image_width * image_height)) width = int(image_width * scale) height = int(image_height * scale) return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + pixel_count = image_width * image_height + if pixel_count <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) def _resize_and_crop( self, diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 003f5154f06b..db63f8d1e966 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -57,6 +57,7 @@ ``` """ +UPSAMPLING_MAX_IMAGE_SIZE = 768**2 # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 @@ -95,8 +96,6 @@ def format_input( ] else: assert len(images) == len(prompts), "Number of images must match number of prompts" - images = _validate_and_process_images(images) - messages = [ [ { @@ -149,7 +148,7 @@ def _validate_and_process_images( # cap the pixels images = [ [ - image_processor._resize_to_target_area(img_i, upsampling_max_image_size, return_if_small_image=True) + image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i ] for img_i in images @@ -301,7 +300,7 @@ def __init__( self.system_message = SYSTEM_MESSAGE self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I - self.upsampling_max_image_size = 768**2 + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE @staticmethod def _get_mistral_3_small_prompt_embeds( @@ -525,6 +524,10 @@ def upsample_prompt( else: system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + # Validate and process the input images + if images: + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) + # Format input messages messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) From 7b65aa73430b05a341bda2cfdb26ae4a2a1b415f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Dec 2025 06:32:10 +0800 Subject: [PATCH 8/9] up --- src/diffusers/pipelines/flux2/image_processor.py | 4 ++-- src/diffusers/pipelines/flux2/pipeline_flux2.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index a5920548ca86..f1a8742491f7 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -96,7 +96,7 @@ def check_image_input( ) return image - + @staticmethod def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: image_width, image_height = image.size @@ -106,7 +106,7 @@ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 102 height = int(image_height * scale) return image.resize((width, height), PIL.Image.Resampling.LANCZOS) - + @staticmethod def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image: image_width, image_height = image.size diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 1ea250b1559e..b54a43dd89a5 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -59,6 +59,7 @@ UPSAMPLING_MAX_IMAGE_SIZE = 768**2 + # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 def format_input( @@ -147,10 +148,7 @@ def _validate_and_process_images( # cap the pixels images = [ - [ - image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) - for img_i in img_i - ] + [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i] for img_i in images ] return images @@ -527,7 +525,7 @@ def upsample_prompt( # Validate and process the input images if images: images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) - + # Format input messages messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) From ff4fe75a5d12125cb4db905a1de3934f2ea099b8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Dec 2025 06:37:17 +0800 Subject: [PATCH 9/9] up --- pyproject.toml | 3 +++ src/diffusers/pipelines/flux2/system_messages.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a864ea34b888..fdda8a6977be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,8 @@ [tool.ruff] line-length = 119 +extend-exclude = [ + "src/diffusers/pipelines/flux2/system_messages.py", +] [tool.ruff.lint] # Never enforce `E501` (line length violations). diff --git a/src/diffusers/pipelines/flux2/system_messages.py b/src/diffusers/pipelines/flux2/system_messages.py index bbcb207ff62f..ecdb1371f0d4 100644 --- a/src/diffusers/pipelines/flux2/system_messages.py +++ b/src/diffusers/pipelines/flux2/system_messages.py @@ -1,11 +1,14 @@ +# docstyle-ignore """ These system prompts come from: https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 """ +# docstyle-ignore SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" +# docstyle-ignore SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. Guidelines: @@ -15,6 +18,7 @@ Output only the revised prompt and nothing else.""" +# docstyle-ignore SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). Rules: