From ce8474fe74af98fd1cda9b17fdb4099ef70eedde Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 14:06:22 +0530 Subject: [PATCH 1/9] start flux modular tests. --- .../modular_pipelines/components_manager.py | 8 ++- .../modular_pipelines/flux/denoise.py | 5 +- .../modular_pipelines/flux/inputs.py | 4 ++ tests/modular_pipelines/flux/__init__.py | 0 .../flux/test_modular_pipeline_flux.py | 52 +++++++++++++++++++ ...st_modular_pipeline_stable_diffusion_xl.py | 20 ++----- 6 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 tests/modular_pipelines/flux/__init__.py create mode 100644 tests/modular_pipelines/flux/test_modular_pipeline_flux.py diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 9dd8035c44e7..e7af70eeae71 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -164,7 +164,13 @@ def __call__(self, hooks, model_id, model, execution_device): device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) - mem_on_device = device_module.mem_get_info(execution_device.index)[0] + try: + mem_on_device = device_module.mem_get_info(execution_device.index)[0] + except AttributeError: + try: + mem_on_device = device_module.recommended_max_memory() + except AttributeError: + raise NotImplementedError(f"Do not know how to obtain memory info for {str(device_module)}.") mem_on_device = mem_on_device - self.memory_reserve_margin if current_module_size < mem_on_device: return [] diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index b1796bb63cb0..eebabd31bda9 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ), InputParam( "guidance", - required=True, + required=False, type_hint=torch.Tensor, description="Guidance scale as a tensor", ), @@ -141,7 +141,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ), InputParam( "guidance", - required=True, + required=False, type_hint=torch.Tensor, description="Guidance scale as a tensor", ), @@ -182,6 +182,7 @@ def __call__( latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) + print(f"{latents.shape=}, {timestep.shape=}") noise_pred = components.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index e1bc17f5ff4e..8309eebfeb37 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -112,6 +112,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.prompt_embeds = block_state.prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + block_state.pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) self.set_block_state(state, block_state) return components, state diff --git a/tests/modular_pipelines/flux/__init__.py b/tests/modular_pipelines/flux/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py new file mode 100644 index 000000000000..ea54fdf71c0a --- /dev/null +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.modular_pipelines import FluxAutoBlocks, FluxModularPipeline + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class FluxModularFastTests(ModularPipelineTesterMixin, unittest.TestCase): + pipeline_class = FluxModularPipeline + pipeline_blocks_class = FluxAutoBlocks + repo = "hf-internal-testing/tiny-flux-modular" + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) + pipeline.load_components(torch_dtype=torch_dtype) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index d05f818135ab..22347aa5589c 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -21,24 +21,12 @@ import torch from PIL import Image -from diffusers import ( - ClassifierFreeGuidance, - StableDiffusionXLAutoBlocks, - StableDiffusionXLModularPipeline, -) +from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from diffusers.loaders import ModularIPAdapterMixin -from ...models.unets.test_models_unet_2d_condition import ( - create_ip_adapter_state_dict, -) -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modular_pipelines_common import ( - ModularPipelineTesterMixin, -) +from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin enable_full_determinism() From 6d2e102983123480f3e7578dc924f7017fced0bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 15:13:54 +0530 Subject: [PATCH 2/9] up --- .../modular_pipelines/flux/encoders.py | 2 +- .../flux/test_modular_pipeline_flux.py | 62 +++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index b71962bd9313..2a84152faea5 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -95,7 +95,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec( "image_processor", VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 16}), + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), default_creation_method="from_config", ), ] diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py index ea54fdf71c0a..442f3c46d0d7 100644 --- a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -13,21 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random +import tempfile import unittest +import numpy as np import torch -from diffusers.modular_pipelines import FluxAutoBlocks, FluxModularPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.modular_pipelines import FluxAutoBlocks, FluxModularPipeline, ModularPipeline +from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin -class FluxModularFastTests(ModularPipelineTesterMixin, unittest.TestCase): +class FluxModularTests: pipeline_class = FluxModularPipeline pipeline_blocks_class = FluxAutoBlocks repo = "hf-internal-testing/tiny-flux-modular" - params = frozenset(["prompt", "height", "width", "guidance_scale"]) - batch_params = frozenset(["prompt"]) def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) @@ -50,3 +53,54 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", } return inputs + + +class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase): + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + +class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase): + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = super().get_pipeline(components_manager, torch_dtype) + # Override `vae_scale_factor` here as currently, `image_processor` is initialized with + # fixed constants instead of + # https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10 + pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + inputs["image"] = image + inputs["strength"] = 0.8 + inputs["height"] = 8 + inputs["width"] = 8 + return inputs + + def test_save_from_pretrained(self): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + base_pipe.save_pretrained(tmpdirname) + pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) + + pipes.append(pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 From 1744f62cdc542104d45d60b10382aeeabe4bec63 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 17:34:57 +0530 Subject: [PATCH 3/9] add kontext --- .../modular_pipelines/flux/before_denoise.py | 3 +- .../flux/test_modular_pipeline_flux.py | 50 ++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index c098b7d4f1e5..5d51e91ff26b 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -598,7 +598,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip and getattr(block_state, "image_width", None) is not None ): image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) - image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) img_ids = FluxPipeline._prepare_latent_image_ids( None, image_latent_height // 2, image_latent_width // 2, device, dtype ) @@ -608,6 +608,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + print(f"{latent_ids.shape=}, {img_ids.shape=}") if img_ids is not None: latent_ids = torch.cat([latent_ids, img_ids], dim=0) diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py index 442f3c46d0d7..b7a81b57fd4a 100644 --- a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -18,15 +18,48 @@ import unittest import numpy as np +import PIL import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.modular_pipelines import FluxAutoBlocks, FluxModularPipeline, ModularPipeline +from diffusers.modular_pipelines import ( + FluxAutoBlocks, + FluxKontextAutoBlocks, + FluxKontextModularPipeline, + FluxModularPipeline, + ModularPipeline, +) +from diffusers.modular_pipelines.flux.modular_blocks import ( + AUTO_BLOCKS_KONTEXT, + FluxKontextAutoVaeEncoderStep, + FluxKontextProcessImagesInputStep, + FluxKontextVaeEncoderStep, + FluxVaeEncoderDynamicStep, +) +from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin +# Because we should disable `auto_resize` during tests. +FluxKontextVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", FluxKontextProcessImagesInputStep(_auto_resize=False)), + ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax")), + ] +) +FluxKontextVaeEncoderStep.block_classes = FluxKontextVaeEncoderBlocks.values() +FluxKontextVaeEncoderStep.block_names = FluxKontextVaeEncoderBlocks.keys() +FluxKontextAutoVaeEncoderStep.block_classes = [FluxKontextVaeEncoderStep] + +AUTO_BLOCKS_KONTEXT = AUTO_BLOCKS_KONTEXT.copy() +AUTO_BLOCKS_KONTEXT["image_encoder"] = FluxKontextAutoVaeEncoderStep + +FluxKontextAutoBlocks.block_classes = AUTO_BLOCKS_KONTEXT.values() +FluxKontextAutoBlocks.block_names = AUTO_BLOCKS_KONTEXT.keys() + + class FluxModularTests: pipeline_class = FluxModularPipeline pipeline_blocks_class = FluxAutoBlocks @@ -104,3 +137,18 @@ def test_save_from_pretrained(self): image_slices.append(image[0, -3:, -3:, -1].flatten()) assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + +class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests): + pipeline_class = FluxKontextModularPipeline + pipeline_blocks_class = FluxKontextAutoBlocks + repo = "hf-internal-testing/tiny-flux-kontext-pipe" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = PIL.Image.new("RGB", (32, 32), 0) + inputs["image"] = image + inputs["height"] = 8 + inputs["width"] = 8 + inputs["max_area"] = 8 * 8 + return inputs From 9ca72017cc97c3a7cdfb695821b3b1c31b7a6989 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 17:37:56 +0530 Subject: [PATCH 4/9] up --- src/diffusers/modular_pipelines/flux/before_denoise.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 5d51e91ff26b..9ad2acb38455 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -608,8 +608,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) - print(f"{latent_ids.shape=}, {img_ids.shape=}") - + if img_ids is not None: latent_ids = torch.cat([latent_ids, img_ids], dim=0) From b19654061c2bb9217d79564910cc403b2bf54b97 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 17:43:45 +0530 Subject: [PATCH 5/9] up --- src/diffusers/modular_pipelines/flux/before_denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 9ad2acb38455..daffec986535 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -608,7 +608,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) - + if img_ids is not None: latent_ids = torch.cat([latent_ids, img_ids], dim=0) From bbdc16a6e88d517f88c48ea68c2c9d25e5c1fa70 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 20:15:15 +0530 Subject: [PATCH 6/9] up --- .../modular_pipelines/flux/encoders.py | 9 ++---- .../flux/test_modular_pipeline_flux.py | 28 ++----------------- 2 files changed, 5 insertions(+), 32 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 2a84152faea5..f0314d4771b0 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -143,10 +143,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState): class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): model_name = "flux-kontext" - def __init__(self, _auto_resize=True): - self._auto_resize = _auto_resize - super().__init__() - @property def description(self) -> str: return ( @@ -167,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam("image")] + return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -195,7 +191,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState): img = images[0] image_height, image_width = components.image_processor.get_default_height_width(img) aspect_ratio = image_width / image_height - if self._auto_resize: + _auto_resize = block_state._auto_resize + if _auto_resize: # Kontext is trained on specific resolutions, using one of them is recommended _, image_width, image_height = min( (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py index b7a81b57fd4a..9d70c21aa8cd 100644 --- a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -29,37 +29,11 @@ FluxModularPipeline, ModularPipeline, ) -from diffusers.modular_pipelines.flux.modular_blocks import ( - AUTO_BLOCKS_KONTEXT, - FluxKontextAutoVaeEncoderStep, - FluxKontextProcessImagesInputStep, - FluxKontextVaeEncoderStep, - FluxVaeEncoderDynamicStep, -) -from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin -# Because we should disable `auto_resize` during tests. -FluxKontextVaeEncoderBlocks = InsertableDict( - [ - ("preprocess", FluxKontextProcessImagesInputStep(_auto_resize=False)), - ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax")), - ] -) -FluxKontextVaeEncoderStep.block_classes = FluxKontextVaeEncoderBlocks.values() -FluxKontextVaeEncoderStep.block_names = FluxKontextVaeEncoderBlocks.keys() -FluxKontextAutoVaeEncoderStep.block_classes = [FluxKontextVaeEncoderStep] - -AUTO_BLOCKS_KONTEXT = AUTO_BLOCKS_KONTEXT.copy() -AUTO_BLOCKS_KONTEXT["image_encoder"] = FluxKontextAutoVaeEncoderStep - -FluxKontextAutoBlocks.block_classes = AUTO_BLOCKS_KONTEXT.values() -FluxKontextAutoBlocks.block_names = AUTO_BLOCKS_KONTEXT.keys() - - class FluxModularTests: pipeline_class = FluxModularPipeline pipeline_blocks_class = FluxAutoBlocks @@ -147,8 +121,10 @@ class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests): def get_dummy_inputs(self, device, seed=0): inputs = super().get_dummy_inputs(device, seed) image = PIL.Image.new("RGB", (32, 32), 0) + _ = inputs.pop("strength") inputs["image"] = image inputs["height"] = 8 inputs["width"] = 8 inputs["max_area"] = 8 * 8 + inputs["_auto_resize"] = False return inputs From 751c7c73ffdc4e52ffb5d7ddadaf0bf28204e178 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 1 Nov 2025 06:56:10 +0530 Subject: [PATCH 7/9] Update src/diffusers/modular_pipelines/flux/denoise.py Co-authored-by: YiYi Xu --- src/diffusers/modular_pipelines/flux/denoise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index eebabd31bda9..5a769df1036d 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -182,7 +182,6 @@ def __call__( latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) - print(f"{latents.shape=}, {timestep.shape=}") noise_pred = components.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, From c171a1d6a0c66414006fcd763ef2217d08ac4db3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 1 Nov 2025 06:57:40 +0530 Subject: [PATCH 8/9] up --- src/diffusers/modular_pipelines/components_manager.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index e7af70eeae71..7c6d0af7eee5 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -164,14 +164,9 @@ def __call__(self, hooks, model_id, model, execution_device): device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) - try: - mem_on_device = device_module.mem_get_info(execution_device.index)[0] - except AttributeError: - try: - mem_on_device = device_module.recommended_max_memory() - except AttributeError: - raise NotImplementedError(f"Do not know how to obtain memory info for {str(device_module)}.") + mem_on_device = device_module.mem_get_info(execution_device.index)[0] mem_on_device = mem_on_device - self.memory_reserve_margin + if current_module_size < mem_on_device: return [] From 94058206d66c899dbf25679e09104ec62ba14353 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 1 Nov 2025 22:45:40 +0530 Subject: [PATCH 9/9] up --- src/diffusers/modular_pipelines/components_manager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 7c6d0af7eee5..cb7e8fb73697 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -164,9 +164,12 @@ def __call__(self, hooks, model_id, model, execution_device): device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) - mem_on_device = device_module.mem_get_info(execution_device.index)[0] - mem_on_device = mem_on_device - self.memory_reserve_margin + try: + mem_on_device = device_module.mem_get_info(execution_device.index)[0] + except AttributeError: + raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.") + mem_on_device = mem_on_device - self.memory_reserve_margin if current_module_size < mem_on_device: return [] @@ -700,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") + # TODO: add a warning if mem_get_info isn't available on `device`. + for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True)