Skip to content
8 changes: 7 additions & 1 deletion src/diffusers/modular_pipelines/components_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def __call__(self, hooks, model_id, model, execution_device):

device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
try:
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
except AttributeError:
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, what happended here?
#12566 (comment)
It was a good catch, I think we should also add a warning sooner during enable_model_offload()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad for misunderstanding!

I have added the error catching. I added a TODO in enable_auto_cpu_offload() to error out early for mem_get_info().

mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
Expand Down Expand Up @@ -699,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")

# TODO: add a warning if mem_get_info isn't available on `device`.

for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
and getattr(block_state, "image_width", None) is not None
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/modular_pipelines/flux/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
),
InputParam(
"guidance",
required=True,
required=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Flux.1-schnell, guidance shouldn't be required.

type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
Expand Down Expand Up @@ -141,7 +141,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
),
InputParam(
"guidance",
required=True,
required=False,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
Expand Down
11 changes: 4 additions & 7 deletions src/diffusers/modular_pipelines/flux/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def expected_components(self) -> List[ComponentSpec]:
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be consistent.

default_creation_method="from_config",
),
]
Expand Down Expand Up @@ -143,10 +143,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux-kontext"

def __init__(self, _auto_resize=True):
self._auto_resize = _auto_resize
super().__init__()

@property
def description(self) -> str:
return (
Expand All @@ -167,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]:

@property
def inputs(self) -> List[InputParam]:
return [InputParam("image")]
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]

@property
def intermediate_outputs(self) -> List[OutputParam]:
Expand Down Expand Up @@ -195,7 +191,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
img = images[0]
image_height, image_width = components.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if self._auto_resize:
_auto_resize = block_state._auto_resize
if _auto_resize:
# Kontext is trained on specific resolutions, using one of them is recommended
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/modular_pipelines/flux/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, -1
)
Comment on lines +115 to +118
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was left out of the expansion incorrectly.

self.set_block_state(state, block_state)

return components, state
Expand Down
Empty file.
130 changes: 130 additions & 0 deletions tests/modular_pipelines/flux/test_modular_pipeline_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import tempfile
import unittest

import numpy as np
import PIL
import torch

from diffusers.image_processor import VaeImageProcessor
from diffusers.modular_pipelines import (
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,
FluxModularPipeline,
ModularPipeline,
)

from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin


class FluxModularTests:
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular"

def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_components(torch_dtype=torch_dtype)
return pipeline

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs


class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])


class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])

def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
# fixed constants instead of
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
return pipeline

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
inputs["height"] = 8
inputs["width"] = 8
return inputs

def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)

with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)

pipes.append(pipe)

image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")

image_slices.append(image[0, -3:, -3:, -1].flatten())

assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3


class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
repo = "hf-internal-testing/tiny-flux-kontext-pipe"

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = PIL.Image.new("RGB", (32, 32), 0)
_ = inputs.pop("strength")
inputs["image"] = image
inputs["height"] = 8
inputs["width"] = 8
inputs["max_area"] = 8 * 8
inputs["_auto_resize"] = False
return inputs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,12 @@
import torch
from PIL import Image

from diffusers import (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes but I thought I would do it (happy to revert).

ClassifierFreeGuidance,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from diffusers.loaders import ModularIPAdapterMixin

from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin


enable_full_determinism()
Expand Down
Loading