From ee842839ef84c4025ae09848034bdb23edbe148b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 27 Feb 2025 19:18:10 +0100 Subject: [PATCH 01/54] add componentspec and configspec --- src/diffusers/guider.py | 3 + src/diffusers/pipelines/modular_pipeline.py | 159 ++- .../pipeline_stable_diffusion_xl_modular.py | 1092 ++++++++++++++++- 3 files changed, 1215 insertions(+), 39 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 7445b7ba97af..b42dca64d651 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -743,3 +743,6 @@ def apply_guidance( # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred + + +Guiders = Union[CFGGuider, PAGGuider, APGGuider] \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index b50d00dbc219..683b3940c723 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -16,7 +16,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional, Type import torch @@ -338,11 +338,28 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""): return output +@dataclass +class ComponentSpec: + """Specification for a pipeline component.""" + name: str + type_hint: Optional[Type] = None + description: Optional[str] = None + default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor + default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"] + default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None + type_hint: Optional[Type] = None + class PipelineBlock: - # YiYi Notes: do we need this? - # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list - expected_components = [] - expected_configs = [] + + component_specs: List[ComponentSpec] = [] + config_specs: List[ConfigSpec] = [] model_name = None @property @@ -409,14 +426,45 @@ def __repr__(self): desc = '\n'.join(desc) + '\n' # Components section - expected_components = set(getattr(self, "expected_components", [])) + expected_components = getattr(self, "expected_components", []) + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) + all_components = sorted(expected_component_names | loaded_components) main_components = [] auxiliary_components = [] for k in all_components: - component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + # Get component spec if available + component_spec = next((comp for comp in expected_components if comp.name == k), None) + + if k in loaded_components: + component_type = type(self.components[k]).__name__ + component_str = f" - {k}={component_type}" + + # Add expected type info if available + if component_spec and component_spec.class_name: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + if expected_type != component_type: + component_str += f" (expected: {expected_type})" + else: + # Component not loaded but expected + if component_spec: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + component_str = f" - {k} (expected: {expected_type})" + + # Add repo info if available + if component_spec.default_repo: + repo_info = component_spec.default_repo + if component_spec.subfolder: + repo_info += f", subfolder={component_spec.subfolder}" + component_str += f" [{repo_info}]" + else: + component_str = f" - {k}" + if k in getattr(self, "auxiliary_components", []): auxiliary_components.append(component_str) else: @@ -793,18 +841,52 @@ def __repr__(self): desc = '\n'.join(desc) + '\n' # Components section - expected_components = set(getattr(self, "expected_components", [])) + expected_components = getattr(self, "expected_components", []) + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) + all_components = sorted(expected_component_names | loaded_components) # Auxiliaries section auxiliaries_str = " Auxiliaries:\n" + "\n".join( f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() ) + main_components = [] + for k in all_components: + # Get component spec if available + component_spec = next((comp for comp in expected_components if comp.name == k), None) + + if k in loaded_components: + component_type = type(self.components[k]).__name__ + component_str = f" - {k}={component_type}" + + # Add expected type info if available + if component_spec and component_spec.class_name: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + if expected_type != component_type: + component_str += f" (expected: {expected_type})" + else: + # Component not loaded but expected + if component_spec: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + component_str = f" - {k} (expected: {expected_type})" + + # Add repo info if available + if component_spec.default_repo: + repo_info = component_spec.default_repo + if component_spec.subfolder: + repo_info += f", subfolder={component_spec.subfolder}" + component_str += f" [{repo_info}]" + else: + component_str = f" - {k}" + + + main_components.append(component_str) + + components = "Components:\n" + "\n".join(main_components) # Configs section expected_configs = set(getattr(self, "expected_configs", [])) @@ -1188,19 +1270,54 @@ def __repr__(self): desc = '\n'.join(desc) + '\n' # Components section - expected_components = set(getattr(self, "expected_components", [])) + expected_components = getattr(self, "expected_components", []) + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) + all_components = sorted(expected_component_names | loaded_components) # Auxiliaries section auxiliaries_str = " Auxiliaries:\n" + "\n".join( f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() ) + main_components = [] + for k in all_components: + # Get component spec if available + component_spec = next((comp for comp in expected_components if comp.name == k), None) + + if k in loaded_components: + component_type = type(self.components[k]).__name__ + component_str = f" - {k}={component_type}" + + # Add expected type info if available + if component_spec and component_spec.class_name: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + if expected_type != component_type: + component_str += f" (expected: {expected_type})" + else: + # Component not loaded but expected + if component_spec: + expected_type = component_spec.class_name + if isinstance(expected_type, (list, tuple)): + expected_type = expected_type[1] # Get class name from [module, class_name] + component_str = f" - {k} (expected: {expected_type})" + + # Add repo info if available + if component_spec.default_repo: + repo_info = component_spec.default_repo + if component_spec.subfolder: + repo_info += f", subfolder={component_spec.subfolder}" + component_str += f" [{repo_info}]" + else: + component_str = f" - {k}" + + + main_components.append(component_str) + + components = "Components:\n" + "\n".join(main_components) + # Configs section expected_configs = set(getattr(self, "expected_configs", [])) loaded_configs = set(self.configs.keys()) @@ -1558,7 +1675,7 @@ def __repr__(self): return output - # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline + # YiYi TODO: try to unify the to method with the one in DiffusionPipeline # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to def to(self, *args, **kwargs): r""" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index f743f442cc40..9c7ad5ec19c0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -41,12 +41,25 @@ InputParam, OutputParam, SequentialPipelineBlocks, + ComponentSpec, + ConfigSpec, ) from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import ( StableDiffusionXLPipelineOutput, ) +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import KarrasDiffusionSchedulers +from ...guider import Guiders, CFGGuider + import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -128,7 +141,23 @@ def retrieve_latents( class StableDiffusionXLLoraStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "unet"] + expected_components = [ + ComponentSpec( + name="text_encoder", + type_hint=CLIPTextModel, + default_class_name=["transformers", "CLIPTextModel"], + ), + ComponentSpec( + name="text_encoder_2", + type_hint=CLIPTextModelWithProjection, + default_class_name=["transformers", "CLIPTextModelWithProjection"], + ), + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + default_class_name=["diffusers", "UNet2DConditionModel"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -164,7 +193,23 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLIPAdapterStep(PipelineBlock): - expected_components = ["image_encoder", "feature_extractor", "unet"] + expected_components = [ + ComponentSpec( + name="image_encoder", + type_hint=CLIPVisionModelWithProjection, + default_class_name=["transformers", "CLIPVisionModelWithProjection"], + ), + ComponentSpec( + name="feature_extractor", + type_hint=CLIPImageProcessor, + default_class_name=["transformers", "CLIPImageProcessor"], + ), + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + default_class_name=["diffusers", "UNet2DConditionModel"], + ) + ] model_name = "stable-diffusion-xl" @@ -236,8 +281,37 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] - expected_configs = ["force_zeros_for_empty_prompt"] + expected_components = [ + ComponentSpec( + name="text_encoder", + type_hint=CLIPTextModel, + default_class_name=["transformers", "CLIPTextModel"], + ), + ComponentSpec( + name="text_encoder_2", + type_hint=CLIPTextModelWithProjection, + default_class_name=["transformers", "CLIPTextModelWithProjection"], + ), + ComponentSpec( + name="tokenizer", + type_hint=CLIPTokenizer, + default_class_name=["transformers", "CLIPTokenizer"], + ), + ComponentSpec( + name="tokenizer_2", + type_hint=CLIPTokenizer, + default_class_name=["transformers", "CLIPTokenizer"], + ) + ] + + expected_configs = [ + ConfigSpec( + name="force_zeros_for_empty_prompt", + default=True, + type_hint=bool + ) + ] + model_name = "stable-diffusion-xl" @property @@ -357,7 +431,18 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] + expected_components = [ + ComponentSpec( + name="vae", + type_hint=AutoencoderKL, + default_class_name=["diffusers", "AutoencoderKL"], + ), + ComponentSpec( + name="image_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor() + ) + ] model_name = "stable-diffusion-xl" @@ -443,7 +528,23 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] + expected_components = [ + ComponentSpec( + name="vae", + type_hint=AutoencoderKL, + default_class_name=["diffusers", "AutoencoderKL"], + ), + ComponentSpec( + name="image_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor() + ), + ComponentSpec( + name="mask_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) + ) + ] model_name = "stable-diffusion-xl" @property @@ -687,7 +788,13 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] + expected_components = [ + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -801,7 +908,13 @@ def denoising_value_valid(dnv): class StableDiffusionXLSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] + expected_components = [ + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -874,7 +987,13 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] + expected_components = [ + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -1024,7 +1143,18 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - expected_components = ["vae", "scheduler"] + expected_components = [ + ComponentSpec( + name="vae", + type_hint=AutoencoderKL, + default_class_name=["diffusers", "AutoencoderKL"], + ), + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -1101,7 +1231,13 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] + expected_components = [ + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -1219,7 +1355,13 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - expected_configs = ["requires_aesthetics_score"] + expected_configs = [ + ConfigSpec( + name="requires_aesthetics_score", + default=False, + type_hint=bool, + ) + ] model_name = "stable-diffusion-xl" @property @@ -1516,7 +1658,23 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLDenoiseStep(PipelineBlock): - expected_components = ["unet", "scheduler", "guider"] + expected_components = [ + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + default_class_name=["diffusers", "UNet2DConditionModel"], + ), + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ), + ComponentSpec( + name="guider", + type_hint=Guiders, + default_class_name=["diffusers", "CFGGuider"], + ) + ] model_name = "stable-diffusion-xl" @property @@ -1803,7 +1961,39 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + expected_components = [ + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + default_class_name=["diffusers", "UNet2DConditionModel"], + ), + ComponentSpec( + name="controlnet", + type_hint=ControlNetModel, + default_class_name=["diffusers", "ControlNetModel"], + ), + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ), + ComponentSpec( + name="guider", + type_hint=Guiders, + default_class_name=["diffusers", "CFGGuider"], + ), + ComponentSpec( + name="controlnet_guider", + type_hint=Guiders, + default_class_name=["diffusers", "CFGGuider"], + ), + + ComponentSpec( + name="control_image_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False) + ) + ] model_name = "stable-diffusion-xl" @property @@ -2257,7 +2447,38 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + expected_components = [ + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + default_class_name=["diffusers", "UNet2DConditionModel"], + ), + ComponentSpec( + name="controlnet", + type_hint=ControlNetUnionModel, + default_class_name=["diffusers", "ControlNetUnionModel"], + ), + ComponentSpec( + name="scheduler", + type_hint=KarrasDiffusionSchedulers, + default_class_name=["diffusers", "EulerDiscreteScheduler"], + ), + ComponentSpec( + name="guider", + type_hint=Guiders, + default_class_name=["diffusers", "CFGGuider"], + ), + ComponentSpec( + name="controlnet_guider", + type_hint=Guiders, + default_class_name=["diffusers", "CFGGuider"], + ), + ComponentSpec( + name="control_image_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False) + ) + ] model_name = "stable-diffusion-xl" @property @@ -2700,7 +2921,18 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - expected_components = ["vae"] + expected_components = [ + ComponentSpec( + name="vae", + type_hint=AutoencoderKL, + default_class_name=["diffusers", "AutoencoderKL"], + ), + ComponentSpec( + name="image_processor", + type_hint=VaeImageProcessor, + default=VaeImageProcessor() + ) + ] model_name = "stable-diffusion-xl" @property @@ -2899,6 +3131,7 @@ def description(self): " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] @@ -2912,6 +3145,7 @@ def description(self): " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] @@ -2955,7 +3189,6 @@ def description(self): " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise - class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] block_names = ["decode", "output"] @@ -2981,7 +3214,6 @@ def description(self): " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] block_names = ["inpaint", "non-inpaint"] @@ -2994,6 +3226,7 @@ def description(self): " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLIPAdapterStep] block_names = ["ip_adapter"] @@ -3003,6 +3236,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): def description(self): return "Run IP Adapter step if `ip_adapter_image` is provided." + class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] @@ -3907,3 +4141,825 @@ def get_guidance_scale_embedding( emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb + + ) + if data.do_classifier_free_guidance: + data.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(data.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + data.negative_ip_adapter_embeds.append(negative_image_embeds) + data.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, data) + return pipeline, state + + + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids_img2img( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + # YiYi TODO: refactor using _encode_vae_image + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents_inpaint( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb From bf99ab2f55e03ac42fcb18fec4b4e476a28524dc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 9 Apr 2025 20:36:45 +0200 Subject: [PATCH 02/54] up --- src/diffusers/pipelines/modular_pipeline.py | 63 +- .../pipeline_stable_diffusion_xl_modular.py | 2445 ++++------------- 2 files changed, 593 insertions(+), 1915 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 683b3940c723..4961d158e10d 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -138,13 +138,31 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" +@dataclass +class ComponentSpec: + """Specification for a pipeline component.""" + name: str + type_hint: Type + description: Optional[str] = None + default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor + default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"] + default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None + + @dataclass class InputParam: name: str + type_hint: Any = None default: Any = None required: bool = False description: str = "" - type_hint: Any = Any def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -152,8 +170,8 @@ def __repr__(self): @dataclass class OutputParam: name: str + type_hint: Any description: str = "" - type_hint: Any = Any def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" @@ -338,49 +356,40 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""): return output -@dataclass -class ComponentSpec: - """Specification for a pipeline component.""" - name: str - type_hint: Optional[Type] = None - description: Optional[str] = None - default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor - default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"] - default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None - type_hint: Optional[Type] = None class PipelineBlock: - - component_specs: List[ComponentSpec] = [] - config_specs: List[ConfigSpec] = [] + model_name = None @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" raise NotImplementedError("description method must be implemented in subclasses") + + @property + def components(self) -> List[ComponentSpec]: + return [] + @property + def configs(self) -> List[ConfigSpec]: + return [] + + + # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("inputs method must be implemented in subclasses") + return [] @property def intermediates_inputs(self) -> List[InputParam]: """List of intermediate input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_inputs method must be implemented in subclasses") + return [] @property def intermediates_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_outputs method must be implemented in subclasses") + return [] # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @property @@ -403,10 +412,6 @@ def required_intermediates_inputs(self) -> List[str]: input_names.append(input_param.name) return input_names - def __init__(self): - self.components: Dict[str, Any] = {} - self.auxiliaries: Dict[str, Any] = {} - self.configs: Dict[str, Any] = {} def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 9c7ad5ec19c0..5e2b8ae779d7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -141,23 +141,7 @@ def retrieve_latents( class StableDiffusionXLLoraStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="text_encoder", - type_hint=CLIPTextModel, - default_class_name=["transformers", "CLIPTextModel"], - ), - ComponentSpec( - name="text_encoder_2", - type_hint=CLIPTextModelWithProjection, - default_class_name=["transformers", "CLIPTextModelWithProjection"], - ), - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - default_class_name=["diffusers", "UNet2DConditionModel"], - ) - ] + model_name = "stable-diffusion-xl" @property @@ -167,49 +151,22 @@ def description(self) -> str: " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" " for more details" ) - - - @property - def inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] @property - def intermediates_outputs(self) -> List[OutputParam]: - return [] + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("unet", UNet2DConditionModel), + ] - def __init__(self): - super().__init__() - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["unet"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") -class StableDiffusionXLIPAdapterStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="image_encoder", - type_hint=CLIPVisionModelWithProjection, - default_class_name=["transformers", "CLIPVisionModelWithProjection"], - ), - ComponentSpec( - name="feature_extractor", - type_hint=CLIPImageProcessor, - default_class_name=["transformers", "CLIPImageProcessor"], - ), - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - default_class_name=["diffusers", "UNet2DConditionModel"], - ) - ] +class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): model_name = "stable-diffusion-xl" @@ -220,26 +177,30 @@ def description(self) -> str: " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor), + ComponentSpec("unet", UNet2DConditionModel), + ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "ip_adapter_image", - required=True, - type_hint=PipelineImageInput, + PipelineImageInput, + required=True, description="The image(s) to be used as ip adapter" ), InputParam( "guidance_scale", default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`." ), ] - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] @property def intermediates_outputs(self) -> List[OutputParam]: @@ -248,12 +209,6 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") ] - def __init__(self): - super().__init__() - self.components["image_encoder"] = None - self.components["feature_extractor"] = None - self.components["unet"] = None - @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -281,36 +236,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="text_encoder", - type_hint=CLIPTextModel, - default_class_name=["transformers", "CLIPTextModel"], - ), - ComponentSpec( - name="text_encoder_2", - type_hint=CLIPTextModelWithProjection, - default_class_name=["transformers", "CLIPTextModelWithProjection"], - ), - ComponentSpec( - name="tokenizer", - type_hint=CLIPTokenizer, - default_class_name=["transformers", "CLIPTokenizer"], - ), - ComponentSpec( - name="tokenizer_2", - type_hint=CLIPTokenizer, - default_class_name=["transformers", "CLIPTokenizer"], - ) - ] - - expected_configs = [ - ConfigSpec( - name="force_zeros_for_empty_prompt", - default=True, - type_hint=bool - ) - ] model_name = "stable-diffusion-xl" @@ -319,52 +244,32 @@ def description(self) -> str: return( "Text Encoder step that generate text_embeddings to guide the image generation" ) + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ] + @property + def configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name="prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts to guide the image generation.", - ), - InputParam( - name="prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders", - ), - InputParam( - name="negative_prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", - ), - InputParam( - name="negative_prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders", - ), - InputParam( - name="cross_attention_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor]", - ), - InputParam( - name="guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", - ), - InputParam( - name="clip_skip", - type_hint=Optional[int], - ), + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("guidance_scale",default=5.0), + InputParam("clip_skip"), ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] @property def intermediates_outputs(self) -> List[OutputParam]: @@ -375,14 +280,6 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), ] - def __init__(self): - super().__init__() - self.configs["force_zeros_for_empty_prompt"] = True - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["tokenizer"] = None - self.components["tokenizer_2"] = None - def check_inputs(self, pipeline, data): if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): @@ -431,18 +328,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLVaeEncoderStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="vae", - type_hint=AutoencoderKL, - default_class_name=["diffusers", "AutoencoderKL"], - ), - ComponentSpec( - name="image_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor() - ) - ] + model_name = "stable-diffusion-xl" @@ -452,37 +338,20 @@ def description(self) -> str: "Vae Encoder step that encode the input image into a latent representation" ) + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()), + ] + @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name="image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - name="generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)" - "to make generation deterministic." - ), - InputParam( - name="height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - name="width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), + InputParam("image", required=True), + InputParam("generator"), + InputParam("height"), + InputParam("width"), ] @property @@ -495,10 +364,6 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor() @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -528,24 +393,16 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="vae", - type_hint=AutoencoderKL, - default_class_name=["diffusers", "AutoencoderKL"], - ), - ComponentSpec( - name="image_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor() - ), - ComponentSpec( - name="mask_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - ) - ] model_name = "stable-diffusion-xl" + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()), + ComponentSpec("mask_processor", VaeImageProcessor, default=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), + ] + @property def description(self) -> str: @@ -556,53 +413,12 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "image", - required=True, - type_hint=PipelineImageInput, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - required=True, - type_hint=PipelineImageInput, - description="`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be " - "repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted " - "to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) " - "instead of 3, so the expected shape would be `(B, H, W, 1)`." - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to " - "image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region " - "with the same aspect ratio of the image and contains all masked area, and then expand that area based " - "on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before " - "resizing to the original image size for inpainting. This is useful when the masked area is small while " - "the image is large and contain information irrelevant for inpainting, such as background." - ), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), ] @property @@ -615,12 +431,7 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - def __init__(self): - super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - self.components["vae"] = None + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -682,12 +493,7 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name="num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt.", - ), + InputParam("num_images_per_prompt", default=1), ] @property @@ -788,15 +594,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ) - ] model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ] + @property def description(self) -> str: return ( @@ -807,50 +612,14 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the" - " expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), - InputParam( - "strength", - default=0.3, - type_hint=float, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="The denoising start value to use for the scheduler. Determines the starting point of the denoising process." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt. Defaults to 1." - ), + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), ] @property @@ -867,9 +636,6 @@ def intermediates_outputs(self) -> List[str]: OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] - def __init__(self): - super().__init__() - self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -908,14 +674,14 @@ def denoising_value_valid(dnv): class StableDiffusionXLSetTimestepsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ) - ] + model_name = "stable-diffusion-xl" + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ] @property def description(self) -> str: @@ -926,27 +692,10 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), ] @property @@ -954,13 +703,6 @@ def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -987,15 +729,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ) - ] + model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ] + @property def description(self) -> str: return ( @@ -1005,31 +747,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic."), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), InputParam( "strength", - default=0.9999, - type_hint=float, + default=0.9999, description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " @@ -1085,9 +809,6 @@ def intermediates_outputs(self) -> List[str]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - def __init__(self): - super().__init__() - self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1143,20 +864,15 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="vae", - type_hint=AutoencoderKL, - default_class_name=["diffusers", "AutoencoderKL"], - ), - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ) - ] model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ] + @property def description(self) -> str: return ( @@ -1166,28 +882,10 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), ] @property @@ -1202,9 +900,6 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - def __init__(self): - super().__init__() - self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1231,15 +926,14 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ) - ] model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ] + @property def description(self) -> str: return ( @@ -1249,37 +943,11 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), ] @property @@ -1308,9 +976,6 @@ def intermediates_outputs(self) -> List[OutputParam]: ) ] - def __init__(self): - super().__init__() - self.components["scheduler"] = None @staticmethod def check_inputs(pipeline, data): @@ -1355,14 +1020,12 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - expected_configs = [ - ConfigSpec( - name="requires_aesthetics_score", - default=False, - type_hint=bool, - ) - ] + model_name = "stable-diffusion-xl" + + @property + def configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", default=False),] @property def description(self) -> str: @@ -1373,75 +1036,16 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "original_size", - type_hint=Optional[Tuple[int]], - description="If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. " - "`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as " - "explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Optional[Tuple[int]], - description="For most cases, `target_size` should be set to the desired height and width of the generated image. If " - "not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in " - "section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a specific image resolution. Part of SDXL's " - "micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_target_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a target image resolution. It should be as same " - "as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of " - "https://huggingface.co/papers/2307.01952" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="`crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position " - "`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's " - "micro-conditioning" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality." - ), - InputParam( - "aesthetic_score", - default=6.0, - type_hint=float, - description="Used to simulate an aesthetic score of the generated image by influencing the positive text condition. " - "Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_aesthetic_score", - default=2.0, - type_hint=float, - description="Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. " - "Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition." - ), + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", required=True), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), ] @property @@ -1458,27 +1062,75 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - def __init__(self): - super().__init__() - self.configs["requires_aesthetics_score"] = False - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - if data.negative_original_size is None: + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) + data.device = pipeline._execution_device + + data.vae_scale_factor = pipeline.vae_scale_factor + + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * data.vae_scale_factor + data.width = data.width * data.vae_scale_factor + + data.original_size = data.original_size or (data.height, data.width) + data.target_size = data.target_size or (data.height, data.width) + + data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) + + if data.negative_original_size is None: data.negative_original_size = data.original_size if data.negative_target_size is None: data.negative_target_size = data.target_size @@ -1526,55 +1178,14 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The original size (height, width) of the image that conditions the generation process. If different from target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The target size (height, width) of the generated image. For most cases, this should be set to the desired output dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "negative_target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to condition the generation process. Setting this to (0, 0) typically produces well-centered images. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to negatively condition the generation process. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. For more information, see: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality."), + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), ] @property @@ -1606,6 +1217,58 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -1658,25 +1321,17 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLDenoiseStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - default_class_name=["diffusers", "UNet2DConditionModel"], - ), - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ), - ComponentSpec( - name="guider", - type_hint=Guiders, - default_class_name=["diffusers", "CFGGuider"], - ) - ] + model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("guider", CFGGuider), + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("unet", UNet2DConditionModel), + ] + @property def description(self) -> str: return ( @@ -1686,47 +1341,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - type_hint=float, - default=0.0, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - type_hint=float, - default=0.0, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the Guider." - ), - InputParam( - "num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt." - ), + InputParam("guidance_scale", default=5.0), + InputParam("guidance_rescale", default=0.0), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + InputParam("guider_kwargs"), + InputParam("num_images_per_prompt", default=1), ] @property @@ -1830,11 +1451,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None def check_inputs(self, pipeline, data): @@ -1961,41 +1577,20 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - default_class_name=["diffusers", "UNet2DConditionModel"], - ), - ComponentSpec( - name="controlnet", - type_hint=ControlNetModel, - default_class_name=["diffusers", "ControlNetModel"], - ), - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ), - ComponentSpec( - name="guider", - type_hint=Guiders, - default_class_name=["diffusers", "CFGGuider"], - ), - ComponentSpec( - name="controlnet_guider", - type_hint=Guiders, - default_class_name=["diffusers", "CFGGuider"], - ), - - ComponentSpec( - name="control_image_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - ) - ] + model_name = "stable-diffusion-xl" + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("guider", CFGGuider), + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("controlnet_guider", CFGGuider), + ] + @property def description(self) -> str: return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" @@ -2003,78 +1598,18 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching." - ), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying." - ), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying." - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider." - ), + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), + InputParam("guidance_rescale", default=0.0), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + InputParam("guider_kwargs"), ] @property @@ -2183,16 +1718,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - def check_inputs(self, pipeline, data): num_channels_unet = pipeline.unet.config.in_channels @@ -2447,39 +1972,18 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - default_class_name=["diffusers", "UNet2DConditionModel"], - ), - ComponentSpec( - name="controlnet", - type_hint=ControlNetUnionModel, - default_class_name=["diffusers", "ControlNetUnionModel"], - ), - ComponentSpec( - name="scheduler", - type_hint=KarrasDiffusionSchedulers, - default_class_name=["diffusers", "EulerDiscreteScheduler"], - ), - ComponentSpec( - name="guider", - type_hint=Guiders, - default_class_name=["diffusers", "CFGGuider"], - ), - ComponentSpec( - name="controlnet_guider", - type_hint=Guiders, - default_class_name=["diffusers", "CFGGuider"], - ), - ComponentSpec( - name="control_image_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - ) - ] model_name = "stable-diffusion-xl" + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("guider", CFGGuider), + ComponentSpec("controlnet_guider", CFGGuider), + ComponentSpec("control_image_processor", VaeImageProcessor, default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ] @property def description(self) -> str: @@ -2487,75 +1991,19 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching."), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying."), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying."), - InputParam( - "control_mode", - required=True, - type_hint=List[int], - description="The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment" - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor."), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic."), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider."), + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("control_mode", required=True), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), + InputParam("guidance_rescale", default=0.0), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + InputParam("guider_kwargs") ] @property @@ -2664,16 +2112,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - def check_inputs(self, pipeline, data): num_channels_unet = pipeline.unet.config.in_channels @@ -2921,19 +2359,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - expected_components = [ - ComponentSpec( - name="vae", - type_hint=AutoencoderKL, - default_class_name=["diffusers", "AutoencoderKL"], - ), - ComponentSpec( - name="image_processor", - type_hint=VaeImageProcessor, - default=VaeImageProcessor() - ) - ] + model_name = "stable-diffusion-xl" + + @property + def components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()) + ] @property def description(self) -> str: @@ -2942,12 +2376,7 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "output_type", - type_hint=str, - default="pil", - description="The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array." - ), + InputParam("output_type", default="pil"), ] @property @@ -2958,11 +2387,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8) - @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -3028,24 +2452,9 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - type_hint=PipelineImageInput, - required=True, - description="The mask image(s) to use for inpainting, white pixels in the mask will be repainted, while black pixels will be preserved. If mask_image is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be (B, H, W, 1). Must be a `PIL.Image.Image`" - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - default=None, - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied. If set, it will find a rectangular region with the same aspect ratio as the image that contains all masked areas, then expand that area by this margin. The image and mask_image are cropped to this expanded area before resizing to the original size for inpainting. Useful when the masked area is small in a large image with irrelevant background information." - ), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), ] @property @@ -3080,7 +2489,7 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - return [(InputParam("return_dict", type_hint=bool, default=True, description="Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple."))] + return [InputParam("return_dict", default=True)] @property def intermediates_inputs(self) -> List[str]: @@ -3201,959 +2610,136 @@ def description(self): - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - -class StableDiffusionXLModularPipeline( - ModularPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids - def _get_add_time_ids_img2img( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents - # YiYi TODO: refactor using _encode_vae_image - def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents - def prepare_latents_inpaint( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = self.vae.config.scaling_factor * image_latents + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ + " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs +# block mapping +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 +INPAINT_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. +CONTROLNET_BLOCKS = OrderedDict([ + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), +]) - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) - ) - if data.do_classifier_free_guidance: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds +AUTO_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) - self.add_block_state(state, data) - return pipeline, state +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} +class StableDiffusionXLComponents( + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): @property def default_sample_size(self): default_sample_size = 128 @@ -4182,24 +2768,6 @@ def num_channels_latents(self): num_channels_latents = self.vae.config.latent_channels return num_channels_latents - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids_img2img( @@ -4963,3 +3531,108 @@ def get_guidance_scale_embedding( emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb + + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +sdxl_inputs_map = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "guidance_scale": InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-Free Diffusion Guidance scale"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "guidance_rescale": InputParam("guidance_rescale", type_hint=float, default=0.0, description="Guidance rescale factor to fix overexposure"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +sdxl_intermediate_inputs_map = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +sdxl_intermediate_outputs_map = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +sdxl_outputs_map = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} \ No newline at end of file From 9ad1470d48470103a57e430c5bfa2ea7ed298f66 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 11 Apr 2025 18:29:21 +0200 Subject: [PATCH 03/54] up --- src/diffusers/pipelines/modular_pipeline.py | 18 +- .../pipeline_stable_diffusion_xl_modular.py | 2078 +++++++++-------- 2 files changed, 1121 insertions(+), 975 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 4961d158e10d..954b78d417ce 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -144,7 +144,7 @@ class ComponentSpec: name: str type_hint: Type description: Optional[str] = None - default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor + obj: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"] default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] @@ -185,6 +185,16 @@ def format_inputs_short(inputs): Returns: str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' """ required_inputs = [param for param in inputs if param.required] optional_inputs = [param for param in inputs if not param.required] @@ -367,13 +377,13 @@ def description(self) -> str: raise NotImplementedError("description method must be implemented in subclasses") @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [] @property - def configs(self) -> List[ConfigSpec]: + def expected_configs(self) -> List[ConfigSpec]: return [] - + # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5e2b8ae779d7..23ea96b8e8a0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -140,6 +140,7 @@ def retrieve_latents( +# YiYi Notes: I think we do not need this, we can add loader methods on the components class class StableDiffusionXLLoraStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -153,7 +154,7 @@ def description(self) -> str: ) @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("text_encoder", CLIPTextModel), ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), @@ -179,7 +180,7 @@ def description(self) -> str: ) @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_encoder", CLIPVisionModelWithProjection), ComponentSpec("feature_extractor", CLIPImageProcessor), @@ -209,6 +210,76 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") ] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -246,7 +317,7 @@ def description(self) -> str: ) @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("text_encoder", CLIPTextModel), ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), @@ -255,7 +326,7 @@ def components(self) -> List[ComponentSpec]: ] @property - def configs(self) -> List[ConfigSpec]: + def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("force_zeros_for_empty_prompt", True)] @property @@ -287,231 +358,591 @@ def check_inputs(self, pipeline, data): elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale - # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None - ) - ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) = pipeline.encode_prompt( - data.prompt, - data.prompt_2, - data.device, - 1, - data.do_classifier_free_guidance, - data.negative_prompt, - data.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, - ) - # Add outputs - self.add_block_state(state, data) - return pipeline, state + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) -class StableDiffusionXLVaeEncoderStep(PipelineBlock): + prompt = [prompt] if isinstance(prompt, str) else prompt - model_name = "stable-diffusion-xl" + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] ) - @property - def components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()), - ] + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("generator"), - InputParam("height"), - InputParam("width"), - ] + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - data.batch_size = data.image.shape[0] + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." - ) + prompt_embeds_list.append(prompt_embeds) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt - self.add_block_state(state, data) + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) - return pipeline, state + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + data = self.get_block_state(state) + self.check_inputs(pipeline, data) + + data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.device = pipeline._execution_device + + + # Encode input prompt + data.text_encoder_lora_scale = ( + data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None + ) + ( + data.prompt_embeds, + data.negative_prompt_embeds, + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + data.prompt, + data.prompt_2, + data.device, + 1, + data.do_classifier_free_guidance, + data.negative_prompt, + data.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=data.text_encoder_lora_scale, + clip_skip=data.clip_skip, + ) + # Add outputs + self.add_block_state(state, data) + return pipeline, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - - @property - def components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()), - ComponentSpec("mask_processor", VaeImageProcessor, default=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), - ] - + @property def description(self) -> str: return ( - "Vae encoder step that prepares the image and mask for the inpainting process" + "Vae Encoder step that encode the input image into a latent representation" ) + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), + ] + @property def inputs(self) -> List[InputParam]: return [ + InputParam("image", required=True), + InputParam("generator"), InputParam("height"), InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), ] @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + return [ + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) - data = self.get_block_state(state) + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device + if self.vae.config.force_upcast: + self.vae.to(dtype) - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: - data.crops_coords = None - data.resize_mode = "default" - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) + image_latents = self.vae.config.scaling_factor * image_latents - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) + return image_latents + - data.batch_size = data.image.shape[0] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) + data.preprocess_kwargs = data.preprocess_kwargs or {} + data.device = pipeline._execution_device + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) + data.batch_size = data.image.shape[0] - self.add_block_state(state, data) + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(data.generator, list) and len(data.generator) != data.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" + f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + self.add_block_state(state, data) return pipeline, state -class StableDiffusionXLInputStep(PipelineBlock): +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), + ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), + ] + @property def description(self) -> str: return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." + "Vae encoder step that prepares the image and mask for the inpainting process" ) @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_images_per_prompt", default=1), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), ] @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + + data = self.get_block_state(state) + + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device + + if data.padding_mask_crop is not None: + data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) + data.resize_mode = "fill" + else: + data.crops_coords = None + data.resize_mode = "default" + + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) + data.image = data.image.to(dtype=torch.float32) + + data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) + data.masked_image = data.image * (data.mask < 0.5) + + data.batch_size = data.image.shape[0] + data.image = data.image.to(device=data.device, dtype=data.dtype) + data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + + # 7. Prepare mask latent variables + data.mask, data.masked_image_latents = self.prepare_mask_latents( + data.mask, + data.masked_image, + data.batch_size, + data.height, + data.width, + data.dtype, + data.device, + data.generator, + ) + + self.add_block_state(state, data) + + + return pipeline, state + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), @@ -597,7 +1028,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", KarrasDiffusionSchedulers), ] @@ -636,6 +1067,47 @@ def intermediates_outputs(self) -> List[str]: OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -650,7 +1122,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - data.timesteps, data.num_inference_steps = pipeline.get_timesteps( + data.timesteps, data.num_inference_steps = self.get_timesteps( data.num_inference_steps, data.strength, data.device, @@ -678,7 +1150,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", KarrasDiffusionSchedulers), ] @@ -733,7 +1205,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", KarrasDiffusionSchedulers), ] @@ -809,19 +1281,135 @@ def intermediates_outputs(self) -> List[str]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents_inpaint( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - data.is_strength_max = data.strength == 1.0 + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) + + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device + + data.is_strength_max = data.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(pipeline,"unet") and pipeline.unet is not None: + if pipeline.unet.config.in_channels == 4: data.masked_image_latents = None data.add_noise = True if data.denoising_start is None else False @@ -829,7 +1417,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - data.latents, data.noise = pipeline.prepare_latents_inpaint( + data.latents, data.noise = self.prepare_latents_inpaint( data.batch_size * data.num_images_per_prompt, pipeline.num_channels_latents, data.height, @@ -847,7 +1435,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin ) # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( + data.mask, data.masked_image_latents = self.prepare_mask_latents( data.mask, data.masked_image_latents, data.batch_size * data.num_images_per_prompt, @@ -867,7 +1455,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", KarrasDiffusionSchedulers), @@ -900,6 +1488,92 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + # YiYi TODO: refactor using _encode_vae_image + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -909,7 +1583,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.device = pipeline._execution_device data.add_noise = True if data.denoising_start is None else False if data.latents is None: - data.latents = pipeline.prepare_latents_img2img( + data.latents = self.prepare_latents_img2img( data.image_latents, data.latent_timestep, data.batch_size, @@ -929,7 +1603,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", KarrasDiffusionSchedulers), ] @@ -989,6 +1663,30 @@ def check_inputs(pipeline, data): f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -1003,7 +1701,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor data.num_channels_latents = pipeline.num_channels_latents - data.latents = pipeline.prepare_latents( + data.latents = self.prepare_latents( data.batch_size * data.num_images_per_prompt, data.num_channels_latents, data.height, @@ -1024,7 +1722,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def configs(self) -> List[ConfigSpec]: + def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("requires_aesthetics_score", default=False),] @property @@ -1114,6 +1812,37 @@ def _get_add_time_ids_img2img( return add_time_ids, add_neg_time_ids + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -1158,7 +1887,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet.config.time_cond_proj_dim is not None ): data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( + data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1269,17 +1998,48 @@ def _get_add_time_ids_img2img( return add_time_ids, add_neg_time_ids - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) + data.device = pipeline._execution_device + + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * pipeline.vae_scale_factor + data.width = data.width * pipeline.vae_scale_factor + + data.original_size = data.original_size or (data.height, data.width) + data.target_size = data.target_size or (data.height, data.width) data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) @@ -1312,7 +2072,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet.config.time_cond_proj_dim is not None ): data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( + data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1325,7 +2085,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("guider", CFGGuider), ComponentSpec("scheduler", KarrasDiffusionSchedulers), @@ -1471,6 +2231,23 @@ def check_inputs(self, pipeline, data): " `pipeline.unet` or your `mask_image` or `image` input." ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1520,7 +2297,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: @@ -1581,13 +2358,13 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("guider", CFGGuider), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ComponentSpec("controlnet_guider", CFGGuider), ] @@ -1737,6 +2514,57 @@ def check_inputs(self, pipeline, data): " `pipeline.unet` or your `mask_image` or `image` input." ) + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1787,7 +2615,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (1.5) # control_image if isinstance(controlnet, ControlNetModel): - data.control_image = pipeline.prepare_control_image( + data.control_image = self.prepare_control_image( image=data.control_image, width=data.width, height=data.height, @@ -1801,7 +2629,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: control_images = [] for control_image_ in data.control_image: - control_image = pipeline.prepare_control_image( + control_image = self.prepare_control_image( image=control_image_, width=data.width, height=data.height, @@ -1884,7 +2712,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) # (5) Denoise loop @@ -1975,14 +2803,14 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("guider", CFGGuider), ComponentSpec("controlnet_guider", CFGGuider), - ComponentSpec("control_image_processor", VaeImageProcessor, default=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @property @@ -2131,6 +2959,57 @@ def check_inputs(self, pipeline, data): " `pipeline.unet` or your `mask_image` or `image` input." ) + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -2182,7 +3061,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (1.5) # prepare control_image for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.prepare_control_image( + data.control_image[idx] = self.prepare_control_image( image=data.control_image[idx], width=data.width, height=data.height, @@ -2270,7 +3149,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) @@ -2363,10 +3242,10 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def components(self) -> List[ComponentSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, default=VaeImageProcessor()) + ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()) ] @property @@ -2387,6 +3266,24 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -2396,7 +3293,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast if data.needs_upcasting: - pipeline.upcast_vae() + self.upcast_vae() data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) elif data.latents.dtype != pipeline.vae.dtype: if torch.backends.mps.is_available(): @@ -2734,7 +3631,9 @@ def description(self): } -class StableDiffusionXLComponents( +# YiYi TODO: rename to components etc. and not inherit from ModularPipeline +class StableDiffusionXLModularPipeline( + ModularPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, @@ -2769,769 +3668,6 @@ def num_channels_latents(self): return num_channels_latents - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids - def _get_add_time_ids_img2img( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents - # YiYi TODO: refactor using _encode_vae_image - def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents - def prepare_latents_inpaint( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = self.vae.config.scaling_factor * image_latents - - return image_latents - - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - From d143851309c7eed3ddb3af54fc56943452cf79d5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 12 Apr 2025 11:46:25 +0200 Subject: [PATCH 04/54] move methods to blocks --- src/diffusers/pipelines/components_manager.py | 1 - src/diffusers/pipelines/modular_pipeline.py | 439 ++++++------------ .../pipeline_stable_diffusion_xl_modular.py | 310 ++++++------- 3 files changed, 295 insertions(+), 455 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 6d7665e29292..8c14321ccfac 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -346,7 +346,6 @@ def get(self, names: Union[str, List[str]]): results.update(result) else: results[name] = result - logger.info(f"Getting multiple components: {list(results.keys())}") return results else: diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 954b78d417ce..785f38cdbf8c 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -170,7 +170,7 @@ def __repr__(self): @dataclass class OutputParam: name: str - type_hint: Any + type_hint: Any = None description: str = "" def __repr__(self): @@ -440,63 +440,31 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) - - main_components = [] - auxiliary_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + expected_components_str_list = [] + + for component_spec in expected_components: + component_str = f" - {component_spec.name} ({component_spec.type_hint})" - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" - if k in getattr(self, "auxiliary_components", []): - auxiliary_components.append(component_str) - else: - main_components.append(component_str) - - components = "Components:\n" + "\n".join(main_components) - if auxiliary_components: - components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs = "Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" - for k in all_configs - ) + expected_components_str_list.append(component_str) + + components = "Components:\n" + "\n".join(expected_components_str_list) + + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -672,35 +640,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs @property def required_inputs(self) -> List[str]: @@ -855,62 +794,34 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - main_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + expected_components_str_list = [] + + for component_spec in expected_components: - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" + + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" + expected_components_str_list.append(component_str) - main_components.append(component_str) + components_str = " Components:\n" + "\n".join(expected_components_str_list) - components = "Components:\n" + "\n".join(main_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items() - ) + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + # Blocks section blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -955,6 +866,7 @@ def __repr__(self): blocks_str += f"{indented_intermediates}\n" blocks_str += "\n" + # Inputs and outputs section inputs_str = format_inputs_short(self.inputs) inputs_str = " Inputs:\n " + inputs_str outputs = [out.name for out in self.outputs] @@ -970,7 +882,6 @@ def __repr__(self): f"{header}\n" f"{desc}" f"{components_str}\n" - f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" f"{inputs_str}\n" @@ -1037,35 +948,6 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs @property def required_inputs(self) -> List[str]: @@ -1284,63 +1166,34 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_component_names = {comp.name for comp in expected_components} if expected_components else set() - loaded_components = set(self.components.keys()) - all_components = sorted(expected_component_names | loaded_components) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - main_components = [] - for k in all_components: - # Get component spec if available - component_spec = next((comp for comp in expected_components if comp.name == k), None) + expected_components_str_list = [] + + for component_spec in expected_components: - if k in loaded_components: - component_type = type(self.components[k]).__name__ - component_str = f" - {k}={component_type}" - - # Add expected type info if available - if component_spec and component_spec.class_name: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - if expected_type != component_type: - component_str += f" (expected: {expected_type})" - else: - # Component not loaded but expected - if component_spec: - expected_type = component_spec.class_name - if isinstance(expected_type, (list, tuple)): - expected_type = expected_type[1] # Get class name from [module, class_name] - component_str = f" - {k} (expected: {expected_type})" - - # Add repo info if available - if component_spec.default_repo: - repo_info = component_spec.default_repo - if component_spec.subfolder: - repo_info += f", subfolder={component_spec.subfolder}" - component_str += f" [{repo_info}]" + component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" + + # Add repo info if available + if component_spec.default_repo: + if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: + repo_info = component_spec.default_repo[0] + subfolder = component_spec.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" else: - component_str = f" - {k}" + repo_info = component_spec.default_repo + component_str += f" [{repo_info}]" + expected_components_str_list.append(component_str) - main_components.append(component_str) - - components = "Components:\n" + "\n".join(main_components) + components_str = " Components:\n" + "\n".join(expected_components_str_list) - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) + # Configs section - focus only on expected configs + expected_configs = getattr(self, "expected_configs", []) + configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + # Blocks section blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1385,6 +1238,7 @@ def __repr__(self): blocks_str += f"{indented_intermediates}\n" blocks_str += "\n" + # Inputs and outputs section inputs_str = format_inputs_short(self.inputs) inputs_str = " Inputs:\n " + inputs_str outputs = [out.name for out in self.outputs] @@ -1400,7 +1254,6 @@ def __repr__(self): f"{header}\n" f"{desc}" f"{components_str}\n" - f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" f"{inputs_str}\n" @@ -1408,6 +1261,7 @@ def __repr__(self): f")" ) + @property def doc(self): return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) @@ -1424,16 +1278,17 @@ class ModularPipeline(ConfigMixin): def __init__(self, block): self.pipeline_block = block - # add default components from pipeline_block (e.g. guider) - for key, value in block.components.items(): - setattr(self, key, value) + for component_spec in self.expected_components: + if component_spec.obj is not None: + setattr(self, component_spec.name, component_spec.obj) + else: + setattr(self, component_spec.name, None) + + default_configs = {} + for config_spec in self.expected_configs: + default_configs[config_spec.name] = config_spec.default + self.register_to_config(**default_configs) - # add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt) - self.register_to_config(**block.configs) - - # add default auxiliaries from pipeline_block (e.g. image_processor) - for key, value in block.auxiliaries.items(): - setattr(self, key, value) @classmethod def from_block(cls, block): @@ -1508,9 +1363,9 @@ def expected_configs(self): @property def components(self): components = {} - for name in self.expected_components: - if hasattr(self, name): - components[name] = getattr(self, name) + for component_spec in self.expected_components: + if hasattr(self, component_spec.name): + components[component_spec.name] = getattr(self, component_spec.name) return components # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar @@ -1596,32 +1451,32 @@ def update_states(self, **kwargs): kwargs (dict): Keyword arguments to update the states. """ - for component_name in self.expected_components: - if component_name in kwargs: - if hasattr(self, component_name) and getattr(self, component_name) is not None: - current_component = getattr(self, component_name) - new_component = kwargs[component_name] + for component in self.expected_components: + if component.name in kwargs: + if hasattr(self, component.name) and getattr(self, component.name) is not None: + current_component = getattr(self, component.name) + new_component = kwargs[component.name] if not isinstance(new_component, current_component.__class__): logger.info( - f"Overwriting existing component '{component_name}' " + f"Overwriting existing component '{component.name}' " f"(type: {current_component.__class__.__name__}) " f"with type: {new_component.__class__.__name__})" ) elif isinstance(current_component, torch.nn.Module): if id(current_component) != id(new_component): logger.info( - f"Overwriting existing component '{component_name}' " + f"Overwriting existing component '{component.name}' " f"(type: {type(current_component).__name__}) " f"with new value (type: {type(new_component).__name__})" ) - setattr(self, component_name, kwargs.pop(component_name)) + setattr(self, component.name, kwargs.pop(component.name)) configs_to_add = {} - for config_name in self.expected_configs: - if config_name in kwargs: - configs_to_add[config_name] = kwargs.pop(config_name) + for config in self.expected_configs: + if config.name in kwargs: + configs_to_add[config.name] = kwargs.pop(config.name) self.register_to_config(**configs_to_add) @property @@ -1631,64 +1486,64 @@ def default_call_parameters(self) -> Dict[str, Any]: params[input_param.name] = input_param.default return params - def __repr__(self): - output = "ModularPipeline:\n" - output += "==============================\n\n" + # def __repr__(self): + # output = "ModularPipeline:\n" + # output += "==============================\n\n" - block = self.pipeline_block + # block = self.pipeline_block - # List the pipeline block structure first - output += "Pipeline Block:\n" - output += "--------------\n" - if hasattr(block, "blocks"): - output += f"{block.__class__.__name__}\n" - base_class = block.__class__.__bases__[0].__name__ - output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - for sub_block_name, sub_block in block.blocks.items(): - if hasattr(block, "block_trigger_inputs"): - trigger_input = block.block_to_trigger_map[sub_block_name] - trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - else: - output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - else: - output += f"{block.__class__.__name__}\n" - output += "\n" - - # List the components registered in the pipeline - output += "Registered Components:\n" - output += "----------------------\n" - for name, component in self.components.items(): - output += f"{name}: {type(component).__name__}" - if hasattr(component, "dtype") and hasattr(component, "device"): - output += f" (dtype={component.dtype}, device={component.device})" - output += "\n" - output += "\n" - - # List the configs registered in the pipeline - output += "Registered Configs:\n" - output += "------------------\n" - for name, config in self.config.items(): - output += f"{name}: {config!r}\n" - output += "\n" - - # Add auto blocks section - if hasattr(block, "trigger_inputs") and block.trigger_inputs: - output += "------------------\n" - output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - output += f"Trigger Inputs: {block.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in block.trigger_inputs if t is not None) - output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - output += "Check `.doc` of returned object for more information.\n\n" - - # List the call parameters - full_doc = self.pipeline_block.doc - if "------------------------" in full_doc: - full_doc = full_doc.split("------------------------")[0].rstrip() - output += full_doc - - return output + # # List the pipeline block structure first + # output += "Pipeline Block:\n" + # output += "--------------\n" + # if hasattr(block, "blocks"): + # output += f"{block.__class__.__name__}\n" + # base_class = block.__class__.__bases__[0].__name__ + # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" + # for sub_block_name, sub_block in block.blocks.items(): + # if hasattr(block, "block_trigger_inputs"): + # trigger_input = block.block_to_trigger_map[sub_block_name] + # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" + # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" + # else: + # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" + # else: + # output += f"{block.__class__.__name__}\n" + # output += "\n" + + # # List the components registered in the pipeline + # output += "Registered Components:\n" + # output += "----------------------\n" + # for name, component in self.components.items(): + # output += f"{name}: {type(component).__name__}" + # if hasattr(component, "dtype") and hasattr(component, "device"): + # output += f" (dtype={component.dtype}, device={component.device})" + # output += "\n" + # output += "\n" + + # # List the configs registered in the pipeline + # output += "Registered Configs:\n" + # output += "------------------\n" + # for name, config in self.config.items(): + # output += f"{name}: {config!r}\n" + # output += "\n" + + # # Add auto blocks section + # if hasattr(block, "trigger_inputs") and block.trigger_inputs: + # output += "------------------\n" + # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" + # output += f"Trigger Inputs: {block.trigger_inputs}\n" + # # Get first trigger input as example + # example_input = next(t for t in block.trigger_inputs if t is not None) + # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + # output += "Check `.doc` of returned object for more information.\n\n" + + # # List the call parameters + # full_doc = self.pipeline_block.doc + # if "------------------------" in full_doc: + # full_doc = full_doc.split("------------------------")[0].rstrip() + # output += full_doc + + # return output # YiYi TODO: try to unify the to method with the one in DiffusionPipeline # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 23ea96b8e8a0..8e7109308962 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -22,7 +22,7 @@ from ...guider import CFGGuider from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.lora import adjust_lora_scale_text_encoder from ...utils import ( @@ -211,7 +211,7 @@ def intermediates_outputs(self) -> List[OutputParam]: ] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(components.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): @@ -237,7 +237,7 @@ def encode_image(components, image, device, num_images_per_prompt, output_hidden # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: @@ -288,7 +288,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.do_classifier_free_guidance = data.guidance_scale > 1.0 data.device = pipeline._execution_device - data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( + data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + pipeline, ip_adapter_image=data.ip_adapter_image, ip_adapter_image_embeds=None, device=data.device, @@ -358,8 +359,9 @@ def check_inputs(self, pipeline, data): elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components def encode_prompt( + self, components, prompt: str, prompt_2: Optional[str] = None, @@ -496,7 +498,7 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) @@ -614,6 +616,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds, ) = self.encode_prompt( + pipeline, data.prompt, data.prompt_2, data.device, @@ -670,40 +673,40 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) dtype = image.dtype - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) image_latents = image_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std else: - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = components.vae.config.scaling_factor * image_latents return image_latents @@ -729,7 +732,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) - data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator) self.add_block_state(state, data) @@ -776,32 +779,32 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) dtype = image.dtype - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) image_latents = image_latents.to(dtype) if latents_mean is not None and latents_std is not None: @@ -809,20 +812,20 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): latents_std = latents_std.to(device=image_latents.device, dtype=dtype) image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = components.vae.config.scaling_factor * image_latents return image_latents # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) @@ -844,7 +847,7 @@ def prepare_mask_latents( if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -887,10 +890,11 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.batch_size = data.image.shape[0] data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = self._encode_vae_image(image=data.image, generator=data.generator) + data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator) # 7. Prepare mask latent variables data.mask, data.masked_image_latents = self.prepare_mask_latents( + pipeline, data.mask, data.masked_image, data.batch_size, @@ -1067,16 +1071,16 @@ def intermediates_outputs(self) -> List[str]: OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) return timesteps, num_inference_steps - t_start @@ -1085,13 +1089,13 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would @@ -1101,10 +1105,10 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) return timesteps, num_inference_steps @@ -1123,6 +1127,7 @@ def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 data.timesteps, data.num_inference_steps = self.get_timesteps( + pipeline, data.num_inference_steps, data.strength, data.device, @@ -1281,9 +1286,10 @@ def intermediates_outputs(self) -> List[str]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components def prepare_latents_inpaint( self, + components, batch_size, num_channels_latents, height, @@ -1302,8 +1308,8 @@ def prepare_latents_inpaint( shape = ( batch_size, num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1322,18 +1328,18 @@ def prepare_latents_inpaint( image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma + latents = noise * components.scheduler.init_noise_sigma else: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) @@ -1351,13 +1357,13 @@ def prepare_latents_inpaint( # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) @@ -1379,7 +1385,7 @@ def prepare_mask_latents( if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -1418,6 +1424,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor data.latents, data.noise = self.prepare_latents_inpaint( + pipeline, data.batch_size * data.num_images_per_prompt, pipeline.num_channels_latents, data.height, @@ -1436,6 +1443,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin # 7. Prepare mask latent variables data.mask, data.masked_image_latents = self.prepare_mask_latents( + pipeline, data.mask, data.masked_image_latents, data.batch_size * data.num_images_per_prompt, @@ -1488,10 +1496,10 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components # YiYi TODO: refactor using _encode_vae_image def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -1499,8 +1507,8 @@ def prepare_latents_img2img( ) # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") + if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None: + components.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) @@ -1512,14 +1520,14 @@ def prepare_latents_img2img( else: latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: + if components.vae.config.force_upcast: image = image.float() - self.vae.to(dtype=torch.float32) + components.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1536,23 +1544,23 @@ def prepare_latents_img2img( ) init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - if self.vae.config.force_upcast: - self.vae.to(dtype) + if components.vae.config.force_upcast: + components.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std else: - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = components.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -1569,7 +1577,7 @@ def prepare_latents_img2img( shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + init_latents = components.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents @@ -1584,6 +1592,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.add_noise = True if data.denoising_start is None else False if data.latents is None: data.latents = self.prepare_latents_img2img( + pipeline, data.image_latents, data.latent_timestep, data.batch_size, @@ -1663,13 +1672,13 @@ def check_inputs(pipeline, data): f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1683,7 +1692,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + latents = latents * components.scheduler.init_noise_sigma return latents @@ -1702,6 +1711,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor data.num_channels_latents = pipeline.num_channels_latents data.latents = self.prepare_latents( + pipeline, data.batch_size * data.num_images_per_prompt, data.num_channels_latents, data.height, @@ -1762,6 +1772,7 @@ def intermediates_outputs(self) -> List[OutputParam]: # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components def _get_add_time_ids_img2img( + self, components, original_size, crops_coords_top_left, @@ -1864,7 +1875,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin if data.negative_target_size is None: data.negative_target_size = data.target_size - data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img( + data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img( + pipeline, data.original_size, data.crops_coords_top_left, data.target_size, @@ -1946,57 +1958,24 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - def _get_add_time_ids_img2img( - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + def _get_add_time_ids( + self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: + if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids + return add_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( @@ -2043,7 +2022,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - data.add_time_ids = pipeline._get_add_time_ids( + data.add_time_ids = self._get_add_time_ids( + pipeline, data.original_size, data.crops_coords_top_left, data.target_size, @@ -2051,7 +2031,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin text_encoder_projection_dim=data.text_encoder_projection_dim, ) if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = pipeline._get_add_time_ids( + data.negative_add_time_ids = self._get_add_time_ids( + pipeline, data.negative_original_size, data.negative_crops_coords_top_left, data.negative_target_size, @@ -2087,7 +2068,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2231,20 +2212,20 @@ def check_inputs(self, pipeline, data): " `pipeline.unet` or your `mask_image` or `image` input." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -2297,7 +2278,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: @@ -2360,12 +2341,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", CFGGuider), + ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), ] @property @@ -2519,6 +2500,7 @@ def check_inputs(self, pipeline, data): # 2. add crops_coords and resize_mode to preprocess() def prepare_control_image( self, + components, image, width, height, @@ -2529,9 +2511,9 @@ def prepare_control_image( crops_coords=None, ): if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -2546,20 +2528,20 @@ def prepare_control_image( return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -2616,6 +2598,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # control_image if isinstance(controlnet, ControlNetModel): data.control_image = self.prepare_control_image( + pipeline, image=data.control_image, width=data.width, height=data.height, @@ -2630,6 +2613,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: for control_image_ in data.control_image: control_image = self.prepare_control_image( + pipeline, image=control_image_, width=data.width, height=data.height, @@ -2712,7 +2696,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) # (5) Denoise loop @@ -2808,8 +2792,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", CFGGuider), - ComponentSpec("controlnet_guider", CFGGuider), + ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2965,6 +2949,7 @@ def check_inputs(self, pipeline, data): # 2. add crops_coords and resize_mode to preprocess() def prepare_control_image( self, + components, image, width, height, @@ -2975,9 +2960,9 @@ def prepare_control_image( crops_coords=None, ): if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -2992,20 +2977,20 @@ def prepare_control_image( return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + def prepare_extra_step_kwargs(self, components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -3062,6 +3047,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # prepare control_image for idx, _ in enumerate(data.control_image): data.control_image[idx] = self.prepare_control_image( + pipeline, image=data.control_image[idx], width=data.width, height=data.height, @@ -3149,7 +3135,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(data.generator, data.eta) + data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) @@ -3266,12 +3252,12 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + def upcast_vae(self, components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, + components.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, @@ -3280,9 +3266,9 @@ def upcast_vae(self): # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -3293,7 +3279,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast if data.needs_upcasting: - self.upcast_vae() + self.upcast_vae(pipeline) data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) elif data.latents.dtype != pipeline.vae.dtype: if torch.backends.mps.is_available(): @@ -3672,7 +3658,7 @@ def num_channels_latents(self): # YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -sdxl_inputs_map = { +SDXL_INPUTS_SCHEMA = { "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), @@ -3718,7 +3704,7 @@ def num_channels_latents(self): } -sdxl_intermediate_inputs_map = { +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), @@ -3744,7 +3730,7 @@ def num_channels_latents(self): } -sdxl_intermediate_outputs_map = { +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), @@ -3769,6 +3755,6 @@ def num_channels_latents(self): } -sdxl_outputs_map = { +SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } \ No newline at end of file From b863bdd6caa99a7bc410f22ef68b3686ca7e222a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 26 Apr 2025 03:42:42 +0530 Subject: [PATCH 05/54] Modular Diffusers Guiders (#11311) * cfg; slg; pag; sdxl without controlnet * support sdxl controlnet * support controlnet union * update * update * cfg zero* * use unwrap_module for torch compiled modules * remove guider kwargs * remove commented code * remove old guider * fix slg bug * remove debug print * autoguidance * smoothed energy guidance * add note about seg * tangential cfg * cfg plus plus * support cfgpp in ddim * apply review suggestions * refactor * rename enable/disable * remove cfg++ for now * rename do_classifier_free_guidance->prepare_unconditional_embeds * remove unused --- src/diffusers/__init__.py | 27 + src/diffusers/guider.py | 748 ------------------ src/diffusers/guiders/__init__.py | 29 + .../guiders/adaptive_projected_guidance.py | 180 +++++ src/diffusers/guiders/auto_guidance.py | 173 ++++ .../guiders/classifier_free_guidance.py | 128 +++ .../classifier_free_zero_star_guidance.py | 144 ++++ .../guiders/entropy_rectifying_guidance.py | 0 src/diffusers/guiders/guider_utils.py | 215 +++++ src/diffusers/guiders/skip_layer_guidance.py | 247 ++++++ .../guiders/smoothed_energy_guidance.py | 240 ++++++ .../tangential_classifier_free_guidance.py | 133 ++++ src/diffusers/hooks/__init__.py | 2 + src/diffusers/hooks/_common.py | 43 + src/diffusers/hooks/_helpers.py | 271 +++++++ src/diffusers/hooks/layer_skip.py | 229 ++++++ .../hooks/smoothed_energy_guidance_utils.py | 158 ++++ .../pipeline_stable_diffusion_xl_modular.py | 561 ++++++------- src/diffusers/utils/torch_utils.py | 5 + 19 files changed, 2458 insertions(+), 1075 deletions(-) delete mode 100644 src/diffusers/guider.py create mode 100644 src/diffusers/guiders/__init__.py create mode 100644 src/diffusers/guiders/adaptive_projected_guidance.py create mode 100644 src/diffusers/guiders/auto_guidance.py create mode 100644 src/diffusers/guiders/classifier_free_guidance.py create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py create mode 100644 src/diffusers/guiders/entropy_rectifying_guidance.py create mode 100644 src/diffusers/guiders/guider_utils.py create mode 100644 src/diffusers/guiders/skip_layer_guidance.py create mode 100644 src/diffusers/guiders/smoothed_energy_guidance.py create mode 100644 src/diffusers/guiders/tangential_classifier_free_guidance.py create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/layer_skip.py create mode 100644 src/diffusers/hooks/smoothed_energy_guidance_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 440c67da629d..a4f55acf8b70 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], @@ -129,12 +130,26 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "AdaptiveProjectedGuidance", + "AutoGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "SkipLayerGuidance", + "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -711,10 +726,22 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ) from .hooks import ( FasterCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, + apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, ) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py deleted file mode 100644 index b42dca64d651..000000000000 --- a/src/diffusers/guider.py +++ /dev/null @@ -1,748 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .models.attention_processor import ( - Attention, - AttentionProcessor, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) -from .utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - pipeline.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - def reset_guider(self, pipeline): - if ( - self.do_perturbed_attention_guidance - and hasattr(pipeline, "original_attn_proc") - and pipeline.original_attn_proc is not None - ): - pipeline.unet.set_attn_processor(pipeline.original_attn_proc) - pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size : self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The negative conditional input. It can be a single tensor or a list of tensors. It must have the same - length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - -class APGGuider: - """ - This class is used to guide the pipeline with APG (Adaptive Projected Guidance). - """ - - def normalized_guidance( - self, - pred_cond: torch.Tensor, - pred_uncond: torch.Tensor, - guidance_scale: float, - momentum_buffer: MomentumBuffer = None, - norm_threshold: float = 0.0, - eta: float = 1.0, - ): - """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion - Models](https://arxiv.org/pdf/2410.02416) - """ - diff = pred_cond - pred_uncond - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - v0, v1 = diff.double(), pred_cond.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) - normalized_update = diff_orthogonal + eta * diff_parallel - pred_guided = pred_cond + (guidance_scale - 1) * normalized_update - return pred_guided - - @property - def adaptive_projected_guidance_momentum(self): - return self._adaptive_projected_guidance_momentum - - @property - def adaptive_projected_guidance_rescale_factor(self): - return self._adaptive_projected_guidance_rescale_factor - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) - adaptive_projected_guidance_rescale_factor = guider_kwargs.get( - "adaptive_projected_guidance_rescale_factor", 15.0 - ) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum - self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - if adaptive_projected_guidance_momentum is not None: - self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) - else: - self.momentum_buffer = None - self.scheduler = pipeline.scheduler - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - if latents is None: - raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") - - sigma = self.scheduler.sigmas[self.scheduler.step_index] - noise_pred = latents - sigma * model_output - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = self.normalized_guidance( - noise_pred_text, - noise_pred_uncond, - self.guidance_scale, - self.momentum_buffer, - self.adaptive_projected_guidance_rescale_factor, - ) - noise_pred = (latents - noise_pred) / sigma - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -Guiders = Union[CFGGuider, PAGGuider, APGGuider] \ No newline at end of file diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..3c1ee293382d --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance + from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance + + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..7da1cc59a365 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_apg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_apg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_apg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..bfffb9f39cd2 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,173 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name in self._auto_guidance_hook_names: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + registry.remove_hook(name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..429f8450410a --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..4c9839ee78f3 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,144 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/entropy_rectifying_guidance.py b/src/diffusers/guiders/entropy_rectifying_guidance.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..7d005442e89c --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance: + r"""Base class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + _identifier_key = "__guidance_identifier__" + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._count_prepared = 0 + self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None + self._enabled = True + + if not (0.0 <= start < 1.0): + raise ValueError( + f"Expected `start` to be between 0.0 and 1.0, but got {start}." + ) + if not (start <= stop <= 1.0): + raise ValueError( + f"Expected `stop` to be between {start} and 1.0, but got {stop}." + ) + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def disable(self): + self._enabled = False + + def enable(self): + self._enabled = True + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._count_prepared = 0 + + def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: + """ + Set the input fields for the guidance technique. The input fields are used to specify the names of the + returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is + obtained from the values of the provided keyword arguments to this method. + + Args: + **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + + Example: + + ``` + data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } + + BaseGuidance.set_input_fields( + latents="latents", + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + ) + ``` + """ + for key, value in kwargs.items(): + is_string = isinstance(value, str) + is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + if not (is_string or is_tuple_of_str_with_len_2): + raise ValueError( + f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." + ) + self._input_fields = kwargs + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + self._count_prepared += 1 + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + """ + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in + subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + modifications made during `prepare_models`. + """ + pass + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + + def __call__(self, data: List["BlockState"]) -> Any: + if not all(hasattr(d, "noise_pred") for d in data): + raise ValueError("Expected all data to have `noise_pred` attribute.") + if len(data) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data." + ) + forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data} + return self.forward(**forward_inputs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + + @property + def num_conditions(self) -> int: + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") + + @classmethod + def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of + the `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..pipelines.modular_pipeline import BlockState + + if input_fields is None: + raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..bdd9e4af81b6 --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,247 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..1c7ee45dc3db --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,240 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module): + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..631f9a5f33b2 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,133 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..9d0e96e9e79e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -5,5 +5,7 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3d9c99e8189f --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9043ffc41838 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..c50d2b7471e4 --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,229 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + if not math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + else: + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook(config.dropout) + registry.register_hook(hook, name) + + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) + registry.register_hook(hook, name) + + if config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook(config.dropout) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..f0366e29887f --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice.clone() + + return query diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e7109308962..2493d5635552 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -19,7 +19,6 @@ import torch from collections import OrderedDict -from ...guider import CFGGuider from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel @@ -31,7 +30,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import randn_tensor, unwrap_module from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, @@ -58,7 +57,7 @@ ) from ...schedulers import KarrasDiffusionSchedulers -from ...guider import Guiders, CFGGuider +from ...guiders import GuiderType, ClassifierFreeGuidance import numpy as np @@ -185,6 +184,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("image_encoder", CLIPVisionModelWithProjection), ComponentSpec("feature_extractor", CLIPImageProcessor), ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("guider", GuiderType), ] @property @@ -195,11 +195,7 @@ def inputs(self) -> List[InputParam]: PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - ), + ) ] @@ -237,10 +233,10 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds ): image_embeds = [] - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -260,11 +256,11 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: + if prepare_unconditional_embeds: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) @@ -272,7 +268,7 @@ def prepare_ip_adapter_image_embeds( ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) @@ -285,7 +281,7 @@ def prepare_ip_adapter_image_embeds( def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( @@ -294,9 +290,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_image_embeds=None, device=data.device, num_images_per_prompt=1, - do_classifier_free_guidance=data.do_classifier_free_guidance, + prepare_unconditional_embeds=data.prepare_unconditional_embeds, ) - if data.do_classifier_free_guidance: + if data.prepare_unconditional_embeds: data.negative_ip_adapter_embeds = [] for i, image_embeds in enumerate(data.ip_adapter_embeds): negative_image_embeds, image_embeds = image_embeds.chunk(2) @@ -324,6 +320,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("guider", GuiderType), ] @property @@ -338,7 +335,6 @@ def inputs(self) -> List[InputParam]: InputParam("negative_prompt"), InputParam("negative_prompt_2"), InputParam("cross_attention_kwargs"), - InputParam("guidance_scale",default=5.0), InputParam("clip_skip"), ] @@ -359,7 +355,6 @@ def check_inputs(self, pipeline, data): elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components def encode_prompt( self, components, @@ -367,7 +362,7 @@ def encode_prompt( prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, + prepare_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -390,8 +385,8 @@ def encode_prompt( torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -499,10 +494,10 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: + elif prepare_unconditional_embeds and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt @@ -563,7 +558,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] @@ -578,7 +573,7 @@ def encode_prompt( pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) @@ -602,10 +597,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device - # Encode input prompt data.text_encoder_lora_scale = ( data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None @@ -621,7 +615,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.prompt_2, data.device, 1, - data.do_classifier_free_guidance, + data.prepare_unconditional_embeds, data.negative_prompt, data.negative_prompt_2, prompt_embeds=None, @@ -1751,7 +1745,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", required=True), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), ] @@ -1898,7 +1891,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1926,7 +1920,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), ] @property @@ -2052,7 +2045,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -2068,7 +2062,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2082,12 +2076,9 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), InputParam("num_images_per_prompt", default=1), ] @@ -2238,78 +2229,63 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - # Prepare conditional inputs using the guider - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.disable() + else: + pipeline.guider.enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # expand the latents if we are doing classifier free guidance - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - # inpainting + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + + # Prepare for inpainting if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - # predict the noise residual - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance( - data.noise_pred, - timestep=t, - latents=data.latents, - ) - # compute the previous noisy sample x_t -> x_t-1 + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + + for batch in guider_data: + pipeline.guider.prepare_models(pipeline.unet) + + # Prepare additional conditionings + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, + } + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + + # Predict the noise residual + batch.noise_pred = pipeline.unet( + data.scaled_latents, + t, + encoder_hidden_states=batch.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=batch.added_cond_kwargs, + return_dict=False, + )[0] + pipeline.guider.cleanup_models(pipeline.unet) + + # Perform guidance + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -2328,7 +2304,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - pipeline.guider.reset_guider(pipeline) self.add_block_state(state, data) return pipeline, state @@ -2341,12 +2316,11 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), ] @property @@ -2362,12 +2336,9 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), ] @property @@ -2514,8 +2485,8 @@ def prepare_control_image( image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] - if image_batch_size == 1: repeat_by = batch_size else: @@ -2523,9 +2494,7 @@ def prepare_control_image( repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @@ -2556,14 +2525,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control_guidance_start/control_guidance_end (align format) @@ -2641,72 +2608,30 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) + if data.disable_guidance: + pipeline.guider.disable() + else: + pipeline.guider.enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) + # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -2715,52 +2640,72 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + + for batch in guider_data: + pipeline.guider.prepare_models(pipeline.unet) + + # Prepare additional conditionings + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, + } + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + + # Prepare controlnet additional conditionings + batch.controlnet_added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, + } + + # Will always be run atleast once with every guider + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + data.scaled_latents, + t, + encoder_hidden_states=batch.prompt_embeds, + controlnet_cond=data.control_image, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=batch.controlnet_added_cond_kwargs, + return_dict=False, + ) + + batch.down_block_res_samples = data.down_block_res_samples + batch.mid_block_res_sample = data.mid_block_res_sample + + if pipeline.guider.is_unconditional and data.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + + # Prepare for inpainting + if data.num_channels_unet == 9: + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + + batch.noise_pred = pipeline.unet( + data.scaled_latents, + t, + encoder_hidden_states=batch.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=batch.added_cond_kwargs, + down_block_additional_residuals=batch.down_block_res_samples, + mid_block_additional_residual=batch.mid_block_res_sample, + return_dict=False, + )[0] + pipeline.guider.cleanup_models(pipeline.unet) + + # Perform guidance + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents @@ -2774,9 +2719,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) @@ -2792,8 +2734,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2810,12 +2751,9 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs") ] @property @@ -3008,7 +2946,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control guidance @@ -3058,7 +2996,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: crops_coords=data.crops_coords, ) data.height, data.width = data.control_image[idx].shape[-2:] - # (1.6) # controlnet_keep @@ -3072,80 +3009,32 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) + if data.disable_guidance: + pipeline.guider.disable() + else: + pipeline.guider.enable() - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) + data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -3154,49 +3043,69 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + + for batch in guider_data: + pipeline.guider.prepare_models(pipeline.unet) + + # Prepare additional conditionings + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, + } + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + + # Prepare controlnet additional conditionings + batch.controlnet_added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, + } + + # Will always be run atleast once with every guider + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + data.scaled_latents, + t, + encoder_hidden_states=batch.prompt_embeds, + controlnet_cond=data.control_image, + control_type=data.control_type, + control_type_idx=data.control_mode, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=batch.controlnet_added_cond_kwargs, + return_dict=False, + ) + + batch.down_block_res_samples = data.down_block_res_samples + batch.mid_block_res_sample = data.mid_block_res_sample + + if pipeline.guider.is_unconditional and data.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + + if data.num_channels_unet == 9: + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + + batch.noise_pred = pipeline.unet( + data.scaled_latents, + t, + encoder_hidden_states=batch.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=batch.added_cond_kwargs, + down_block_additional_residuals=batch.down_block_res_samples, + mid_block_additional_residual=batch.mid_block_res_sample, + return_dict=False, + )[0] + pipeline.guider.cleanup_models(pipeline.unet) + + # Perform guidance + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -3209,14 +3118,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.init_latents_proper = pipeline.scheduler.add_noise( data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) @@ -3543,6 +3448,11 @@ def description(self): "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ "- for text-to-image generation, all you need to provide is `prompt`" +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -3664,7 +3574,6 @@ def num_channels_latents(self): "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "guidance_scale": InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-Free Diffusion Guidance scale"), "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), @@ -3689,9 +3598,7 @@ def num_channels_latents(self): "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "guidance_rescale": InputParam("guidance_rescale", type_hint=float, default=0.0, description="Guidance rescale factor to fix overexposure"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), @@ -3757,4 +3664,4 @@ def num_channels_latents(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} \ No newline at end of file +} diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 6d5beefe2918395e8a6f36eececfcaf4eea2a164 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 30 Apr 2025 11:17:20 -1000 Subject: [PATCH 06/54] [modular diffusers] introducing ModularLoader (#11462) * cfg; slg; pag; sdxl without controlnet --------- Co-authored-by: Aryan --- src/diffusers/__init__.py | 8 +- src/diffusers/pipelines/__init__.py | 8 +- src/diffusers/pipelines/components_manager.py | 396 +++++- src/diffusers/pipelines/modular_pipeline.py | 1245 ++++++++--------- .../pipelines/modular_pipeline_utils.py | 592 ++++++++ .../pipelines/pipeline_loading_utils.py | 21 +- src/diffusers/pipelines/pipeline_utils.py | 3 +- .../pipelines/stable_diffusion_xl/__init__.py | 4 +- .../pipeline_stable_diffusion_xl_modular.py | 106 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../dummy_torch_and_transformers_objects.py | 2 +- 11 files changed, 1571 insertions(+), 816 deletions(-) create mode 100644 src/diffusers/pipelines/modular_pipeline_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4f55acf8b70..c9ee38ac6fda 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -249,7 +249,7 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipeline", + "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", @@ -502,7 +502,7 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", @@ -840,7 +840,7 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipeline, + ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, @@ -1071,7 +1071,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index aee275db0336..7b6bd2071ef4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -46,7 +46,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularPipeline"] + _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -329,7 +329,7 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLAutoPipeline", ] ) @@ -468,7 +468,7 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularPipeline + from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -693,7 +693,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPipeline, StableDiffusionXLAutoPipeline, ) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 8c14321ccfac..bdff133e22d9 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -26,6 +26,7 @@ logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec if is_accelerate_available(): @@ -229,54 +230,175 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload + +from .modular_pipeline_utils import ComponentSpec +import uuid class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") - self.components[name] = component - self.added_time[name] = time.time() + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + + def add(self, name, component, collection: Optional[str] = None): + + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(component_id) + if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id + + + def remove(self, name: Union[str, List[str]]): - def remove(self, name): if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") return self.components.pop(name) self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys Returns: - Single component if names is str and matches one component, - dict of components if names matches multiple components or is a list + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return [] if as_name_component_tuples else {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -286,33 +408,45 @@ def get(self, names: Union[str, List[str]]): # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern } + if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") - - # Exact match - elif names in self.components: - if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -323,30 +457,46 @@ def get(self, names: Union[str, List[str]]): elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") - return matches if len(matches) > 1 else next(iter(matches.values())) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches elif isinstance(names, list): results = {} for name in names: - result = self.get(name) - if isinstance(result, dict): - results.update(result) - else: - results[name] = result - return results + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -390,6 +540,7 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False + # YiYi TODO: add quantization info def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -412,14 +563,23 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No info = { "model_id": name, "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -453,12 +613,56 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) } # Create the header lines @@ -475,17 +679,23 @@ def __repr__(self): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" output += dash_line # Other components section @@ -494,12 +704,16 @@ def __repr__(self): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += dash_line # Add additional component info @@ -507,7 +721,8 @@ def __repr__(self): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): @@ -516,7 +731,7 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -526,17 +741,12 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -545,6 +755,50 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 785f38cdbf8c..636b543395df 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -22,25 +22,45 @@ import torch from tqdm.auto import tqdm import re +import os +import importlib -from ..configuration_utils import ConfigMixin +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( is_accelerate_available, is_accelerate_version, logging, + PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class - +from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_input_params, + format_inputs_short, + format_intermediates_short, + format_output_params, + format_params, + make_doc_string, +) +from .components_manager import ComponentsManager +from copy import deepcopy if is_accelerate_available(): import accelerate logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_PIPELINE_MAPPING = OrderedDict( +MODULAR_LOADER_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), ] ) @@ -138,236 +158,116 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" -@dataclass -class ComponentSpec: - """Specification for a pipeline component.""" - name: str - type_hint: Type - description: Optional[str] = None - obj: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor - default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"] - default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None - - -@dataclass -class InputParam: - name: str - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" -@dataclass -class OutputParam: - name: str - type_hint: Any = None - description: str = "" - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - -def format_inputs_short(inputs): +class ModularPipelineMixin: """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a mouldar loader, optionally accept modular_repo to load from hub. + """ -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - input_parts.append(inp.name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs - return "\n".join(result) if result else " (none)" + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) + if not hasattr(self, "loader"): + raise ValueError("Loader is not set, please call `setup_loader()` first.") -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) + # Make a copy of the input kwargs + input_params = kwargs.copy() -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) + default_params = self.default_call_parameters + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for name, default in default_params.items(): + if name in input_params: + if name not in intermediates_inputs: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, input_params[name]) + elif name not in state.inputs: + state.add_input(name, default) -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "" + for name in intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" + if output is None: + return state - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - return output + elif isinstance(output, str): + return state.get_intermediate(output) + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) -class PipelineBlock: + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): model_name = None @@ -440,31 +340,15 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] - - for component_spec in expected_components: - component_str = f" - {component_spec.name} ({component_spec.type_hint})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") - components = "Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -478,8 +362,8 @@ def __repr__(self): f"{class_name}(\n" f" Class: {base_class}\n" f"{desc}" - f" {components}\n" - f" {configs}\n" + f"{components}\n" + f"{configs}\n" f" {inputs}\n" f" {intermediates}\n" f")" @@ -488,7 +372,15 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) def get_block_state(self, state: PipelineState) -> dict: @@ -575,7 +467,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -class AutoPipelineBlocks: +class AutoPipelineBlocks(ModularPipelineMixin): """ A class that automatically selects a block to run based on the inputs. @@ -796,32 +688,13 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -846,54 +719,31 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" f")" ) + @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) -class SequentialPipelineBlocks: +class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ @@ -1168,32 +1018,13 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1218,85 +1049,172 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" f")" ) @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) -class ModularPipeline(ConfigMixin): + + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): """ - Base class for all Modular pipelines. + Base class for all Modular pipelines loaders. """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specs. + This method is called when component changed or __init__ is called. - config_name = "model_index.json" - _exclude_from_cpu_offload = [] + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + + """ + for name, module in kwargs.items(): + + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + is_registered = hasattr(self, name) - def __init__(self, block): - self.pipeline_block = block + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - for component_spec in self.expected_components: - if component_spec.obj is not None: - setattr(self, component_spec.name, component_spec.obj) + # actual library and class name of the module + + if module is not None: + library, class_name = _fetch_class_library_tuple(module) + new_component_spec = ComponentSpec.from_component(name, module) + component_spec_dict = self._component_spec_to_dict(new_component_spec) + + else: + library, class_name = None, None + # if module is None, we do not update the spec, + # but we still need to update the config to make sure it's synced with the component spec + # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) + new_component_spec = component_spec + component_spec_dict = self._component_spec_to_dict(component_spec) + + # do not register if component is not to be loaded from pretrained + if new_component_spec.default_creation_method == "from_pretrained": + register_dict = {name: (library, class_name, component_spec_dict)} else: - setattr(self, component_spec.name, None) + register_dict = {} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + self._component_specs[name] = new_component_spec + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # it module is not an instance of the expected type, still register it but with a warning + if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → debug + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # save modular_model_index.json config + self.register_to_config(**register_dict) + # update component spec + self._component_specs[name] = new_component_spec + # finally set models + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + register_components_dict[name] = None + self.register_components(**register_components_dict) default_configs = {} - for config_spec in self.expected_configs: - default_configs[config_spec.name] = config_spec.default + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @classmethod - def from_block(cls, block): - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] - modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) - - return modular_pipeline_class(block) - @property def device(self) -> torch.device: r""" @@ -1320,7 +1238,7 @@ def _execution_device(self): Accelerate's module hooks. """ for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + if not isinstance(model, torch.nn.Module): continue if not hasattr(model, "_hf_hook"): @@ -1333,11 +1251,21 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - - - def get_execution_blocks(self, *trigger_inputs): - return self.pipeline_block.get_execution_blocks(*trigger_inputs) + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + @property def dtype(self) -> torch.dtype: r""" @@ -1352,340 +1280,257 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property - def expected_components(self): - return self.pipeline_block.expected_components - - @property - def expected_configs(self): - return self.pipeline_block.expected_configs @property - def components(self): - components = {} - for component_spec in self.expected_components: - if hasattr(self, component_spec.name): - components[component_spec.name] = getattr(self, component_spec.name) - return components - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters + Update components and configs after instance creation. + + Args: - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + ``` + """ - intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) + # extract component_specs_updates & config_specs_updates from `specs` + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.pipeline_block(self, state) - except Exception: - error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" - logger.error(error_msg) - raise + self.register_components(**passed_components) - if output is None: - return state + config_to_register = {} + for name, new_value in passed_config_values.items(): - elif isinstance(output, str): - return state.get_intermediate(output) + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - def update_states(self, **kwargs): + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None components are being - overwritten. - + Load selectedcomponents from specs. + Args: - kwargs (dict): Keyword arguments to update the states. + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - - for component in self.expected_components: - if component.name in kwargs: - if hasattr(self, component.name) and getattr(self, component.name) is not None: - current_component = getattr(self, component.name) - new_component = kwargs[component.name] - - if not isinstance(new_component, current_component.__class__): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {current_component.__class__.__name__}) " - f"with type: {new_component.__class__.__name__})" - ) - elif isinstance(current_component, torch.nn.Module): - if id(current_component) != id(new_component): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {type(current_component).__name__}) " - f"with new value (type: {type(new_component).__name__})" - ) - - setattr(self, component.name, kwargs.pop(component.name)) - - configs_to_add = {} - for config in self.expected_configs: - if config.name in kwargs: - configs_to_add[config.name] = kwargs.pop(config.name) - self.register_to_config(**configs_to_add) - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.pipeline_block.inputs: - params[input_param.name] = input_param.default - return params - - # def __repr__(self): - # output = "ModularPipeline:\n" - # output += "==============================\n\n" - - # block = self.pipeline_block + if component_names is None: + component_names = list(self._component_specs.keys()) + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - # # List the pipeline block structure first - # output += "Pipeline Block:\n" - # output += "--------------\n" - # if hasattr(block, "blocks"): - # output += f"{block.__class__.__name__}\n" - # base_class = block.__class__.__bases__[0].__name__ - # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - # for sub_block_name, sub_block in block.blocks.items(): - # if hasattr(block, "block_trigger_inputs"): - # trigger_input = block.block_to_trigger_map[sub_block_name] - # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - # else: - # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - # else: - # output += f"{block.__class__.__name__}\n" - # output += "\n" - - # # List the components registered in the pipeline - # output += "Registered Components:\n" - # output += "----------------------\n" - # for name, component in self.components.items(): - # output += f"{name}: {type(component).__name__}" - # if hasattr(component, "dtype") and hasattr(component, "device"): - # output += f" (dtype={component.dtype}, device={component.device})" - # output += "\n" - # output += "\n" - - # # List the configs registered in the pipeline - # output += "Registered Configs:\n" - # output += "------------------\n" - # for name, config in self.config.items(): - # output += f"{name}: {config!r}\n" - # output += "\n" - - # # Add auto blocks section - # if hasattr(block, "trigger_inputs") and block.trigger_inputs: - # output += "------------------\n" - # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - # output += f"Trigger Inputs: {block.trigger_inputs}\n" - # # Get first trigger input as example - # example_input = next(t for t in block.trigger_inputs if t is not None) - # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - # output += "Check `.doc` of returned object for more information.\n\n" - - # # List the call parameters - # full_doc = self.pipeline_block.doc - # if "------------------------" in full_doc: - # full_doc = full_doc.split("------------------------")[0].rstrip() - # output += full_doc - - # return output - - # YiYi TODO: try to unify the to method with the one in DiffusionPipeline - # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to - def to(self, *args, **kwargs): - r""" - Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the - arguments of `self.to(*args, **kwargs).` - - - - If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, - the returned pipeline is a copy of self with the desired torch.dtype and torch.device. - - + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.create(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) - Here are the ways to call `to`: + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + for name in expected_component: + for spec in component_specs: + if spec.name == name: + break + else: + # append a empty component spec for these not in modular_model_index + component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs) - - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the - specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. - Arguments: - dtype (`torch.dtype`, *optional*): - Returns a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - device (`torch.Device`, *optional*): - Returns a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - silence_dtype_warnings (`str`, *optional*, defaults to `False`): - Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. Returns: - [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } """ - dtype = kwargs.pop("dtype", None) - device = kwargs.pop("device", None) - silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - - dtype_arg = None - device_arg = None - if len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype_arg = args[0] - else: - device_arg = torch.device(args[0]) if args[0] is not None else None - elif len(args) == 2: - if isinstance(args[0], torch.dtype): - raise ValueError( - "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." - ) - device_arg = torch.device(args[0]) if args[0] is not None else None - dtype_arg = args[1] - elif len(args) > 2: - raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") - - if dtype is not None and dtype_arg is not None: - raise ValueError( - "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - dtype = dtype or dtype_arg - - if device is not None and device_arg is not None: - raise ValueError( - "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - device = device or device_arg - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and ( - isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) - or hasattr(module._hf_hook, "hooks") - and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) - ) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - - if is_loaded_in_8bit and dtype is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." - ) - - if is_loaded_in_8bit and device is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." - ) - else: - module.to(device, dtype) - - if ( - module.dtype == torch.float16 - and str(device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - return self + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..c8064a5215aa --- /dev/null +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -0,0 +1,592 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin + +if is_torch_available(): + import torch + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: torch.nn.Module) -> Any: + """Create a ComponentSpec from a Component created by `create` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` method") + + type_hint = component.__class__ + + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, **load_spec) + + @classmethod + def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: + """Create a ComponentSpec from a load_id string.""" + if load_id == "null": + raise ValueError("Cannot create ComponentSpec from null load_id") + + # Decode the load_id into a dictionary of loading fields + load_fields = cls.decode_load_id(load_id) + + # Create a new ComponentSpec instance with the decoded fields + return cls(name=name, **load_fields) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not loaded from pretrained). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + # YiYi TODO: add validator + def create(self, **kwargs) -> Any: + """Create the component using the preferred creation method.""" + + # from_pretrained creation + if self.default_creation_method == "from_pretrained": + return self.create_from_pretrained(**kwargs) + elif self.default_creation_method == "from_config": + # from_config creation + return self.create_from_config(**kwargs) + else: + raise ValueError(f"Invalid creation method: {self.default_creation_method}") + + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def create_from_pretrained(self, **kwargs) -> Any: + """Create component using from_pretrained.""" + + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + + if repo != self.repo: + self.repo = repo + for k, v in passed_loading_kwargs.items(): + if v is not None: + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + input_parts.append(inp.name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a9d6c561af34..48d5992f31ee 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -333,6 +333,20 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): @@ -414,7 +428,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": + if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -841,7 +855,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c27cd434cd9a..22b0baee2e39 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1917,9 +1917,10 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] missing_modules = ( set(expected_modules) - - set(pipeline._optional_components) + - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 584b260eaaa8..006836fe30d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -34,7 +34,7 @@ "StableDiffusionXLDecodeLatentsStep", "StableDiffusionXLDenoiseStep", "StableDiffusionXLInputStep", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPrepareAdditionalConditioningStep", "StableDiffusionXLPrepareLatentsStep", "StableDiffusionXLSetTimestepsStep", @@ -65,7 +65,7 @@ StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDenoiseStep, StableDiffusionXLInputStep, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 2493d5635552..5ae9e63851db 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -34,7 +34,7 @@ from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, - ModularPipeline, + ModularLoader, PipelineBlock, PipelineState, InputParam, @@ -56,8 +56,9 @@ CLIPVisionModelWithProjection, ) -from ...schedulers import KarrasDiffusionSchedulers -from ...guiders import GuiderType, ClassifierFreeGuidance +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance +from ...configuration_utils import FrozenDict import numpy as np @@ -182,9 +183,13 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -320,7 +325,11 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -645,7 +654,11 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property @@ -740,8 +753,16 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), - ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), ] @@ -1028,7 +1049,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1151,7 +1172,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1206,7 +1227,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1460,7 +1481,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1608,7 +1629,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1727,7 +1748,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", default=False),] + return [ConfigSpec("requires_aesthetics_score", False),] @property def description(self) -> str: @@ -2062,8 +2083,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2245,7 +2270,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -2316,11 +2341,15 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property @@ -2626,7 +2655,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) # (5) Denoise loop - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -2733,9 +2762,17 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config"), ] @property @@ -3029,7 +3066,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -3136,7 +3173,11 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()) + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property @@ -3527,9 +3568,14 @@ def description(self): } -# YiYi TODO: rename to components etc. and not inherit from ModularPipeline -class StableDiffusionXLModularPipeline( - ModularPipeline, +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bea14cfe9c8d..f3837e39f192 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1328,7 +1328,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipeline(metaclass=DummyObject): +class ModularLoader(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0a2c1eefae12..cbfbb842723a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2417,7 +2417,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularPipeline(metaclass=DummyObject): +class StableDiffusionXLModularLoader(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From c8b5d5641271f88dc9c0ab41ca48e39ef143df3f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 2 May 2025 00:46:31 +0200 Subject: [PATCH 07/54] make loader optional --- src/diffusers/pipelines/modular_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 636b543395df..c994b91ba8bb 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -199,7 +199,8 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, state = PipelineState() if not hasattr(self, "loader"): - raise ValueError("Loader is not set, please call `setup_loader()` first.") + logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None # Make a copy of the input kwargs input_params = kwargs.copy() From 7b86fcea31d7c968e774dd16c275f601c2bed0fb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 2 May 2025 11:31:25 +0200 Subject: [PATCH 08/54] remove lora step and ip-adapter step -> no longer needed --- .../pipeline_stable_diffusion_xl_modular.py | 168 ------------------ 1 file changed, 168 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ae9e63851db..0d068f90f7e6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -140,174 +140,6 @@ def retrieve_latents( -# YiYi Notes: I think we do not need this, we can add loader methods on the components class -class StableDiffusionXLLoraStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("unet", UNet2DConditionModel), - ] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - pipeline, - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=data.prepare_unconditional_embeds, - ) - if data.prepare_unconditional_embeds: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" From 7ca860c24bc35fccf5a68db2f92af932819f0b24 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 May 2025 01:32:59 +0200 Subject: [PATCH 09/54] rename pipeline -> components, data -> block_state --- .../pipeline_stable_diffusion_xl_modular.py | 1554 +++++++++-------- 1 file changed, 872 insertions(+), 682 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0d068f90f7e6..81808540ee67 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -65,6 +65,51 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -140,6 +185,148 @@ def retrieve_latents( +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -189,15 +376,16 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), ] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(block_state): - if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") - elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + @staticmethod def encode_prompt( - self, components, prompt: str, prompt_2: Optional[str] = None, @@ -255,7 +443,7 @@ def encode_prompt( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - device = device or self._execution_device + device = device or components._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it @@ -433,42 +621,42 @@ def encode_prompt( @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(block_state) - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None ) ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, ) = self.encode_prompt( - pipeline, - data.prompt, - data.prompt_2, - data.device, + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, 1, - data.prepare_unconditional_embeds, - data.negative_prompt, - data.negative_prompt_2, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, ) # Add outputs - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLVaeEncoderStep(PipelineBlock): @@ -552,30 +740,30 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - data.batch_size = data.image.shape[0] + block_state.batch_size = block_state.image.shape[0] # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." ) - data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): @@ -715,47 +903,47 @@ def prepare_mask_latents( @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + block_state = self.get_block_state(state) - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" else: - data.crops_coords = None - data.resize_mode = "default" + block_state.crops_coords = None + block_state.resize_mode = "default" - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator) + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInputStep(PipelineBlock): @@ -802,77 +990,77 @@ def intermediates_outputs(self) -> List[str]: OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), ] - def check_inputs(self, pipeline, data): + def check_inputs(self, components, block_state): - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." ) - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): raise ValueError("`negative_ip_adapter_embeds` must be a list") - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: raise ValueError( "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." + f" {block_state.negative_ip_adapter_embeds[i].shape}." ) @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype - _, seq_len, _ = data.prompt_embeds.shape + _, seq_len, _ = block_state.prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @@ -961,40 +1149,40 @@ def get_timesteps(self, components, num_inference_steps, strength, device, denoi @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.device = pipeline._execution_device + block_state.device = components._execution_device - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas ) def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - data.timesteps, data.num_inference_steps = self.get_timesteps( - pipeline, - data.num_inference_steps, - data.strength, - data.device, - denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, ) - data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLSetTimestepsStep(PipelineBlock): @@ -1029,27 +1217,27 @@ def intermediates_outputs(self) -> List[OutputParam]: @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.device = pipeline._execution_device + block_state.device = components._execution_device - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas ) - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @@ -1133,7 +1321,46 @@ def intermediates_outputs(self) -> List[str]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument def prepare_latents_inpaint( self, components, @@ -1252,58 +1479,58 @@ def prepare_mask_latents( @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device - data.is_strength_max = data.strength == 1.0 + block_state.is_strength_max = block_state.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: - data.masked_image_latents = None - - data.add_noise = True if data.denoising_start is None else False - - data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor - data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - - data.latents, data.noise = self.prepare_latents_inpaint( - pipeline, - data.batch_size * data.num_images_per_prompt, - pipeline.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - image=data.image_latents, - timestep=data.latent_timestep, - is_strength_max=data.is_strength_max, - add_noise=data.add_noise, + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, return_noise=True, return_image_latents=False, ) # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image_latents, - data.batch_size * data.num_images_per_prompt, - data.height, - data.width, - data.dtype, - data.device, - data.generator, + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @@ -1343,21 +1570,17 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components # YiYi TODO: refactor using _encode_vae_image + @staticmethod def prepare_latents_img2img( - self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None: - components.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -1431,28 +1654,28 @@ def prepare_latents_img2img( return latents @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - data.add_noise = True if data.denoising_start is None else False - if data.latents is None: - data.latents = self.prepare_latents_img2img( - pipeline, - data.image_latents, - data.latent_timestep, - data.batch_size, - data.num_images_per_prompt, - data.dtype, - data.device, - data.generator, - data.add_noise, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = self.prepare_latents_img2img( + components, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @@ -1508,19 +1731,20 @@ def intermediates_outputs(self) -> List[OutputParam]: @staticmethod - def check_inputs(pipeline, data): + def check_inputs(components, block_state): if ( - data.height is not None - and data.height % pipeline.vae_scale_factor != 0 - or data.width is not None - and data.width % pipeline.vae_scale_factor != 0 + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 ): raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, @@ -1544,34 +1768,34 @@ def prepare_latents(self, components, batch_size, num_channels_latents, height, @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.dtype is None: - data.dtype = pipeline.vae.dtype - - data.device = pipeline._execution_device - - self.check_inputs(pipeline, data) - - data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor - data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor - data.num_channels_latents = pipeline.num_channels_latents - data.latents = self.prepare_latents( - pipeline, - data.batch_size * data.num_images_per_prompt, - data.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @@ -1617,8 +1841,8 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod def _get_add_time_ids_img2img( - self, components, original_size, crops_coords_top_left, @@ -1670,8 +1894,9 @@ def _get_add_time_ids_img2img( return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -1701,57 +1926,57 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - if data.negative_original_size is None: - data.negative_original_size = data.original_size - if data.negative_target_size is None: - data.negative_target_size = data.target_size - - data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.aesthetic_score, - data.negative_aesthetic_score, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - dtype=data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None + block_state.timestep_cond = None if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @@ -1805,8 +2030,9 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod def _get_add_time_ids( - self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -1824,8 +2050,9 @@ def _get_add_time_ids( return add_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -1855,57 +2082,57 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - data.add_time_ids = self._get_add_time_ids( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) - if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = self._get_add_time_ids( - pipeline, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) else: - data.negative_add_time_ids = data.add_time_ids + block_state.negative_add_time_ids = block_state.add_time_ids - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None + block_state.timestep_cond = None if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLDenoiseStep(PipelineBlock): @@ -2041,27 +2268,29 @@ def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2079,42 +2308,42 @@ def prepare_extra_step_kwargs(self, components, generator, eta): return extra_step_kwargs @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.num_channels_unet = components.unet.config.in_channels + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2125,45 +2354,45 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds # Predict the noise residual - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) + block_state.latents = block_state.latents.to(block_state.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @@ -2308,30 +2537,31 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() + @staticmethod def prepare_control_image( - self, components, image, width, @@ -2359,7 +2589,8 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2378,108 +2609,108 @@ def prepare_extra_step_kwargs(self, components, generator, eta): @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels + block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(pipeline.controlnet) + controlnet = unwrap_module(components.controlnet) # (1.1) # control_guidance_start/control_guidance_end (align format) - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - data.control_guidance_start, data.control_guidance_end = ( - mult * [data.control_guidance_start], - mult * [data.control_guidance_end], + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], ) # (1.2) # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): - data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) # (1.3) # global_pool_conditions - data.global_pool_conditions = ( + block_state.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) # (1.4) # guess_mode - data.guess_mode = data.guess_mode or data.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.5) # control_image if isinstance(controlnet, ControlNetModel): - data.control_image = self.prepare_control_image( - pipeline, - image=data.control_image, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] - for control_image_ in data.control_image: + for control_image_ in block_state.control_image: control_image = self.prepare_control_image( - pipeline, + components, image=control_image_, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) control_images.append(control_image) - data.control_image = control_images + block_state.control_image = control_images else: assert False # (1.6) # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): keeps = [ - 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) - for s, e in zip(data.control_guidance_start, data.control_guidance_end) + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] - data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), @@ -2487,23 +2718,23 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) # (5) Denoise loop - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(block_state.controlnet_cond_scale, list): + block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] + block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2520,70 +2751,70 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, + if components.guider.is_conditional or not block_state.guess_mode: + block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, + controlnet_cond=block_state.control_image, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample + batch.down_block_res_samples = block_state.down_block_res_samples + batch.mid_block_res_sample = block_state.mid_block_res_sample - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + if components.guider.is_unconditional and block_state.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, down_block_additional_residuals=batch.down_block_res_samples, mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) + block_state.latents = block_state.latents.to(block_state.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): @@ -2731,31 +2962,32 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() + @staticmethod def prepare_control_image( - self, components, image, width, @@ -2785,7 +3017,8 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2803,118 +3036,118 @@ def prepare_extra_step_kwargs(self, components, generator, eta): return extra_step_kwargs @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels + block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(pipeline.controlnet) + controlnet = unwrap_module(components.controlnet) # (1.1) # control guidance - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] # (1.2) # global_pool_conditions & guess_mode - data.global_pool_conditions = controlnet.config.global_pool_conditions - data.guess_mode = data.guess_mode or data.global_pool_conditions + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.3) # control_type - data.num_control_type = controlnet.config.num_control_type + block_state.num_control_type = controlnet.config.num_control_type # (1.4) # control_type - if not isinstance(data.control_image, list): - data.control_image = [data.control_image] + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] - if not isinstance(data.control_mode, list): - data.control_mode = [data.control_mode] + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] - if len(data.control_image) != len(data.control_mode): + if len(block_state.control_image) != len(block_state.control_mode): raise ValueError("Expected len(control_image) == len(control_type)") - data.control_type = [0 for _ in range(data.num_control_type)] - for control_idx in data.control_mode: - data.control_type[control_idx] = 1 + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 - data.control_type = torch.Tensor(data.control_type) + block_state.control_type = torch.Tensor(block_state.control_type) # (1.5) # prepare control_image - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = self.prepare_control_image( - pipeline, - image=data.control_image[idx], - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) - data.height, data.width = data.control_image[idx].shape[-2:] + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] # (1.6) # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - data.controlnet_keep.append( + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( 1.0 - - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) ) # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() - data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) - repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] - data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) + block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(block_state.controlnet_cond_scale, list): + block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] + block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2931,70 +3164,70 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, + if components.guider.is_conditional or not block_state.guess_mode: + block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, + controlnet_cond=block_state.control_image, + control_type=block_state.control_type, + control_type_idx=block_state.control_mode, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample + batch.down_block_res_samples = block_state.down_block_res_samples + batch.mid_block_res_sample = block_state.mid_block_res_sample - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + if components.guider.is_unconditional and block_state.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, down_block_additional_residuals=batch.down_block_res_samples, mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + if block_state.num_channels_unet == 9 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLDecodeLatentsStep(PipelineBlock): @@ -3031,7 +3264,8 @@ def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - def upcast_vae(self, components): + @staticmethod + def upcast_vae(components): dtype = components.vae.dtype components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( @@ -3049,57 +3283,57 @@ def upcast_vae(self, components): components.vae.decoder.mid_block.to(dtype) @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if not data.output_type == "latent": + if not block_state.output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - if data.needs_upcasting: - self.upcast_vae(pipeline) - data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif data.latents.dtype != pipeline.vae.dtype: + if block_state.needs_upcasting: + self.upcast_vae(components) + block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif block_state.latents.dtype != components.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(data.latents.dtype) + components.vae = components.vae.to(block_state.latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None - data.has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None ) - data.has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None ) - if data.has_latents_mean and data.has_latents_std: - data.latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) ) - data.latents_std = ( - torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) ) - data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean + block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean else: - data.latents = data.latents / pipeline.vae.config.scaling_factor + block_state.latents = block_state.latents / components.vae.config.scaling_factor - data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] # cast back to fp16 if needed - if data.needs_upcasting: - pipeline.vae.to(dtype=torch.float16) + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) else: - data.images = data.latents + block_state.images = block_state.latents # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - data.images = pipeline.watermark.apply_watermark(data.images) + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) - data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): @@ -3130,15 +3364,15 @@ def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if data.padding_mask_crop is not None and data.crops_coords is not None: - data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLOutputStep(PipelineBlock): @@ -3162,15 +3396,15 @@ def intermediates_outputs(self) -> List[str]: @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if not data.return_dict: - data.images = (data.images,) + if not block_state.return_dict: + block_state.images = (block_state.images,) else: - data.images = StableDiffusionXLPipelineOutput(images=data.images) - self.add_block_state(state, data) - return pipeline, state + block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) + self.add_block_state(state, block_state) + return components, state # Encode @@ -3400,50 +3634,6 @@ def description(self): } -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - - # YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks SDXL_INPUTS_SCHEMA = { From efd70b783871aa7b3e02bd8252afbc8e45eeb314 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 May 2025 20:22:05 +0200 Subject: [PATCH 10/54] seperate controlnet step into input + denoise --- .../pipeline_stable_diffusion_xl_modular.py | 466 +++++++++++------- 1 file changed, 299 insertions(+), 167 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 81808540ee67..ea774283437a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2395,27 +2395,20 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): +class StableDiffusionXLControlNetInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step that prepare inputs for controlnet" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -2426,9 +2419,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), ] @property @@ -2452,110 +2442,25 @@ def intermediates_inputs(self) -> List[str]: type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), InputParam( "crops_coords", type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + return [ + OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] - @staticmethod - def check_inputs(components, block_state): - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance @@ -2588,33 +2493,12 @@ def prepare_control_image( image = image.to(device=device, dtype=dtype) return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs block_state.device = components._execution_device @@ -2699,17 +2583,243 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # (2) Prepare conditional inputs for unet using the guider + + + self.add_block_state(state, block_state) + + return components, state + +class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam("num_images_per_prompt", default=1), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "control_image", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "add_time_ids", + required=True, + type_hint=torch.Tensor, + description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], + description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + InputParam( + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + block_state.device = components._execution_device + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + # (1) setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), @@ -2720,11 +2830,16 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # (5) Denoise loop with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) + # prepare latent input for unet block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + # adjust latent input for inpainting + block_state.num_channels_unet = components.unet.config.in_channels + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + # cond_scale (controlnet input) if isinstance(block_state.controlnet_keep[i], list): block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: @@ -2733,62 +2848,69 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - for batch in guider_data: + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(block_state) + + for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + if guider_state_batch.ip_adapter_embeds is not None: + guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - # Will always be run atleast once with every guider - if components.guider.is_conditional or not block_state.guess_mode: - block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros + guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, controlnet_cond=block_state.control_image, conditioning_scale=block_state.cond_scale, guess_mode=block_state.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, + added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = block_state.down_block_res_samples - batch.mid_block_res_sample = block_state.mid_block_res_sample + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - if components.guider.is_unconditional and block_state.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - batch.noise_pred = components.unet( + + guider_state_batch.noise_pred = components.unet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, timestep_cond=block_state.timestep_cond, cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, + added_cond_kwargs=guider_state_batch.added_cond_kwargs, + down_block_additional_residuals=guider_state_batch.down_block_res_samples, + mid_block_additional_residual=guider_state_batch.mid_block_res_sample, return_dict=False, )[0] components.guider.cleanup_models(components.unet) # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype @@ -2799,6 +2921,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 block_state.latents = block_state.latents.to(block_state.latents_dtype) + # adjust latent for inpainting if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: block_state.init_latents_proper = block_state.image_latents if i < len(block_state.timesteps) - 1: @@ -3463,6 +3586,16 @@ def description(self): " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" +class StableDiffusionXLControlNetStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "Controlnet step that denoise the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] @@ -3477,10 +3610,9 @@ def description(self): " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3489,7 +3621,7 @@ def description(self): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3597,7 +3729,7 @@ def description(self): ]) CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), + ("denoise", StableDiffusionXLControlNetStep), ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ From 43ac1ff7e78ffdf8fa91932769236a7995ac482e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 4 May 2025 22:17:25 +0200 Subject: [PATCH 11/54] refactor controlnet union --- .../pipeline_stable_diffusion_xl_modular.py | 426 ++++++++++++------ 1 file changed, 284 insertions(+), 142 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ea774283437a..5ebdd383ccbb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2613,12 +2613,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2755,6 +2749,12 @@ def intermediates_inputs(self) -> List[str]: type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), ] @property @@ -2940,25 +2940,198 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", default=[0]), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_mode", type_hint=List[int], description="The control mode indices"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = block_state.device or components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + + + self.add_block_state(state, block_state) + + return components, state + class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), - ComponentSpec( - "control_image_processor", - VaeImageProcessor, - config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), - default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetUnionModel), ] @property @@ -2967,12 +3140,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("control_mode", required=True), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2983,15 +3150,75 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", + "control_image", + required=True, + type_hint=List[torch.Tensor], + description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_mode", + required=True, + type_hint=List[int], + description="The control mode indices. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_type", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step." ), InputParam( - "batch_size", + "num_control_type", required=True, type_hint=int, + description="The number of control types available. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="Whether guess mode is used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "global_pool_conditions", + required=True, + type_hint=bool, + description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( @@ -3045,23 +3272,23 @@ def intermediates_inputs(self) -> List[str]: description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( @@ -3070,19 +3297,19 @@ def intermediates_inputs(self) -> List[str]: description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), ] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @staticmethod @@ -3105,39 +3332,7 @@ def check_inputs(components, block_state): " `components.unet` or your `mask_image` or `image` input." ) - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod @@ -3164,85 +3359,20 @@ def __call__(self, components, state: PipelineState) -> PipelineState: self.check_inputs(components, block_state) block_state.num_channels_unet = components.unet.config.in_channels - - # (1) prepare controlnet inputs block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control guidance - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.3) - # control_type - block_state.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - - block_state.control_type = torch.Tensor(block_state.control_type) - # (1.5) - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + # Setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) @@ -3612,7 +3742,7 @@ def description(self): # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3620,8 +3750,8 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): def description(self): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3733,7 +3863,7 @@ def description(self): ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), + ("denoise", StableDiffusionXLControlNetUnionStep), ]) IP_ADAPTER_BLOCKS = OrderedDict([ @@ -3865,3 +3995,15 @@ def description(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } + + +class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "ControlNetUnion step that denoises the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model." From dc4dbfe10711f4f4e70c435a996cfedec00e5218 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 May 2025 09:58:44 +0200 Subject: [PATCH 12/54] reefactor pipeline/block states so that it can dynamically accept kwargs --- src/diffusers/pipelines/modular_pipeline.py | 153 ++++++++++++++---- .../pipelines/modular_pipeline_utils.py | 4 +- 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index c994b91ba8bb..1733ad6d4e00 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -73,18 +73,72 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - def add_input(self, key: str, value: Any): + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any): + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) @@ -106,11 +160,17 @@ def format_value(v): inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) return ( f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" f")" ) @@ -146,10 +206,16 @@ def format_value(v): # Handle dicts with tensor values elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict # Default case return repr(v) @@ -203,30 +269,34 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, self.loader = None # Make a copy of the input kwargs - input_params = kwargs.copy() + passed_kwargs = kwargs.copy() - default_params = self.default_call_parameters # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) + state.add_input(name, passed_kwargs.pop(name), kwargs_type) else: - state.add_input(name, input_params[name]) + state.add_input(name, passed_kwargs[name], kwargs_type) elif name not in state.inputs: - state.add_input(name, default) + state.add_input(name, default, kwargs_type) - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + if len(passed_kwargs) > 0: + logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -390,25 +460,50 @@ def get_block_state(self, state: PipelineState) -> dict: # Check inputs for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v # Check intermediates for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index c8064a5215aa..f300f259f9eb 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -244,11 +244,12 @@ class ConfigSpec: @dataclass class InputParam: """Specification for an input parameter.""" - name: str + name: str = None type_hint: Any = None default: Any = None required: bool = False description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -260,6 +261,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" From f552773572a9a27d80aa35910e45c26883583bc5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 May 2025 10:00:14 +0200 Subject: [PATCH 13/54] remove controlnet union denoise step, refactor & reuse controlnet denoisee step to accept aditional contrlnet kwargs --- .../pipeline_stable_diffusion_xl_modular.py | 475 +++--------------- 1 file changed, 57 insertions(+), 418 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ebdd383ccbb..119c92e06f1d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2452,11 +2452,11 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image", kwargs_type="contronet_kwargs"), OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -2582,6 +2582,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale @@ -2615,15 +2618,16 @@ def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), + InputParam("generator", kwargs_type="scheduler_kwargs"), + InputParam("eta", default=0.0, kwargs_type="scheduler_kwargs"), + InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) ] @property def intermediates_inputs(self) -> List[str]: return [ InputParam( - "control_image", + "controlnet_cond", required=True, type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." @@ -2641,8 +2645,7 @@ def intermediates_inputs(self) -> List[str]: description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "controlnet_conditioning_scale", - required=True, + "conditioning_scale", type_hint=float, description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), @@ -2755,6 +2758,7 @@ def intermediates_inputs(self) -> List[str]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), + InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") ] @property @@ -2780,26 +2784,16 @@ def check_inputs(components, block_state): f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `components.unet` or your `mask_image` or `image` input." ) - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs + return extra_kwargs @torch.no_grad() @@ -2808,9 +2802,15 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state = self.get_block_state(state) self.check_inputs(components, block_state) block_state.device = components._execution_device + print(f" block_state: {block_state}") + + controlnet = unwrap_module(components.controlnet) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + # YiYI TODO: refactor scheduler_kwargs and support unet kwargs + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) # (1) setup guider @@ -2841,9 +2841,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # cond_scale (controlnet input) if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] else: - block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + block_state.controlnet_cond_scale = block_state.conditioning_scale if isinstance(block_state.controlnet_cond_scale, list): block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] @@ -2882,11 +2882,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.scaled_latents, t, encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.control_image, - conditioning_scale=block_state.cond_scale, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.conditioning_scale, guess_mode=block_state.guess_mode, added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, return_dict=False, + **block_state.extra_controlnet_kwargs, ) if block_state.down_block_res_samples_zeros is None: @@ -2958,7 +2959,7 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("control_image", required=True), - InputParam("control_mode", default=[0]), + InputParam("control_mode", required=True), InputParam("control_guidance_start", default=0.0), InputParam("control_guidance_end", default=1.0), InputParam("controlnet_conditioning_scale", default=1.0), @@ -2973,7 +2974,7 @@ def intermediates_inputs(self) -> List[InputParam]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." ), InputParam( "batch_size", @@ -2991,7 +2992,7 @@ def intermediates_inputs(self) -> List[InputParam]: "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." ), InputParam( "crops_coords", @@ -3003,13 +3004,13 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_mode", type_hint=List[int], description="The control mode indices"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"), + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images", kwargs_type="controlnet_kwargs"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -3051,7 +3052,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt controlnet = unwrap_module(components.controlnet) - device = block_state.device or components._execution_device + device = components._execution_device dtype = block_state.dtype or components.controlnet.dtype block_state.height, block_state.width = block_state.latents.shape[-2:] @@ -3069,10 +3070,10 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.global_pool_conditions = controlnet.config.global_pool_conditions block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - + # control_image if not isinstance(block_state.control_image, list): block_state.control_image = [block_state.control_image] - + # control_mode if not isinstance(block_state.control_mode, list): block_state.control_mode = [block_state.control_mode] @@ -3112,371 +3113,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt 1.0 - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) ) - - - self.add_block_state(state, block_state) - - return components, state - -class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ] - - @property - def description(self) -> str: - return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "control_image", - required=True, - type_hint=List[torch.Tensor], - description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_mode", - required=True, - type_hint=List[int], - description="The control mode indices. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_type", - required=True, - type_hint=torch.Tensor, - description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "num_control_type", - required=True, - type_hint=int, - description="The number of control types available. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_guidance_start", - required=True, - type_hint=float, - description="The control guidance start value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_guidance_end", - required=True, - type_hint=float, - description="The control guidance end value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "controlnet_conditioning_scale", - required=True, - type_hint=float, - description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="Whether guess mode is used. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "global_pool_conditions", - required=True, - type_hint=bool, - description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels - block_state.device = components._execution_device - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - # Setup guider - # disable for LCMs - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] - else: - block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale - if isinstance(block_state.controlnet_cond_scale, list): - block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] - block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - - for batch in guider_data: - components.guider.prepare_models(components.unet) - - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - - # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - - # Will always be run atleast once with every guider - if components.guider.is_conditional or not block_state.guess_mode: - block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=block_state.control_image, - control_type=block_state.control_type, - control_type_idx=block_state.control_mode, - conditioning_scale=block_state.cond_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, - return_dict=False, - ) - - batch.down_block_res_samples = block_state.down_block_res_samples - batch.mid_block_res_sample = block_state.mid_block_res_sample - - if components.guider.is_unconditional and block_state.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - if block_state.num_channels_unet == 9 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale self.add_block_state(state, block_state) @@ -3727,6 +3366,18 @@ def description(self): " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." +class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "ControlNetUnion step that denoises the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents using the ControlNetUnion model." + + class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] block_names = ["inpaint", "img2img", "text2img"] @@ -3995,15 +3646,3 @@ def description(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } - - -class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "ControlNetUnion step that denoises the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model." From 16b6583fa8bcd0d5595984ac4c7f08c91ab3af3f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:25:31 +0200 Subject: [PATCH 14/54] allow input_fields as input & update message --- src/diffusers/guiders/adaptive_projected_guidance.py | 10 +++++++--- src/diffusers/guiders/auto_guidance.py | 10 +++++++--- src/diffusers/guiders/classifier_free_guidance.py | 10 +++++++--- .../guiders/classifier_free_zero_star_guidance.py | 10 +++++++--- src/diffusers/guiders/guider_utils.py | 4 ++-- src/diffusers/guiders/skip_layer_guidance.py | 10 +++++++--- src/diffusers/guiders/smoothed_energy_guidance.py | 10 +++++++--- .../guiders/tangential_classifier_free_guidance.py | 10 +++++++--- 8 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 7da1cc59a365..83e93c15ff1d 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,14 +73,18 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index bfffb9f39cd2..8bb6083781c2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -120,11 +120,15 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429f8450410a..429392e3f9c6 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -75,11 +75,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 4c9839ee78f3..220a95e54a8d 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,11 +73,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 7d005442e89c..18c85f579424 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -174,7 +174,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da from ..pipelines.modular_pipeline import BlockState if input_fields is None: - raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") data_batch = {} for key, value in input_fields.items(): try: @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + logger.warning(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bdd9e4af81b6..56dae1903606 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 1c7ee45dc3db..c215cb0afdc9 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -149,7 +149,11 @@ def cleanup_models(self, denoiser: torch.nn.Module): for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -161,7 +165,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 631f9a5f33b2..9fa8f9454134 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -62,11 +62,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches From d89631fc50578dc5de0b95400b7d796daa8b0abc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:27:17 +0200 Subject: [PATCH 15/54] update input formating, consider kwarggs_type inputs with no name, e/g *_controlnet_kwargs --- src/diffusers/pipelines/modular_pipeline_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index f300f259f9eb..392d6dcd9521 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -322,7 +322,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu if inp.name in required_intermediates_inputs: input_parts.append(f"Required({inp.name})") else: - input_parts.append(inp.name) + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} From 0f0618ff2b53397485fec8ca8c3fb698434019ef Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:28:52 +0200 Subject: [PATCH 16/54] refactor the denoiseestep using LoopSequential! also add a new file for denoise step --- src/diffusers/pipelines/modular_pipeline.py | 288 +++- .../pipeline_stable_diffusion_xl_modular.py | 1185 +++++++++-------- ...table_diffusion_xl_modular_denoise_loop.py | 729 ++++++++++ 3 files changed, 1607 insertions(+), 595 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 1733ad6d4e00..92cb50a8b490 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -184,6 +184,23 @@ def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + def __repr__(self): def format_value(v): # Handle tensors directly @@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li for block_name, inputs in named_input_lists: for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] if (current_param.default is not None and input_param.default is not None and current_param.default != input_param.default): @@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> for block_name, outputs in named_output_lists: for output_param in outputs: - if output_param.name not in combined_dict: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param return list(combined_dict.values()) @@ -919,6 +940,9 @@ def required_intermediates_inputs(self) -> List[str]: # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks @@ -931,6 +955,9 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1169,7 +1196,262 @@ def doc(self): expected_configs=self.expected_configs ) +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 119c92e06f1d..7869e11a9cd5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -370,10 +370,10 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), ] @staticmethod @@ -982,12 +982,12 @@ def intermediates_outputs(self) -> List[str]: return [ OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), ] def check_inputs(self, components, block_state): @@ -1836,8 +1836,8 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -2025,8 +2025,8 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -2135,264 +2135,265 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs +from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - for batch in guider_data: - components.guider.prepare_models(components.unet) +# for batch in guider_data: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - # Predict the noise residual - batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetInputStep(PipelineBlock): @@ -2452,11 +2453,11 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image", kwargs_type="contronet_kwargs"), + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -2592,353 +2593,353 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator", kwargs_type="scheduler_kwargs"), - InputParam("eta", default=0.0, kwargs_type="scheduler_kwargs"), - InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_start", - required=True, - type_hint=float, - description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_end", - required=True, - type_hint=float, - description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLControlNetDenoiseStep +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - block_state.device = components._execution_device - print(f" block_state: {block_state}") +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") - controlnet = unwrap_module(components.controlnet) +# controlnet = unwrap_module(components.controlnet) - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - # YiYI TODO: refactor scheduler_kwargs and support unet kwargs - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (1) setup guider - # disable for LCMs - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - # (5) Denoise loop - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - - # prepare latent input for unet - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # adjust latent input for inpainting - block_state.num_channels_unet = components.unet.config.in_channels - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - - # cond_scale (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - block_state.controlnet_cond_scale = block_state.conditioning_scale - if isinstance(block_state.controlnet_cond_scale, list): - block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] - block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_state = components.guider.prepare_inputs(block_state) +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - guider_state_batch.added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } - if guider_state_batch.ip_adapter_embeds is not None: - guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - # Prepare controlnet additional conditionings - guider_state_batch.controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } - - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros - guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.conditioning_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, - return_dict=False, - **block_state.extra_controlnet_kwargs, - ) +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=guider_state_batch.added_cond_kwargs, - down_block_additional_residuals=guider_state_batch.down_block_res_samples, - mid_block_additional_residual=guider_state_batch.mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - # adjust latent for inpainting - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): @@ -3004,13 +3005,13 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images", kwargs_type="controlnet_kwargs"), + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py new file mode 100644 index 000000000000..92c07854fc74 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py @@ -0,0 +1,729 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + LoopSequentialPipelineBlocks, + InputParam, + OutputParam, + BlockState, + ComponentSpec, +) +from ...guiders import ClassifierFreeGuidance +from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + + +# StableDiffusionXLControlNetDenoiseStep + +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + + + From c677d528e4c1c33b6c73c549c7a5f74ab8635f5e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:16:24 +0200 Subject: [PATCH 17/54] change warning to debug --- src/diffusers/guiders/guider_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 18c85f579424..df544c955f33 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - logger.warning(f"`data` does not have attribute(s) {value}, skipping.") + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) From 2b361a24132045786b229c1a6bfc3be0bd79e8a1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:17:10 +0200 Subject: [PATCH 18/54] fix get_execusion blocks with loopsequential --- src/diffusers/pipelines/modular_pipeline.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 92cb50a8b490..97a8677bda63 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1033,16 +1033,17 @@ def trigger_inputs(self): def _traverse_trigger_blocks(self, trigger_inputs): # Convert trigger_inputs to a set for easier manipulation active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - # sequential or PipelineBlock + # sequential(include loopsequential) or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): - # sequential - for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} result_blocks.update(blocks_to_update) else: # PipelineBlock @@ -1069,13 +1070,14 @@ def fn_recursive_traverse(block, block_name, active_triggers): matching_trigger = None if this_block is not None: - # sequential/auto + # sequential/auto (keep traversing) if hasattr(this_block, 'blocks'): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? if hasattr(this_block, 'outputs'): active_triggers.update(out.name for out in this_block.outputs) From 2017ae56244f87fb2137888cb440afb1c7a87663 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:19:24 +0200 Subject: [PATCH 19/54] fix auto denoise so all tests pass --- .../pipeline_stable_diffusion_xl_modular.py | 699 +----------------- ...table_diffusion_xl_modular_denoise_loop.py | 688 ++++++++++++++++- 2 files changed, 702 insertions(+), 685 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 7869e11a9cd5..acb395345086 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2134,268 +2134,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt self.add_block_state(state, block_state) return components, state - -from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - class StableDiffusionXLControlNetInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -2593,355 +2331,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLControlNetDenoiseStep -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -3123,6 +2512,13 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + class StableDiffusionXLDecodeLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -3316,8 +2712,8 @@ def description(self): # Before denoise class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3326,12 +2722,13 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3340,12 +2737,13 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3354,29 +2752,8 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLControlNetStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "Controlnet step that denoise the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." - -class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "ControlNetUnion step that denoises the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents using the ControlNetUnion model." + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): @@ -3387,24 +2764,27 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): @property def description(self): return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." -# Denoise -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_union", "controlnet", "unet"] - block_trigger_inputs = ["control_mode", "control_image", None] +# # Denoise +from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLAutoDenoiseStep +# class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): +# block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] +# block_names = ["controlnet_union", "controlnet", "unet"] +# block_trigger_inputs = ["control_mode", "control_image", None] - @property - def description(self): - return "Denoise step that denoise the latents.\n" + \ - "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ - " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." +# @property +# def description(self): +# return "Denoise step that denoise the latents.\n" + \ +# "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ +# " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ +# " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ +# " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): @@ -3474,6 +2854,7 @@ def description(self): # always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the # configuration of guider is. + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -3511,11 +2892,13 @@ def description(self): ]) CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetStep), + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionStep), + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) IP_ADAPTER_BLOCKS = OrderedDict([ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py index 92c07854fc74..63d0784a5762 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py @@ -22,10 +22,11 @@ from ...models import ControlNetModel, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging -from ...utils.torch_utils import unwrap_module +from ...utils.torch_utils import unwrap_module from ..modular_pipeline import ( PipelineBlock, PipelineState, + AutoPipelineBlocks, LoopSequentialPipelineBlocks, InputParam, OutputParam, @@ -42,7 +43,7 @@ # YiYi experimenting composible denoise loop # loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -83,7 +84,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock): +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -145,7 +146,7 @@ def check_inputs(components, block_state): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int): + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) @@ -157,7 +158,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -267,7 +268,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock): +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -468,7 +469,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -535,8 +536,8 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state - -class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock): +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -643,7 +644,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # the loop wrapper that iterates over the timesteps -class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks): +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "stable-diffusion-xl" @@ -706,24 +707,657 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") -# StableDiffusionXLControlNetDenoiseStep - -class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] +# controlnet = unwrap_module(components.controlnet) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) +# return components, state \ No newline at end of file From cf01aaeb49a2632458113f4572dd3929426bd009 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:49:30 +0200 Subject: [PATCH 20/54] update imports on guiders --- src/diffusers/guiders/adaptive_projected_guidance.py | 2 +- src/diffusers/guiders/auto_guidance.py | 2 +- src/diffusers/guiders/classifier_free_guidance.py | 2 +- src/diffusers/guiders/classifier_free_zero_star_guidance.py | 2 +- src/diffusers/guiders/guider_utils.py | 4 ++-- src/diffusers/guiders/skip_layer_guidance.py | 2 +- src/diffusers/guiders/smoothed_energy_guidance.py | 2 +- src/diffusers/guiders/tangential_classifier_free_guidance.py | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 83e93c15ff1d..ef2f3f2c8420 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AdaptiveProjectedGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 8bb6083781c2..791cc582add2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AutoGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429392e3f9c6..a459e51cd083 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 220a95e54a8d..a722f2605036 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeZeroStarGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index df544c955f33..e8e873f5c88f 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name @@ -171,7 +171,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da Returns: `BlockState`: The prepared batch of data. """ - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState if input_fields is None: raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 56dae1903606..7c19f6391f41 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SkipLayerGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index c215cb0afdc9..3986da913f82 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SmoothedEnergyGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 9fa8f9454134..017693fd9f07 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class TangentialClassifierFreeGuidance(BaseGuidance): From 462429b68747dc1c0a313bb1a8f913de207dde6d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:10 +0200 Subject: [PATCH 21/54] remove modular reelated change from pipelines folder --- src/diffusers/pipelines/modular_pipeline.py | 1916 ----------- .../pipelines/modular_pipeline_utils.py | 598 ---- .../pipeline_stable_diffusion_xl_modular.py | 3032 ----------------- ...table_diffusion_xl_modular_denoise_loop.py | 1363 -------- 4 files changed, 6909 deletions(-) delete mode 100644 src/diffusers/pipelines/modular_pipeline.py delete mode 100644 src/diffusers/pipelines/modular_pipeline_utils.py delete mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py delete mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py deleted file mode 100644 index 97a8677bda63..000000000000 --- a/src/diffusers/pipelines/modular_pipeline.py +++ /dev/null @@ -1,1916 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - - -import torch -from tqdm.auto import tqdm -import re -import os -import importlib - -from huggingface_hub.utils import validate_hf_hub_args - -from ..configuration_utils import ConfigMixin, FrozenDict -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, - PushToHubMixin, -) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple -from .modular_pipeline_utils import ( - ComponentSpec, - ConfigSpec, - InputParam, - OutputParam, - format_components, - format_configs, - format_input_params, - format_inputs_short, - format_intermediates_short, - format_output_params, - format_params, - make_doc_string, -) -from .components_manager import ComponentsManager - -from copy import deepcopy -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -MODULAR_LOADER_MAPPING = OrderedDict( - [ - ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), - ] -) - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - - def add_input(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an input to the pipeline state with optional metadata. - - Args: - key (str): The key for the input - value (Any): The input value - kwargs_type (str): The kwargs_type to store with the input - """ - self.inputs[key] = value - if kwargs_type is not None: - if kwargs_type not in self.input_kwargs: - self.input_kwargs[kwargs_type] = [key] - else: - self.input_kwargs[kwargs_type].append(key) - - def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an intermediate value to the pipeline state with optional metadata. - - Args: - key (str): The key for the intermediate value - value (Any): The intermediate value - kwargs_type (str): The kwargs_type to store with the intermediate value - """ - self.intermediates[key] = value - if kwargs_type is not None: - if kwargs_type not in self.intermediate_kwargs: - self.intermediate_kwargs[kwargs_type] = [key] - else: - self.intermediate_kwargs[kwargs_type].append(key) - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.inputs.get(key, default) for key in keys} - - def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all inputs with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of inputs with matching kwargs_type - """ - input_names = self.input_kwargs.get(kwargs_type, []) - return self.get_inputs(input_names) - - def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all intermediates with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of intermediates with matching kwargs_type - """ - intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) - return self.get_intermediates(intermediate_names) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.intermediates.get(key, default) for key in keys} - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - - # Format input_kwargs and intermediate_kwargs - input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) - intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" input_kwargs={{\n{input_kwargs_str}\n }},\n" - f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __getitem__(self, key: str): - # allows block_state["foo"] - return getattr(self, key, None) - - def __setitem__(self, key: str, value: Any): - # allows block_state["foo"] = "bar" - setattr(self, key, value) - - def as_dict(self): - """ - Convert BlockState to a dictionary. - - Returns: - Dict[str, Any]: Dictionary containing all attributes of the BlockState - """ - return {key: value for key, value in self.__dict__.items()} - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - formatted_dict = {} - for k, val in v.items(): - if hasattr(val, "shape") and hasattr(val, "dtype"): - formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" - elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): - shapes = [t.shape for t in val] - formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" - else: - formatted_dict[k] = repr(val) - return formatted_dict - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - - -class ModularPipelineMixin: - """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks - """ - - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): - """ - create a mouldar loader, optionally accept modular_repo to load from hub. - """ - - # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] - diffusers_module = importlib.import_module("diffusers") - loader_class = getattr(diffusers_module, loader_class_name) - - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - # Create the loader with the updated specs - specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) - - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - if not hasattr(self, "loader"): - logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") - self.loader = None - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for expected_input_param in self.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediates_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.intermediates_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - @torch.compiler.disable - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - -class PipelineBlock(ModularPipelineMixin): - - model_name = None - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - - - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs - - @property - def required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - use format_components with add_empty_lines=False - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - components = " " + components_str.replace("\n", "\n ") - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - configs = " " + configs_str.replace("\n", "\n ") - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f"{components}\n" - f"{configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name is None and input_param.kwargs_type is not None: - input_name = "*_" + input_param.kwargs_type - else: - input_name = input_param.name - if input_name in combined_dict: - current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - -class AutoPipelineBlocks(ModularPipelineMixin): - """ - A class that automatically selects a block to run based on the inputs. - - Attributes: - block_classes: List of block classes to be used - block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_names = [] - block_trigger_inputs = [] - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): - raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - - @property - def required_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - @property - def required_intermediates_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - - @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - elif input_name is not None and state.get_intermediate(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.warning(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - - return trigger_inputs - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -class SequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. - """ - block_classes = [] - block_names = [] - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": - """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - - Args: - blocks_dict: Dictionary mapping block names to block instances - - Returns: - A new SequentialPipelineBlocks instance - """ - instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict - return instance - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - return self.get_inputs() - - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_inputs(self) -> List[str]: - return self.get_intermediates_inputs() - - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): - try: - pipeline, state = block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return pipeline, state - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - return fn_recursive_get_trigger(self.blocks) - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): - result_blocks = OrderedDict() - - # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): - # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} - result_blocks.update(blocks_to_update) - else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): - active_triggers.update(out.name for out in block.outputs) - return result_blocks - - # auto - else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - matching_trigger = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - matching_trigger = None - - if this_block is not None: - # sequential/auto (keep traversing) - if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, 'outputs'): - active_triggers.update(out.name for out in this_block.outputs) - - return result_blocks - - all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - all_blocks.update(blocks_to_update) - return all_blocks - - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs - - if trigger_inputs is not None: - - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) - return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -#YiYi TODO: __repr__ -class LoopSequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. - """ - - model_name = None - block_classes = [] - block_names = [] - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def loop_expected_configs(self) -> List[ConfigSpec]: - return [] - - @property - def loop_inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - - @property - def loop_required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def loop_required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - # modified from SequentialPipelineBlocks to include loop_expected_components - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - for component in self.loop_expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - # modified from SequentialPipelineBlocks to include loop_expected_configs - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - for config in self.loop_expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - # modified from SequentialPipelineBlocks to include loop_inputs - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - named_inputs.append(("loop", self.loop_inputs)) - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - # Copied from SequentialPipelineBlocks - @property - def inputs(self): - return self.get_inputs() - - - # modified from SequentialPipelineBlocks to include loop_intermediates_inputs - @property - def intermediates_inputs(self): - intermediates = self.get_intermediates_inputs() - intermediate_names = [input.name for input in intermediates] - for loop_intermediate_input in self.loop_intermediates_inputs: - if loop_intermediate_input.name not in intermediate_names: - intermediates.append(loop_intermediate_input) - return intermediates - - - # Copied from SequentialPipelineBlocks - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - - # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - required_by_loop = set(getattr(self, "loop_required_inputs", set())) - required_by_any.update(required_by_loop) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - for input_param in self.loop_intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - - # YiYi TODO: this need to be thought about more - # modified from SequentialPipelineBlocks to include loop_intermediates_outputs - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - for output in self.loop_intermediates_outputs: - if output.name not in set([output.name for output in combined_outputs]): - combined_outputs.append(output) - return combined_outputs - - # YiYi TODO: this need to be thought about more - # copied from SequentialPipelineBlocks - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - def loop_step(self, components, state: PipelineState, **kwargs): - - for block_name, block in self.blocks.items(): - try: - components, state = block(components, state, **kwargs) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return components, state - - def __call__(self, components, state: PipelineState) -> PipelineState: - raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - -# YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() -class ModularLoader(ConfigMixin, PushToHubMixin): - """ - Base class for all Modular pipelines loaders. - - """ - config_name = "modular_model_index.json" - - - def register_components(self, **kwargs): - """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - - Args: - **kwargs: Keyword arguments where keys are component names and values are component objects. - - """ - for name, module in kwargs.items(): - - # current component spec - component_spec = self._component_specs.get(name) - if component_spec is None: - logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") - continue - - is_registered = hasattr(self, name) - - if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - # actual library and class name of the module - - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) - - else: - library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec - component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} - - # set the component as attribute - # if it is not set yet, just set it and skip the process to check and warn below - if not is_registered: - self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - continue - - current_module = getattr(self, name, None) - # skip if the component is already registered with the same object - if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") - continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") - - # warn if unregister - if current_module is not None and module is None: - logger.info( - f"ModularLoader.register_components: setting '{name}' to None " - f"(was {current_module.__class__.__name__})" - ) - # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: - logger.debug( - f"ModularLoader.register_components: replacing existing '{name}' " - f"(same type {type(current_module).__name__}, new instance)" - ) - - # save modular_model_index.json config - self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec - # finally set models - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - - - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - """ - Initialize the loader with a list of component specs and config specs. - """ - self._component_manager = component_manager - self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } - - # update component_specs and config_specs from modular_repo - if modular_repo is not None: - config_dict = self.load_config(modular_repo, **kwargs) - - for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - self._component_specs[name] = component_spec - - elif name in self._config_specs: - self._config_specs[name].default = value - - register_components_dict = {} - for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None - self.register_components(**register_components_dict) - - default_configs = {} - for name, config_spec in self._config_specs.items(): - default_configs[name] = config_spec.default - self.register_to_config(**default_configs) - - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module): - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - - @property - def components(self) -> Dict[str, Any]: - # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } - - def update(self, **kwargs): - """ - Update components and configs after instance creation. - - Args: - - """ - """ - Update components and configuration values after the loader has been instantiated. - - This method allows you to: - 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) - 2. Update configuration values (e.g., changing requires_safety_checker flag) - - Args: - **kwargs: Component objects or configuration values to update: - - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) - - Raises: - ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) - - Examples: - ```python - # Update multiple components at once - loader.update( - unet=new_unet_model, - text_encoder=new_text_encoder - ) - - # Update configuration values - loader.update( - requires_safety_checker=False, - guidance_rescale=0.7 - ) - - # Update both components and configs together - loader.update( - unet=new_unet_model, - requires_safety_checker=False - ) - ``` - """ - - # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} - passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} - - for name, component in passed_components.items(): - if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - if len(kwargs) > 0: - logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - - self.register_components(**passed_components) - - - config_to_register = {} - for name, new_value in passed_config_values.items(): - - # e.g. requires_aesthetics_score = False - self._config_specs[name].default = new_value - config_to_register[name] = new_value - self.register_to_config(**config_to_register) - - - # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, component_names: Optional[List[str]] = None, **kwargs): - """ - Load selectedcomponents from specs. - - Args: - component_names: List of component names to load - **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. - """ - if component_names is None: - component_names = list(self._component_specs.keys()) - elif not isinstance(component_names, list): - component_names = [component_names] - - components_to_load = set([name for name in component_names if name in self._component_specs]) - unknown_component_names = set([name for name in component_names if name not in self._component_specs]) - if len(unknown_component_names) > 0: - logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - - components_to_register = {} - for name in components_to_load: - spec = self._component_specs[name] - component_load_kwargs = {} - for key, value in kwargs.items(): - if not isinstance(value, dict): - # if the value is a single value, apply it to all components - component_load_kwargs[key] = value - else: - if name in value: - # if it is a dict, check if the component name is in the dict - component_load_kwargs[key] = value[name] - elif "default" in value: - # check if the default is specified - component_load_kwargs[key] = value["default"] - try: - components_to_register[name] = spec.create(**component_load_kwargs) - except Exception as e: - logger.warning(f"Failed to create component '{name}': {e}") - - # Register all components at once - self.register_components(**components_to_register) - - # YiYi TODO: should support to method - def to(self, *args, **kwargs): - pass - - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - - component_names = list(self._component_specs.keys()) - config_names = list(self._config_specs.keys()) - self.register_to_config(_components_names=component_names, _configs_names=config_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - config.pop("_components_names", None) - config.pop("_configs_names", None) - self._internal_dict = FrozenDict(config) - - - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - expected_component = set(config_dict.pop("_components_names")) - expected_config = set(config_dict.pop("_configs_names")) - - component_specs = [] - config_specs = [] - for name, value in config_dict.items(): - if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) - - elif name in expected_config: - config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) - - - @staticmethod - def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: - """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in - `modular_model_index.json`. - - This dict contains: - - "type_hint": Tuple[str, str] - The fully‐qualified module path and class name of the component. - - All loading fields defined by `component_spec.loading_fields()`, typically: - - "repo": Optional[str] - The model repository (e.g., "stabilityai/stable-diffusion-xl"). - - "subfolder": Optional[str] - A subfolder within the repo where this component lives. - - "variant": Optional[str] - An optional variant identifier for the model. - - "revision": Optional[str] - A specific git revision (commit hash, tag, or branch). - - ... any other loading fields defined on the spec. - - Args: - component_spec (ComponentSpec): - The spec object describing one pipeline component. - - Returns: - Dict[str, Any]: A mapping suitable for JSON serialization. - - Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec - >>> from diffusers.models.unet import UNet2DConditionModel - >>> spec = ComponentSpec( - ... name="unet", - ... type_hint=UNet2DConditionModel, - ... config=None, - ... repo="path/to/repo", - ... subfolder="subfolder", - ... variant=None, - ... revision=None, - ... default_creation_method="from_pretrained", - ... ) - >>> ModularLoader._component_spec_to_dict(spec) - { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": None, - "revision": None, - } - """ - if component_spec.type_hint is not None: - lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) - else: - lib_name = None - cls_name = None - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} - return { - "type_hint": (lib_name, cls_name), - **load_spec_dict, - } - - @staticmethod - def _dict_to_component_spec( - name: str, - spec_dict: Dict[str, Any], - ) -> ComponentSpec: - """ - Reconstruct a ComponentSpec from a dict. - """ - # make a shallow copy so we can pop() safely - spec_dict = spec_dict.copy() - # pull out and resolve the stored type_hint - lib_name, cls_name = spec_dict.pop("type_hint") - if lib_name is not None and cls_name is not None: - type_hint = simple_get_class_obj(lib_name, cls_name) - else: - type_hint = None - - # re‐assemble the ComponentSpec - return ComponentSpec( - name=name, - type_hint=type_hint, - **spec_dict, - ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py deleted file mode 100644 index 392d6dcd9521..000000000000 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal - -from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin - -if is_torch_available(): - import torch - - -# YiYi TODO: -# 1. validate the dataclass fields -# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() -@dataclass -class ComponentSpec: - """Specification for a pipeline component. - - A component can be created in two ways: - 1. From scratch using __init__ with a config dict - 2. using `from_pretrained` - - Attributes: - name: Name of the component - type_hint: Type of the component (e.g. UNet2DConditionModel) - description: Optional description of the component - config: Optional config dict for __init__ creation - repo: Optional repo path for from_pretrained creation - subfolder: Optional subfolder in repo - variant: Optional variant in repo - revision: Optional revision in repo - default_creation_method: Preferred creation method - "from_config" or "from_pretrained" - """ - name: Optional[str] = None - type_hint: Optional[Type] = None - description: Optional[str] = None - config: Optional[FrozenDict[str, Any]] = None - # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name - repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) - subfolder: Optional[str] = field(default=None, metadata={"loading": True}) - variant: Optional[str] = field(default=None, metadata={"loading": True}) - revision: Optional[str] = field(default=None, metadata={"loading": True}) - default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - - def __hash__(self): - """Make ComponentSpec hashable, using load_id as the hash value.""" - return hash((self.name, self.load_id, self.default_creation_method)) - - def __eq__(self, other): - """Compare ComponentSpec objects based on name and load_id.""" - if not isinstance(other, ComponentSpec): - return False - return (self.name == other.name and - self.load_id == other.load_id and - self.default_creation_method == other.default_creation_method) - - @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" - - if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") - - type_hint = component.__class__ - - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): - config = component.config - else: - config = None - - load_spec = cls.decode_load_id(component._diffusers_load_id) - - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) - - @classmethod - def loading_fields(cls) -> List[str]: - """ - Return the names of all loading‐related fields - (i.e. those whose field.metadata["loading"] is True). - """ - return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - - @property - def load_id(self) -> str: - """ - Unique identifier for this spec's pretrained load, - composed of repo|subfolder|variant|revision (no empty segments). - """ - parts = [getattr(self, k) for k in self.loading_fields()] - parts = ["null" if p is None else p for p in parts] - return "|".join(p for p in parts if p) - - @classmethod - def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: - """ - Decode a load_id string back into a dictionary of loading fields and values. - - Args: - load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" - where None values are represented as "null" - - Returns: - Dict mapping loading field names to their values. e.g. - { - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": "variant", - "revision": "revision" - } - If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). - """ - - # Get all loading fields in order - loading_fields = cls.loading_fields() - result = {f: None for f in loading_fields} - - if load_id == "null": - return result - - # Split the load_id - parts = load_id.split("|") - - # Map parts to loading fields by position - for i, part in enumerate(parts): - if i < len(loading_fields): - # Convert "null" string back to None - result[loading_fields[i]] = None if part == "null" else part - - return result - - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: - """Create component using from_config with config.""" - - if self.type_hint is None or not isinstance(self.type_hint, type): - raise ValueError( - f"`type_hint` is required when using from_config creation method." - ) - - config = config or self.config or {} - - if issubclass(self.type_hint, ConfigMixin): - component = self.type_hint.from_config(config, **kwargs) - else: - signature_params = inspect.signature(self.type_hint.__init__).parameters - init_kwargs = {} - for k, v in config.items(): - if k in signature_params: - init_kwargs[k] = v - for k, v in kwargs.items(): - if k in signature_params: - init_kwargs[k] = v - component = self.type_hint(**init_kwargs) - - component._diffusers_load_id = "null" - if hasattr(component, "config"): - self.config = component.config - - return component - - # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" - - passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} - load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} - # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path - repo = load_kwargs.pop("repo", None) - if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - - if self.type_hint is None: - try: - from diffusers import AutoModel - component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") - self.type_hint = component.__class__ - else: - try: - component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") - - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) - component._diffusers_load_id = self.load_id - - return component - - - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None -@dataclass -class InputParam: - """Specification for an input parameter.""" - name: str = None - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - - -@dataclass -class OutputParam: - """Specification for an output parameter.""" - name: str - type_hint: Any = None - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - if inp.name is None and inp.kwargs_type is not None: - inp_name = "*_" + inp.kwargs_type - else: - inp_name = inp.name - input_parts.append(inp_name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params, header="Args", indent_level=4, max_line_length=115): - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text, indent, max_length): - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if hasattr(param, "required"): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - - -def format_input_params(input_params, indent_level=4, max_line_length=115): - """Format a list of InputParam objects into a readable string representation. - - Args: - input_params: List of InputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all input parameters - """ - return format_params(input_params, "Inputs", indent_level, max_line_length) - - -def format_output_params(output_params, indent_level=4, max_line_length=115): - """Format a list of OutputParam objects into a readable string representation. - - Args: - output_params: List of OutputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all output parameters - """ - return format_params(output_params, "Outputs", indent_level, max_line_length) - - -def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ComponentSpec objects into a readable string representation. - - Args: - components: List of ComponentSpec objects to format - indent_level: Number of spaces to indent each component line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between components (default: True) - - Returns: - A formatted string representing all components - """ - if not components: - return "" - - base_indent = " " * indent_level - component_indent = " " * (indent_level + 4) - formatted_components = [] - - # Add the header - formatted_components.append(f"{base_indent}Components:") - if add_empty_lines: - formatted_components.append("") - - # Add each component with optional empty lines between them - for i, component in enumerate(components): - # Get type name, handling special cases - type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - - component_desc = f"{component_indent}{component.name} (`{type_name}`)" - if component.description: - component_desc += f": {component.description}" - - # Get the loading fields dynamically - loading_field_values = [] - for field_name in component.loading_fields(): - field_value = getattr(component, field_name) - if field_value is not None: - loading_field_values.append(f"{field_name}={field_value}") - - # Add loading field information if available - if loading_field_values: - component_desc += f" [{', '.join(loading_field_values)}]" - - formatted_components.append(component_desc) - - # Add an empty line after each component except the last one - if add_empty_lines and i < len(components) - 1: - formatted_components.append("") - - return "\n".join(formatted_components) - - -def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ConfigSpec objects into a readable string representation. - - Args: - configs: List of ConfigSpec objects to format - indent_level: Number of spaces to indent each config line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between configs (default: True) - - Returns: - A formatted string representing all configs - """ - if not configs: - return "" - - base_indent = " " * indent_level - config_indent = " " * (indent_level + 4) - formatted_configs = [] - - # Add the header - formatted_configs.append(f"{base_indent}Configs:") - if add_empty_lines: - formatted_configs.append("") - - # Add each config with optional empty lines between them - for i, config in enumerate(configs): - config_desc = f"{config_indent}{config.name} (default: {config.default})" - if config.description: - config_desc += f": {config.description}" - formatted_configs.append(config_desc) - - # Add an empty line after each config except the last one - if add_empty_lines and i < len(configs) - 1: - formatted_configs.append("") - - return "\n".join(formatted_configs) - - -def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Args: - inputs: List of input parameters - intermediates_inputs: List of intermediate input parameters - outputs: List of output parameters - description (str, *optional*): Description of the block - class_name (str, *optional*): Name of the class to include in the documentation - expected_components (List[ComponentSpec], *optional*): List of expected components - expected_configs (List[ConfigSpec], *optional*): List of expected configurations - - Returns: - str: A formatted string containing information about components, configs, call parameters, - intermediate inputs/outputs, and final outputs. - """ - output = "" - - # Add class name if provided - if class_name: - output += f"class {class_name}\n\n" - - # Add description - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - # Add components section if provided - if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) - output += components_str + "\n\n" - - # Add configs section if provided - if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) - output += configs_str + "\n\n" - - # Add inputs section - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - # Add outputs section - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index acb395345086..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3032 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularLoader, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, - ComponentSpec, - ConfigSpec, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance -from ...configuration_utils import FrozenDict - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -class StableDiffusionXLIPAdapterStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - @staticmethod - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - components, - ip_adapter_image=block_state.ip_adapter_image, - ip_adapter_image_embeds=None, - device=block_state.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, - ) - if block_state.prepare_unconditional_embeds: - block_state.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(block_state.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - block_state.negative_ip_adapter_embeds.append(negative_image_embeds) - block_state.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), - ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("force_zeros_for_empty_prompt", True)] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("prompt"), - InputParam("prompt_2"), - InputParam("negative_prompt"), - InputParam("negative_prompt_2"), - InputParam("cross_attention_kwargs"), - InputParam("clip_skip"), - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - ] - - @staticmethod - def check_inputs(block_state): - - if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") - - @staticmethod - def encode_prompt( - components, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prepare_unconditional_embeds (`bool`): - whether to use prepare unconditional embeddings or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or components._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] - text_encoders = ( - [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - block_state = self.get_block_state(state) - self.check_inputs(block_state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - # Encode input prompt - block_state.text_encoder_lora_scale = ( - block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None - ) - ( - block_state.prompt_embeds, - block_state.negative_prompt_embeds, - block_state.pooled_prompt_embeds, - block_state.negative_pooled_prompt_embeds, - ) = self.encode_prompt( - components, - block_state.prompt, - block_state.prompt_2, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - block_state.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=block_state.text_encoder_lora_scale, - clip_skip=block_state.clip_skip, - ) - # Add outputs - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("generator"), - InputParam("height"), - InputParam("width"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - - block_state.batch_size = block_state.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ComponentSpec( - "mask_processor", - VaeImageProcessor, - config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), - default_creation_method="from_config"), - ] - - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - if block_state.padding_mask_crop is not None: - block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) - block_state.resize_mode = "fill" - else: - block_state.crops_coords = None - block_state.resize_mode = "default" - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) - block_state.image = block_state.image.to(dtype=torch.float32) - - block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) - - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image, - block_state.batch_size, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - - return components, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, components, block_state): - - if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: - if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {block_state.negative_prompt_embeds.shape}." - ) - - if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {block_state.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype - - _, seq_len, _ = block_state.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - if block_state.negative_prompt_embeds is not None: - _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.negative_pooled_prompt_embeds is not None: - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - if block_state.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): - block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - InputParam("strength", default=0.3), - InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components - def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start * components.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (denoising_start * components.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if components.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(components.scheduler.timesteps) - num_inference_steps - timesteps = components.scheduler.timesteps[t_start:] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( - components, - block_state.num_inference_steps, - block_state.strength, - block_state.device, - denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, - ) - block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - InputParam( - "strength", - default=0.9999, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - @staticmethod - def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument - def prepare_latents_inpaint( - self, - components, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - block_state.is_strength_max = block_state.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components,"unet") and components.unet is not None: - if components.unet.config.in_channels == 4: - block_state.masked_image_latents = None - - block_state.add_noise = True if block_state.denoising_start is None else False - - block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor - block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - - block_state.latents, block_state.noise = self.prepare_latents_inpaint( - components, - block_state.batch_size * block_state.num_images_per_prompt, - components.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - image=block_state.image_latents, - timestep=block_state.latent_timestep, - is_strength_max=block_state.is_strength_max, - add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image_latents, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - @staticmethod - def prepare_latents_img2img( - components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - block_state.add_noise = True if block_state.denoising_start is None else False - if block_state.latents is None: - block_state.latents = self.prepare_latents_img2img( - components, - block_state.image_latents, - block_state.latent_timestep, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - - @staticmethod - def check_inputs(components, block_state): - if ( - block_state.height is not None - and block_state.height % components.vae_scale_factor != 0 - or block_state.width is not None - and block_state.width % components.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - @staticmethod - def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * components.scheduler.init_noise_sigma - return latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.dtype is None: - block_state.dtype = components.vae.dtype - - block_state.device = components._execution_device - - self.check_inputs(components, block_state) - - block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor - block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.num_channels_latents = components.num_channels_latents - block_state.latents = self.prepare_latents( - components, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", False),] - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("aesthetic_score", default=6.0), - InputParam("negative_aesthetic_score", default=2.0), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids_img2img( - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.vae_scale_factor = components.vae_scale_factor - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * block_state.vae_scale_factor - block_state.width = block_state.width * block_state.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - if block_state.negative_original_size is None: - block_state.negative_original_size = block_state.original_size - if block_state.negative_target_size is None: - block_state.negative_target_size = block_state.target_size - - block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.aesthetic_score, - block_state.negative_aesthetic_score, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - dtype=block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids( - components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - block_state.add_time_ids = self._get_add_time_ids( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - if block_state.negative_original_size is not None and block_state.negative_target_size is not None: - block_state.negative_add_time_ids = self._get_add_time_ids( - components, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - else: - block_state.negative_add_time_ids = block_state.add_time_ids - - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - -class StableDiffusionXLControlNetInputStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepare inputs for controlnet" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), - OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), - OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - # (1) prepare controlnet inputs - block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - block_state.control_guidance_start, block_state.control_guidance_end = ( - mult * [block_state.control_guidance_start], - mult * [block_state.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( - components, - image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in block_state.control_image: - control_image = self.prepare_control_image( - components, - image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - - control_images.append(control_image) - - block_state.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - keeps = [ - 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) - for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) - ] - block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - - - self.add_block_state(state, block_state) - - return components, state - -class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepares inputs for the ControlNetUnion model" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_mode", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of model tensor inputs. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), - OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), - OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - controlnet = unwrap_module(components.controlnet) - - device = components._execution_device - dtype = block_state.dtype or components.controlnet.dtype - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # control_image - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - # control_mode - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - # control_type - block_state.num_control_type = controlnet.config.num_control_type - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - block_state.control_type = torch.Tensor(block_state.control_type) - - block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=device, - dtype=dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) - block_state.control_type_idx = block_state.control_mode - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): - - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] - block_names = ["controlnet_union", "controlnet"] - block_trigger_inputs = ["control_mode", "control_image"] - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - @staticmethod - def upcast_vae(components): - dtype = components.vae.dtype - components.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - components.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - components.vae.post_quant_conv.to(dtype) - components.vae.decoder.conv_in.to(dtype) - components.vae.decoder.mid_block.to(dtype) - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - - if block_state.needs_upcasting: - self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - block_state.has_latents_mean = ( - hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None - ) - block_state.has_latents_std = ( - hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None - ) - if block_state.has_latents_mean and block_state.has_latents_std: - block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean - else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor - - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if block_state.needs_upcasting: - components.vae.to(dtype=torch.float16) - else: - block_state.images = block_state.latents - - # apply watermark if available - if hasattr(components, "watermark") and components.watermark is not None: - block_state.images = components.watermark.apply_watermark(block_state.images) - - block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: - block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.return_dict: - block_state.images = (block_state.images,) - else: - block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) - self.add_block_state(state, block_state) - return components, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - -# # Denoise -from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLAutoDenoiseStep -# class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): -# block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] -# block_names = ["controlnet_union", "controlnet", "unet"] -# block_trigger_inputs = ["control_mode", "control_image", None] - -# @property -# def description(self): -# return "Denoise step that denoise the latents.\n" + \ -# "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ -# " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ -# " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ -# " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that -# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by -# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the -# configuration of guider is. - - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - - -# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") -} - - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") -} - - -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} - - -SDXL_OUTPUTS_SCHEMA = { - "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py deleted file mode 100644 index 63d0784a5762..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py +++ /dev/null @@ -1,1363 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from tqdm.auto import tqdm - -from ...configuration_utils import FrozenDict -from ...models import ControlNetModel, UNet2DConditionModel -from ...schedulers import EulerDiscreteScheduler -from ...utils import logging -from ...utils.torch_utils import unwrap_module -from ..modular_pipeline import ( - PipelineBlock, - PipelineState, - AutoPipelineBlocks, - LoopSequentialPipelineBlocks, - InputParam, - OutputParam, - BlockState, - ComponentSpec, -) -from ...guiders import ClassifierFreeGuidance -from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader -from dataclasses import asdict - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi experimenting composible denoise loop -# loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - - return components, block_state - -# loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.num_channels_unet - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - - return components, block_state - -# loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step within the denoising loop that denoise the latents with guidance" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} - prompt_embeds = cond_kwargs.pop("prompt_embeds") - - # Predict the noise residual - # store the noise_pred in guider_state_batch so that we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - InputParam( - kwargs_type="controlnet_kwargs", - description=( - "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" - "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ) - ] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - # cond_scale for the timestep (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - controlnet_cond_scale = block_state.conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] - - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None - - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - - # Prepare additional conditionings - added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: - added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds - - # Prepare controlnet additional conditionings - controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - # run controlnet for the guidance batch - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - down_block_res_samples = block_state.down_block_res_samples_zeros - mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - down_block_res_samples, mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.cond_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - **extra_controlnet_kwargs, - ) - - # assign it to block_state so it will be available for the uncond guidance batch - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) - - # Predict the noise - # store the noise_pred in guider_state_batch so we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - #YiYi TODO: move this out of here - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - return components, block_state - -# loop step (3): scheduler step to update latents (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - def check_inputs(self, components, block_state): - if components.num_channels_unet == 4: - if block_state.image_latents is None: - raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") - if block_state.mask is None: - raise ValueError(f"mask is required for this step {self.__class__.__name__}") - if block_state.noise is None: - raise ValueError(f"noise is required for this step {self.__class__.__name__}") - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - # adjust latent for inpainting - if components.num_channels_unet == 4: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - - - return components, block_state - - -# the loop wrapper that iterates over the timesteps -class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components, block_state = self.loop_step(components, block_state, i=i, t=t) - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, block_state) - - return components, state - - -# composing the denoising loops -class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond -class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# mask -class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond + mask -class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - - -# all task without controlnet -class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - -# all task with controlnet -class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] - block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] - block_trigger_inputs = ["mask", None] - -# all task with or without controlnet -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] - - - - - - - -# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - - -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state \ No newline at end of file From 0acb5e1460b2fd2769bcaa38a523c6a9a9f063ea Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:31 +0200 Subject: [PATCH 22/54] made a modular_pipelines folder! --- src/diffusers/modular_pipelines/__init__.py | 82 + .../modular_pipelines/components_manager.py | 863 ++++++++ .../modular_pipelines/modular_pipeline.py | 1916 +++++++++++++++++ .../modular_pipeline_utils.py | 598 +++++ .../stable_diffusion_xl/__init__.py | 51 + .../stable_diffusion_xl/after_denoise.py | 259 +++ .../stable_diffusion_xl/before_denoise.py | 1766 +++++++++++++++ .../stable_diffusion_xl/denoise.py | 1362 ++++++++++++ .../stable_diffusion_xl/encoders.py | 856 ++++++++ .../stable_diffusion_xl/modular_loader.py | 175 ++ .../modular_pipeline_presets.py | 119 + 11 files changed, 8047 insertions(+) create mode 100644 src/diffusers/modular_pipelines/__init__.py create mode 100644 src/diffusers/modular_pipelines/components_manager.py create mode 100644 src/diffusers/modular_pipelines/modular_pipeline.py create mode 100644 src/diffusers/modular_pipelines/modular_pipeline_utils.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..cb2ed78ce360 --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,82 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineMixin", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "ModularLoader", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularLoader, + ModularPipelineMixin, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) + from .components_manager import ComponentsManager +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py new file mode 100644 index 000000000000..0ace1b321e8b --- /dev/null +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -0,0 +1,863 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from itertools import combinations +from typing import List, Optional, Union, Dict, Any +import copy + +import torch +import time +from dataclasses import dataclass + +from ..utils import ( + is_accelerate_available, + logging, +) +from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec + + +import uuid + + +if is_accelerate_available(): + from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module + from accelerate.state import PartialState + from accelerate.utils import send_to_device + from accelerate.utils.memory import clear_device_cache + from accelerate.utils.modeling import convert_file_size_to_int + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi Notes: copied from modeling_utils.py (decide later where to put this) +def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to + benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch + discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are + tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm + layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + +class CustomOffloadHook(ModelHook): + """ + A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are + on the given device. Optionally offloads other models to the CPU before the forward pass is called. + + Args: + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + """ + + def __init__( + self, + execution_device: Optional[Union[str, int, torch.device]] = None, + other_hooks: Optional[List["UserCustomOffloadHook"]] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, + ): + self.execution_device = execution_device if execution_device is not None else PartialState().default_device + self.other_hooks = other_hooks + self.offload_strategy = offload_strategy + self.model_id = None + + def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): + self.offload_strategy = offload_strategy + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + """ + Add a hook to the list of hooks to consider for offloading. + """ + if self.other_hooks is None: + self.other_hooks = [] + self.other_hooks.append(hook) + + def init_hook(self, module): + return module.to("cpu") + + def pre_forward(self, module, *args, **kwargs): + if module.device != self.execution_device: + if self.other_hooks is not None: + hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] + # offload all other hooks + start_time = time.perf_counter() + if self.offload_strategy is not None: + hooks_to_offload = self.offload_strategy( + hooks=hooks_to_offload, + model_id=self.model_id, + model=module, + execution_device=self.execution_device, + ) + end_time = time.perf_counter() + logger.info( + f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" + ) + + for hook in hooks_to_offload: + logger.info( + f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" + ) + hook.offload() + + if hooks_to_offload: + clear_device_cache() + module.to(self.execution_device) + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) + + +class UserCustomOffloadHook: + """ + A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of + the hook or remove it entirely. + """ + + def __init__(self, model_id, model, hook): + self.model_id = model_id + self.model = model + self.hook = hook + + def offload(self): + self.hook.init_hook(self.model) + + def attach(self): + add_hook_to_module(self.model, self.hook) + self.hook.model_id = self.model_id + + def remove(self): + remove_hook_from_module(self.model) + self.hook.model_id = None + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + self.hook.add_other_hook(hook) + + +def custom_offload_with_hook( + model_id: str, + model: torch.nn.Module, + execution_device: Union[str, int, torch.device] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, +): + hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) + user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) + user_hook.attach() + return user_hook + + +class AutoOffloadStrategy: + """ + Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on + the available memory on the device. + """ + + def __init__(self, memory_reserve_margin="3GB"): + self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) + + def __call__(self, hooks, model_id, model, execution_device): + if len(hooks) == 0: + return [] + + current_module_size = get_memory_footprint(model) + + mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] + mem_on_device = mem_on_device - self.memory_reserve_margin + if current_module_size < mem_on_device: + return [] + + min_memory_offload = current_module_size - mem_on_device + logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") + + # exlucde models that's not currently loaded on the device + module_sizes = dict( + sorted( + {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), + key=lambda x: x[1], + reverse=True, + ) + ) + + def search_best_candidate(module_sizes, min_memory_offload): + """ + search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a + minimum memory offload size. the combination of models should add up to the smallest modulesize that is + larger than `min_memory_offload` + """ + model_ids = list(module_sizes.keys()) + best_candidate = None + best_size = float("inf") + for r in range(1, len(model_ids) + 1): + for candidate_model_ids in combinations(model_ids, r): + candidate_size = sum( + module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids + ) + if candidate_size < min_memory_offload: + continue + else: + if best_candidate is None or candidate_size < best_size: + best_candidate = candidate_model_ids + best_size = candidate_size + + return best_candidate + + best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) + + if best_offload_model_ids is None: + # if no combination is found, meaning that we cannot meet the memory requirement, offload all models + logger.warning("no combination of models to offload to cpu is found, offloading all models") + hooks_to_offload = hooks + else: + hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] + + return hooks_to_offload + + + +class ComponentsManager: + def __init__(self): + self.components = OrderedDict() + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names + self.model_hooks = None + self._auto_offload_enabled = False + + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + + def add(self, name, component, collection: Optional[str] = None): + + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(component_id) + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id + + + def remove(self, name: Union[str, List[str]]): + + if name not in self.components: + logger.warning(f"Component '{name}' not found in ComponentsManager") + return + + self.components.pop(name) + self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): + """ + Select components by name with simple pattern matching. + + Args: + names: Component name(s) or pattern(s) + Patterns: + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys + + Returns: + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True + """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return [] if as_name_component_tuples else {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + + if isinstance(names, str): + # Check if this is a "not" pattern + is_not_pattern = names.startswith('!') + if is_not_pattern: + names = names[1:] # Remove the ! prefix + + # Handle OR patterns (containing |) + if '|' in names: + terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } + + if is_not_pattern: + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # Prefix match (ends with *) + elif names.endswith('*'): + prefix = names[:-1] + matches = { + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") + else: + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + + # Contains match (starts with *) + elif names.startswith('*'): + search = names[1:-1] if names.endswith('*') else names[1:] + matches = { + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + + else: + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches + + elif isinstance(names, list): + results = {} + for name in names: + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results + + else: + raise ValueError(f"Invalid type for names: {type(names)}") + + def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): + remove_hook_from_module(component, recurse=True) + + self.disable_auto_cpu_offload() + offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) + device = torch.device(device) + if device.index is None: + device = torch.device(f"{device.type}:{0}") + all_hooks = [] + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module): + hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) + all_hooks.append(hook) + + for hook in all_hooks: + other_hooks = [h for h in all_hooks if h is not hook] + for other_hook in other_hooks: + if other_hook.hook.execution_device == hook.hook.execution_device: + hook.add_other_hook(other_hook) + + self.model_hooks = all_hooks + self._auto_offload_enabled = True + self._auto_offload_device = device + + def disable_auto_cpu_offload(self): + if self.model_hooks is None: + self._auto_offload_enabled = False + return + + for hook in self.model_hooks: + hook.offload() + hook.remove() + if self.model_hooks: + clear_device_cache() + self.model_hooks = None + self._auto_offload_enabled = False + + # YiYi TODO: add quantization info + def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + """Get comprehensive information about a component. + + Args: + name: Name of the component to get info for + fields: Optional field(s) to return. Can be a string for single field or list of fields. + If None, returns all fields. + + Returns: + Dictionary containing requested component metadata. + If fields is specified, returns only those fields. + If a single field is requested as string, returns just that field's value. + """ + if name not in self.components: + raise ValueError(f"Component '{name}' not found in ComponentsManager") + + component = self.components[name] + + # Build complete info dict first + info = { + "model_id": name, + "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), + } + + # Additional info for torch.nn.Module components + if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + + info.update({ + "class_name": component.__class__.__name__, + "size_gb": get_memory_footprint(component) / (1024**3), + "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, + }) + + # Get adapters if applicable + if hasattr(component, "peft_config"): + info["adapters"] = list(component.peft_config.keys()) + + # Check for IP-Adapter scales + if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): + processors = copy.deepcopy(component.attn_processors) + # First check if any processor is an IP-Adapter + processor_types = [v.__class__.__name__ for v in processors.values()] + if any("IPAdapter" in ptype for ptype in processor_types): + # Then get scales only from IP-Adapter processors + scales = { + k: v.scale + for k, v in processors.items() + if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ + } + if scales: + info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + + # If fields specified, filter info + if fields is not None: + if isinstance(fields, str): + # Single field requested, return just that value + return {fields: info.get(fields)} + else: + # List of fields requested, return dict with just those fields + return {k: v for k, v in info.items() if k in fields} + + return info + + def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + + col_widths = { + "name": max(15, max(len(name) for name in simple_names)), + "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), + "device": 15, # Reduced since using more compact format + "dtype": 15, + "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) + } + + # Create the header lines + sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + + output = "Components:\n" + sep_line + + # Separate components into models and others + models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} + + # Models section + if models: + output += "Models:\n" + dash_line + # Column headers + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" + output += dash_line + + # Model entries + for name, component in models.items(): + info = self.get_model_info(name) + simple_name = get_simple_name(name) + device_str = format_device(component, info) + dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" + output += dash_line + + # Other components section + if others: + if models: # Add extra newline if we had models section + output += "\n" + output += "Other Components:\n" + dash_line + # Column headers for other components + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" + output += dash_line + + # Other component entries + for name, component in others.items(): + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" + output += dash_line + + # Add additional component info + output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" + for name in self.components: + info = self.get_model_info(name) + if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" + if info.get("adapters") is not None: + output += f" Adapters: {info['adapters']}\n" + if info.get("ip_adapter"): + output += f" IP-Adapter: Enabled\n" + output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" + + return output + + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + """ + Load components from a pretrained model and add them to the manager. + + Args: + pretrained_model_name_or_path (str): The path or identifier of the pretrained model + prefix (str, optional): Prefix to add to all component names loaded from this model. + If provided, components will be named as "{prefix}_{component_name}" + **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() + """ + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) + +def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: + """Summarizes a dictionary by finding common prefixes that share the same value. + + For a dictionary with dot-separated keys like: + { + 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], + 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + } + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: + { + 'down_blocks': [0.6], + 'up_blocks': [0.3] + } + """ + # First group by values - convert lists to tuples to make them hashable + value_to_keys = {} + for key, value in d.items(): + value_tuple = tuple(value) if isinstance(value, list) else value + if value_tuple not in value_to_keys: + value_to_keys[value_tuple] = [] + value_to_keys[value_tuple].append(key) + + def find_common_prefix(keys: List[str]) -> str: + """Find the shortest common prefix among a list of dot-separated keys.""" + if not keys: + return "" + if len(keys) == 1: + return keys[0] + + # Split all keys into parts + key_parts = [k.split('.') for k in keys] + + # Find how many initial parts are common + common_length = 0 + for parts in zip(*key_parts): + if len(set(parts)) == 1: # All parts at this position are the same + common_length += 1 + else: + break + + if common_length == 0: + return "" + + # Return the common prefix + return '.'.join(key_parts[0][:common_length]) + + # Create summary by finding common prefixes for each value group + summary = {} + for value_tuple, keys in value_to_keys.items(): + prefix = find_common_prefix(keys) + if prefix: # Only add if we found a common prefix + # Convert tuple back to list if it was originally a list + value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple + summary[prefix] = value + else: + summary[""] = value # Use empty string if no common prefix + + return summary diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py new file mode 100644 index 000000000000..98960fe25bde --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -0,0 +1,1916 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union, Optional, Type + + +import torch +from tqdm.auto import tqdm +import re +import os +import importlib + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + PushToHubMixin, +) +from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_input_params, + format_inputs_short, + format_intermediates_short, + format_output_params, + format_params, + make_doc_string, +) +from .components_manager import ComponentsManager + +from copy import deepcopy +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_LOADER_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ] +) + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ + self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) + + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ + self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) + + def get_input(self, key: str, default: Any = None) -> Any: + return self.inputs.get(key, default) + + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.intermediates.get(key, default) for key in keys} + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + + +class ModularPipelineMixin: + """ + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a mouldar loader, optionally accept modular_repo to load from hub. + """ + + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs + + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params + + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + if not hasattr(self, "loader"): + logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): + + model_name = None + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + + # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable + @property + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks + @property + def outputs(self) -> List[OutputParam]: + return self.intermediates_outputs + + @property + def required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - use format_components with add_empty_lines=False + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") + + # Inputs section + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str + + # Intermediates section + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates = f"Intermediates:\n{intermediates_str}" + + return ( + f"{class_name}(\n" + f" Class: {base_class}\n" + f"{desc}" + f"{components}\n" + f"{configs}\n" + f" {inputs}\n" + f" {intermediates}\n" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): + warnings.warn( + f"Multiple different default values found for input '{input_param.name}': " + f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + else: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + + return list(combined_dict.values()) + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + +class AutoPipelineBlocks(ModularPipelineMixin): + """ + A class that automatically selects a block to run based on the inputs. + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default + """ + + block_classes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocksmatters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img + if len(default_blocks) > 1 or ( + len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None + ): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) + self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + + @property + def required_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + @property + def required_intermediates_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediates_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + + @property + def intermediates_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediates_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + + try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + + return trigger_inputs + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + return ( + f"{header}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +class SequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + """ + block_classes = [] + block_names = [] + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.blocks) + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) + def fn_recursive_traverse(block, block_name, active_triggers): + result_blocks = OrderedDict() + + # sequential(include loopsequential) or PipelineBlock + if not hasattr(block, 'block_trigger_inputs'): + if hasattr(block, 'blocks'): + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, 'outputs'): + active_triggers.update(out.name for out in block.outputs) + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our active_triggers + this_block = None + matching_trigger = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: + this_block = block.trigger_to_block_map[trigger_input] + matching_trigger = trigger_input + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + matching_trigger = None + + if this_block is not None: + # sequential/auto (keep traversing) + if hasattr(this_block, 'blocks'): + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? + if hasattr(this_block, 'outputs'): + active_triggers.update(out.name for out in this_block.outputs) + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_execution_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + return ( + f"{header}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): + """ + Base class for all Modular pipelines loaders. + + """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specs. + This method is called when component changed or __init__ is called. + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + + """ + for name, module in kwargs.items(): + + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + is_registered = hasattr(self, name) + + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # actual library and class name of the module + + if module is not None: + library, class_name = _fetch_class_library_tuple(module) + new_component_spec = ComponentSpec.from_component(name, module) + component_spec_dict = self._component_spec_to_dict(new_component_spec) + + else: + library, class_name = None, None + # if module is None, we do not update the spec, + # but we still need to update the config to make sure it's synced with the component spec + # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) + new_component_spec = component_spec + component_spec_dict = self._component_spec_to_dict(component_spec) + + # do not register if component is not to be loaded from pretrained + if new_component_spec.default_creation_method == "from_pretrained": + register_dict = {name: (library, class_name, component_spec_dict)} + else: + register_dict = {} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + self._component_specs[name] = new_component_spec + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # it module is not an instance of the expected type, still register it but with a warning + if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → debug + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # save modular_model_index.json config + self.register_to_config(**register_dict) + # update component spec + self._component_specs[name] = new_component_spec + # finally set models + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + register_components_dict[name] = None + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + + @property + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): + """ + Update components and configs after instance creation. + + Args: + + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + ``` + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + + + self.register_components(**passed_components) + + + config_to_register = {} + for name, new_value in passed_config_values.items(): + + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) + + + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): + """ + Load selectedcomponents from specs. + + Args: + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. + """ + if component_names is None: + component_names = list(self._component_specs.keys()) + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") + + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.create(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) + + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + for name in expected_component: + for spec in component_specs: + if spec.name == name: + break + else: + # append a empty component spec for these not in modular_model_index + component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs) + + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } + """ + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..392d6dcd9521 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -0,0 +1,598 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin + +if is_torch_available(): + import torch + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: torch.nn.Module) -> Any: + """Create a ComponentSpec from a Component created by `create` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` method") + + type_hint = component.__class__ + + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, **load_spec) + + @classmethod + def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: + """Create a ComponentSpec from a load_id string.""" + if load_id == "null": + raise ValueError("Cannot create ComponentSpec from null load_id") + + # Decode the load_id into a dictionary of loading fields + load_fields = cls.decode_load_id(load_id) + + # Create a new ComponentSpec instance with the decoded fields + return cls(name=name, **load_fields) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not loaded from pretrained). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + # YiYi TODO: add validator + def create(self, **kwargs) -> Any: + """Create the component using the preferred creation method.""" + + # from_pretrained creation + if self.default_creation_method == "from_pretrained": + return self.create_from_pretrained(**kwargs) + elif self.default_creation_method == "from_config": + # from_config creation + return self.create_from_config(**kwargs) + else: + raise ValueError(f"Invalid creation method: {self.default_creation_method}") + + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def create_from_pretrained(self, **kwargs) -> Any: + """Create component using from_pretrained.""" + + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + + if repo != self.repo: + self.repo = repo + for k, v in passed_loading_kwargs.items(): + if v is not None: + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str = None + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + kwargs_type: str = None + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + kwargs_type: str = None + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..6d06c1f2e3df --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] + _import_structure["after_denoise"] = ["StableDiffusionXLAutoDecodeStep"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline + from .modular_loader import StableDiffusionXLModularLoader + from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep + from .after_denoise import StableDiffusionXLAutoDecodeStep +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py new file mode 100644 index 000000000000..9746832506d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -0,0 +1,259 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +import numpy as np +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging + +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...configuration_utils import FrozenDict + +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + +class StableDiffusionXLDecodeLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + @staticmethod + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif block_state.latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(block_state.latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + ) + block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + else: + block_state.latents = block_state.latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLOutputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [InputParam("return_dict", default=True)] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] + + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.return_dict: + block_state.images = (block_state.images,) + else: + block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) + self.add_block_state(state, block_state) + return components, state + + +# After denoise +class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] + block_names = ["decode", "output"] + + @property + def description(self): + return """Decode step that decode the denoised latents into images outputs. +This is a sequential pipeline blocks: + - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images + - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ + " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..6809b4cd8e2e --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1766 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module + +from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + ModularLoader, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + ] + + def check_inputs(self, components, block_state): + + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components + # YiYi TODO: refactor using _encode_vae_image + @staticmethod + def prepare_latents_img2img( + components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + init_latents = components.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = components.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = self.prepare_latents_img2img( + components, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] + + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * components.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", False),] + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + + +# Before denoise +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..f605d0ab00aa --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,1362 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module + +from ...guiders import ClassifierFreeGuidance +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + AutoPipelineBlocks, + LoopSequentialPipelineBlocks, + BlockState, +) +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") + +# controlnet = unwrap_module(components.controlnet) + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..3c84fc71c8af --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,856 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor, unwrap_module +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...configuration_utils import FrozenDict + +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec + +import numpy as np + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + ] + + @staticmethod + def check_inputs(block_state): + + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("generator"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), + ] + + + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + + return components, state + + + +# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py new file mode 100644 index 000000000000..53f27571092a --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -0,0 +1,175 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +import PIL +import torch +import numpy as np + +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...image_processor import PipelineImageInput +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging + +from ..modular_pipeline import ModularLoader +from ..modular_pipeline_utils import InputParam, OutputParam + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py new file mode 100644 index 000000000000..80f1780595c2 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -0,0 +1,119 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks + +from .denoise import StableDiffusionXLAutoDenoiseStep +from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep +from .after_denoise import StableDiffusionXLAutoDecodeStep +from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "after_denoise"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" + + + +# YiYi notes: comment out for now, work on this later +# # block mapping +# TEXT2IMAGE_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLDecodeStep) +# ]) + +# IMAGE2IMAGE_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLVaeEncoderStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLDecodeStep) +# ]) + +# INPAINT_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLInpaintDecodeStep) +# ]) + +# CONTROLNET_BLOCKS = OrderedDict([ +# ("controlnet_input", StableDiffusionXLControlNetInputStep), +# ("denoise", StableDiffusionXLControlNetDenoiseStep), +# ]) + +# CONTROLNET_UNION_BLOCKS = OrderedDict([ +# ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), +# ("denoise", StableDiffusionXLControlNetDenoiseStep), +# ]) + +# IP_ADAPTER_BLOCKS = OrderedDict([ +# ("ip_adapter", StableDiffusionXLIPAdapterStep), +# ]) + +# AUTO_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), +# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), +# ("denoise", StableDiffusionXLAutoDenoiseStep), +# ("decode", StableDiffusionXLAutoDecodeStep) +# ]) + +# AUTO_CORE_BLOCKS = OrderedDict([ +# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), +# ("denoise", StableDiffusionXLAutoDenoiseStep), +# ]) + + +# SDXL_SUPPORTED_BLOCKS = { +# "text2img": TEXT2IMAGE_BLOCKS, +# "img2img": IMAGE2IMAGE_BLOCKS, +# "inpaint": INPAINT_BLOCKS, +# "controlnet": CONTROLNET_BLOCKS, +# "controlnet_union": CONTROLNET_UNION_BLOCKS, +# "ip_adapter": IP_ADAPTER_BLOCKS, +# "auto": AUTO_BLOCKS +# } + + From 153ae34ff6d8c0832b7d2db2aabcf4e27f0eb1e4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:47 +0200 Subject: [PATCH 23/54] update __init__ --- src/diffusers/__init__.py | 48 +++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa3e88d999b5..7a3de0b95747 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,6 +39,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "modular_pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -254,13 +255,19 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ModularLoader", + "ComponentSpec", + "ComponentsManager", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -509,12 +516,10 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -541,6 +546,24 @@ ] ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -864,12 +887,16 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) + from .modular_pipelines import ( + ModularLoader, + ComponentSpec, + ComponentsManager, + ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -1097,12 +1124,10 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, @@ -1127,7 +1152,16 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() From 796453cad12d62dbe48db156df925cd5392cca31 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:14:43 +0200 Subject: [PATCH 24/54] add notes --- .../modular_pipelines/modular_pipeline_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 392d6dcd9521..a82f83fc38d9 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -241,6 +241,13 @@ class ConfigSpec: name: str default: Any description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediates_inputs are InputParam objects +# however some fields are not relevant for intermediates_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs +# -> should we use different class for inputs and intermediates_inputs? @dataclass class InputParam: """Specification for an input parameter.""" @@ -249,7 +256,7 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None + kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" From 144eae4e0bb3368d9f617d7c54761e86128a0289 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:16:42 +0200 Subject: [PATCH 25/54] add block state will also make sure modifed intermediates_inputs will be updated --- .../modular_pipelines/modular_pipeline.py | 241 +++++++++++++++--- 1 file changed, 206 insertions(+), 35 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 98960fe25bde..3eeff41dd1de 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -282,7 +282,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, state = PipelineState() if not hasattr(self, "loader"): - logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") self.loader = None # Make a copy of the input kwargs @@ -313,7 +313,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, # Warn about unexpected inputs if len(passed_kwargs) > 0: - logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -373,7 +373,6 @@ def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -389,13 +388,16 @@ def intermediates_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] + def _get_outputs(self): + return self.intermediates_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @property def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs + return self._get_outputs() - @property - def required_inputs(self) -> List[str]: + def _get_required_inputs(self): input_names = [] for input_param in self.inputs: if input_param.required: @@ -403,13 +405,23 @@ def required_inputs(self) -> List[str]: return input_names @property - def required_intermediates_inputs(self) -> List[str]: + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + + def _get_required_intermediates_inputs(self): input_names = [] for input_param in self.intermediates_inputs: if input_param.required: input_names.append(input_param.name) return input_names + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + return self._get_required_intermediates_inputs() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -521,6 +533,30 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li input_param.default is not None and current_param.default != input_param.default): warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." ) if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name return list(combined_dict.values()) @@ -661,7 +697,9 @@ def required_inputs(self) -> List[str]: required_by_all.intersection_update(block_required) return list(required_by_all) - + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -838,14 +876,21 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result @property @@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin): block_classes = [] block_names = [] - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name @property def description(self): return "" + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + @property def expected_components(self): @@ -929,6 +976,8 @@ def required_inputs(self) -> List[str]: return list(required_by_any) + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -960,11 +1009,15 @@ def intermediates_inputs(self) -> List[str]: def get_intermediates_inputs(self): inputs = [] outputs = set() + added_inputs = set() # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + for inp in block.intermediates_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -1176,14 +1229,21 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result @property @@ -1348,7 +1408,8 @@ def required_inputs(self) -> List[str]: return list(required_by_any) - # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -1384,6 +1445,22 @@ def __init__(self): for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance def loop_step(self, components, state: PipelineState, **kwargs): @@ -1455,6 +1532,100 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + + # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader From 522e82762566597de63afd185f9bc02589035674 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:17:45 +0200 Subject: [PATCH 26/54] move block mappings to its own file --- .../modular_pipeline_block_mappings.py | 128 ++++++++++++++++++ .../modular_pipeline_presets.py | 76 ----------- 2 files changed, 128 insertions(+), 76 deletions(-) create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py new file mode 100644 index 000000000000..c739a24e9759 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +# Import all the necessary block classes +from .denoise import ( + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLControlNetDenoiseStep +) +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep +) +from .encoders import ( + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLVaeEncoderStep, + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep +) +from .after_denoise import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintDecodeStep +) +from .after_denoise import StableDiffusionXLAutoDecodeStep + + +# YiYi notes: comment out for now, work on this later +# block mapping +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + +AUTO_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 80f1780595c2..6ea327047740 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -40,80 +40,4 @@ def description(self): -# YiYi notes: comment out for now, work on this later -# # block mapping -# TEXT2IMAGE_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLDecodeStep) -# ]) - -# IMAGE2IMAGE_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLVaeEncoderStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLDecodeStep) -# ]) - -# INPAINT_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLInpaintDecodeStep) -# ]) - -# CONTROLNET_BLOCKS = OrderedDict([ -# ("controlnet_input", StableDiffusionXLControlNetInputStep), -# ("denoise", StableDiffusionXLControlNetDenoiseStep), -# ]) - -# CONTROLNET_UNION_BLOCKS = OrderedDict([ -# ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), -# ("denoise", StableDiffusionXLControlNetDenoiseStep), -# ]) - -# IP_ADAPTER_BLOCKS = OrderedDict([ -# ("ip_adapter", StableDiffusionXLIPAdapterStep), -# ]) - -# AUTO_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), -# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), -# ("denoise", StableDiffusionXLAutoDenoiseStep), -# ("decode", StableDiffusionXLAutoDecodeStep) -# ]) - -# AUTO_CORE_BLOCKS = OrderedDict([ -# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), -# ("denoise", StableDiffusionXLAutoDenoiseStep), -# ]) - - -# SDXL_SUPPORTED_BLOCKS = { -# "text2img": TEXT2IMAGE_BLOCKS, -# "img2img": IMAGE2IMAGE_BLOCKS, -# "inpaint": INPAINT_BLOCKS, -# "controlnet": CONTROLNET_BLOCKS, -# "controlnet_union": CONTROLNET_UNION_BLOCKS, -# "ip_adapter": IP_ADAPTER_BLOCKS, -# "auto": AUTO_BLOCKS -# } - From 5cde77f9159d9bf1deeb948a4db79d109df461d7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 01:52:51 +0200 Subject: [PATCH 27/54] make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables --- .../modular_pipelines/modular_pipeline.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3eeff41dd1de..5dcb903db495 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -17,6 +17,7 @@ from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union, Optional, Type +from copy import deepcopy import torch @@ -109,7 +110,9 @@ def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} @@ -483,6 +486,7 @@ def doc(self): ) + # YiYi TODO: input and inteermediate inputs with same name? should warn? def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} @@ -1032,14 +1036,21 @@ def get_intermediates_inputs(self): @property def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): From 58358c2d003f7a25120aea9c4545571d6feefe21 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 01:57:47 +0200 Subject: [PATCH 28/54] decode block, if skip decoding do not need to update latent --- .../stable_diffusion_xl/after_denoise.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py index 9746832506d7..6ce59b5c35b9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -98,16 +98,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if not block_state.output_type == "latent": + latents = block_state.latents # make sure the VAE is in float32 mode, as it overflows in float16 block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast if block_state.needs_upcasting: self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) + components.vae = components.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None @@ -119,16 +120,16 @@ def __call__(self, components, state: PipelineState) -> PipelineState: ) if block_state.has_latents_mean and block_state.has_latents_std: block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor + latents = latents / components.vae.config.scaling_factor - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] + block_state.images = components.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if block_state.needs_upcasting: @@ -186,6 +187,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state +# YiYi TODO: remove this, we don't need this in modular class StableDiffusionXLOutputStep(PipelineBlock): model_name = "stable-diffusion-xl" From 506a8ea09c19d806103c23e69d3dd52aa7e84110 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 04:36:06 +0200 Subject: [PATCH 29/54] fix imports --- .../pipelines/stable_diffusion_xl/__init__.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 006836fe30d4..8088fbcfceba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,18 +29,6 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] - _import_structure["pipeline_stable_diffusion_xl_modular"] = [ - "StableDiffusionXLControlNetDenoiseStep", - "StableDiffusionXLDecodeLatentsStep", - "StableDiffusionXLDenoiseStep", - "StableDiffusionXLInputStep", - "StableDiffusionXLModularLoader", - "StableDiffusionXLPrepareAdditionalConditioningStep", - "StableDiffusionXLPrepareLatentsStep", - "StableDiffusionXLSetTimestepsStep", - "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLAutoPipeline", - ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -60,18 +48,6 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline - from .pipeline_stable_diffusion_xl_modular import ( - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLModularLoader, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, - ) try: if not (is_transformers_available() and is_flax_available()): From e2491af650b33c43294f0aaac02f0b7fdbbcf7e0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 20:42:57 +0200 Subject: [PATCH 30/54] fix import --- src/diffusers/pipelines/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0567eb687c62..a988fb6702aa 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,7 +47,6 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -481,7 +480,6 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, From a0deefb6061408a5ff6523ceed24a0fa31c30b20 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 20:51:21 +0200 Subject: [PATCH 31/54] fix more --- src/diffusers/pipelines/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a988fb6702aa..011f23ed371c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -329,8 +329,6 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularLoader", - "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -704,9 +702,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( From a7fb2d2a2243d4687a2b9c05ca0fdec21fdb9ffb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 22:15:54 +0200 Subject: [PATCH 32/54] remove the output step --- .../stable_diffusion_xl/after_denoise.py | 56 ++----------------- .../stable_diffusion_xl/modular_loader.py | 1 - 2 files changed, 5 insertions(+), 52 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py index 6ce59b5c35b9..ca848e20984f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -41,7 +41,7 @@ -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): +class StableDiffusionXLDecodeStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -187,63 +187,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state -# YiYi TODO: remove this, we don't need this in modular -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.return_dict: - block_state.images = (block_state.images,) - else: - block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) - self.add_block_state(state, block_state) - return components, state - - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] @property def description(self): return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 53f27571092a..4af942af64e6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -107,7 +107,6 @@ def num_channels_latents(self): "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), From 8ad14a52cbc3b3e0d7f97305dc95fee629564b97 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 23:25:56 +0200 Subject: [PATCH 33/54] make generator intermediates (it is mutable) --- .../stable_diffusion_xl/before_denoise.py | 6 +++--- .../modular_pipelines/stable_diffusion_xl/denoise.py | 4 ++-- .../modular_pipelines/stable_diffusion_xl/encoders.py | 8 +++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 6809b4cd8e2e..8f083f1870e7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -440,7 +440,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), @@ -459,6 +458,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "batch_size", required=True, @@ -733,7 +733,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), @@ -742,6 +741,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), @@ -879,7 +879,6 @@ def inputs(self) -> List[InputParam]: return [ InputParam("height"), InputParam("width"), - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), ] @@ -887,6 +886,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam( "batch_size", required=True, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index f605d0ab00aa..b29920764acb 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -485,13 +485,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("eta", default=0.0), ] @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "latents", required=True, @@ -554,13 +554,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("eta", default=0.0), ] @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "timesteps", required=True, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 3c84fc71c8af..ca4efe2c4a7f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -568,7 +568,6 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("image", required=True), - InputParam("generator"), InputParam("height"), InputParam("width"), ] @@ -576,6 +575,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @@ -680,7 +680,6 @@ def inputs(self) -> List[InputParam]: return [ InputParam("height"), InputParam("width"), - InputParam("generator"), InputParam("image", required=True), InputParam("mask_image", required=True), InputParam("padding_mask_crop"), @@ -688,7 +687,10 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] @property def intermediates_outputs(self) -> List[OutputParam]: From 96ce6744fe4c7a569fd1cb5e42ce7d188b85eb1e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 15 May 2025 00:45:45 +0200 Subject: [PATCH 34/54] after_denoise -> decoders --- .../modular_pipelines/stable_diffusion_xl/__init__.py | 4 ++-- .../stable_diffusion_xl/{after_denoise.py => decoders.py} | 0 .../stable_diffusion_xl/modular_pipeline_presets.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/diffusers/modular_pipelines/stable_diffusion_xl/{after_denoise.py => decoders.py} (100%) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 6d06c1f2e3df..f3f961d61a13 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -24,7 +24,7 @@ _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] - _import_structure["after_denoise"] = ["StableDiffusionXLAutoDecodeStep"] + _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -36,7 +36,7 @@ from .modular_pipeline_presets import StableDiffusionXLAutoPipeline from .modular_loader import StableDiffusionXLModularLoader from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep - from .after_denoise import StableDiffusionXLAutoDecodeStep + from .decoders import StableDiffusionXLAutoDecodeStep else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py similarity index 100% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 6ea327047740..637c7ac306d7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -18,7 +18,7 @@ from .denoise import StableDiffusionXLAutoDenoiseStep from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep -from .after_denoise import StableDiffusionXLAutoDecodeStep +from .decoders import StableDiffusionXLAutoDecodeStep from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -26,7 +26,7 @@ class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "after_denoise"] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] @property def description(self): From 27c1158b23fc06c03a1bb8f9d730d22c394421f5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:50:03 +0200 Subject: [PATCH 35/54] add a to-do for guider cconfig mixin --- src/diffusers/hooks/layer_skip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index c50d2b7471e4..65a99464ba2f 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -30,6 +30,8 @@ _LAYER_SKIP_HOOK = "layer_skip_hook" +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable @dataclass class LayerSkipConfig: r""" From d0fbf745e6e27185a8c465ced3373e2f77cf37e2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:52:12 +0200 Subject: [PATCH 36/54] refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method --- .../modular_pipeline_utils.py | 72 ++++++++----------- 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index a82f83fc38d9..0c6d1b585589 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -71,34 +71,31 @@ def __eq__(self, other): self.default_creation_method == other.default_creation_method) @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + if isinstance(component, ConfigMixin): config = component.config else: config = None load_spec = cls.decode_load_id(component._diffusers_load_id) - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) @classmethod def loading_fields(cls) -> List[str]: @@ -137,7 +134,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: "revision": "revision" } If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). + Returns None if load_id is "null" (indicating component not created with `load` method). """ # Get all loading fields in order @@ -158,20 +155,12 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: return result - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): @@ -201,34 +190,35 @@ def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] return component # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") if self.type_hint is None: try: from diffusers import AutoModel component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully self.type_hint = component.__class__ else: try: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} using load method: {e}") - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) component._diffusers_load_id = self.load_id return component From 163341d3dd6c7ca8d375630a3b41363d1da3c9ce Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:58:26 +0200 Subject: [PATCH 37/54] refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load --- .../modular_pipelines/modular_pipeline.py | 124 ++++++++++++------ 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5dcb903db495..1c67a3871764 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1651,54 +1651,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + Args: **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) + # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - # actual library and class name of the module - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below if not is_registered: self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec setattr(self, name, module) - if module is not None and self._component_manager is not None: + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue @@ -1707,10 +1721,6 @@ def register_components(self, **kwargs): if current_module is module: logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: @@ -1718,7 +1728,7 @@ def register_components(self, **kwargs): f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) - # same type, new instance → debug + # same type, new instance → replace but send debug log elif current_module is not None \ and module is not None \ and isinstance(module, current_module.__class__) \ @@ -1728,13 +1738,12 @@ def register_components(self, **kwargs): f"(same type {type(current_module).__name__}, new instance)" ) - # save modular_model_index.json config + # update modular_model_index.json config self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) - if module is not None and self._component_manager is not None: + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) @@ -1758,6 +1767,7 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): + # only update component_spec for from_pretrained components if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) @@ -1768,7 +1778,11 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: register_components_dict = {} for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component self.register_components(**register_components_dict) default_configs = {} @@ -1870,6 +1884,7 @@ def update(self, **kwargs): **kwargs: Component objects or configuration values to update: - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it Raises: ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) @@ -1893,22 +1908,52 @@ def update(self, **kwargs): unet=new_unet_model, requires_safety_checker=False ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) ``` """ # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - self.register_components(**passed_components) + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) config_to_register = {} @@ -1932,8 +1977,9 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" if component_names is None: - component_names = list(self._component_specs.keys()) + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] elif not isinstance(component_names, list): component_names = [component_names] @@ -1958,7 +2004,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): # check if the default is specified component_load_kwargs[key] = value["default"] try: - components_to_register[name] = spec.create(**component_load_kwargs) + components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") @@ -1986,7 +2032,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) @@ -2010,7 +2056,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # append a empty component spec for these not in modular_model_index component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) @staticmethod From 73ab5725c2fad4f62589554c9432c7b0dd673268 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 19:09:01 +0200 Subject: [PATCH 38/54] update components manager --- .../modular_pipelines/components_manager.py | 143 +++++++++++------- 1 file changed, 89 insertions(+), 54 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 0ace1b321e8b..88910baf90f4 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -243,78 +243,112 @@ def __init__(self): self._auto_offload_enabled = False - def _get_by_collection(self, collection: str): + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ - Select components by collection name. + Lookup component_ids by name, collection, or load_id. """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = "_".join(component_id.split("_")[:-1]) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(self.collections.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) - - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids def add(self, name, component, collection: Optional[str] = None): + + component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated components for comp_id, comp in self.components.items(): if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id + logger.warning( + f"component '{component}' already exists as '{comp_id}'" + ) + # if name is the same, use the existing component_id + if comp_id.split("_")[:-1] == component_id.split("_")[:-1]: + component_id = comp_id + break - component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) + existing = ", ".join(components_with_same_load_id) logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." + f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." ) - # add component to components manager self.components[component_id] = component self.added_time[component_id] = time.time() + if collection: if collection not in self.collections: self.collections[collection] = set() + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + else: + logger.info(f"Added component '{name}' as '{component_id}'") if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") return component_id - def remove(self, name: Union[str, List[str]]): + def remove(self, component_id: str = None): - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - - self.components.pop(name) - self.added_time.pop(name) + + component = self.components.pop(component_id) + self.added_time.pop(component_id) for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, as_name_component_tuples: bool = False): @@ -343,16 +377,8 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N or list of (base_name, component) tuples if as_name_component_tuples=True """ - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} # Helper to extract base name from component_id def get_base_name(component_id): @@ -542,11 +568,11 @@ def disable_auto_cpu_offload(self): self._auto_offload_enabled = False # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: - name: Name of the component to get info for + component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. @@ -555,16 +581,16 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No If fields is specified, returns only those fields. If a single field is requested as string, returns just that field's value. """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") - component = self.components[name] + component = self.components[component_id] # Build complete info dict first info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": next((coll for coll, comps in self.collections.items() if component_id in comps), None), } # Additional info for torch.nn.Module components @@ -776,7 +802,7 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: """ Get a single component by name. Raises an error if multiple components match or none are found. @@ -791,6 +817,15 @@ def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, Raises: ValueError: If no components match or multiple components match """ + + # if component_id is provided, return the component + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError(" if component_id is provided, name, collection, and load_id must be None") + elif component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + results = self.get(name, collection, load_id) if not results: From 61dac3bbe4ad8d71c5239e9d6158819f99069a20 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 19 May 2025 22:39:32 +0200 Subject: [PATCH 39/54] up --- .../modular_pipelines/components_manager.py | 84 +++++++++++++------ .../modular_pipelines/modular_pipeline.py | 15 ++-- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 88910baf90f4..992353389b95 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -253,7 +253,7 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde if name: ids_by_name = set() for component_id, component in components.items(): - comp_name = "_".join(component_id.split("_")[:-1]) + comp_name = self._id_to_name(component_id) if comp_name == name: ids_by_name.add(component_id) else: @@ -264,7 +264,7 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde if component_id in self.collections[collection]: ids_by_collection.add(component_id) else: - ids_by_collection = set(self.collections.keys()) + ids_by_collection = set(components.keys()) if load_id: ids_by_load_id = set() for name, component in components.items(): @@ -276,6 +276,9 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) return ids + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) def add(self, name, component, collection: Optional[str] = None): @@ -284,18 +287,24 @@ def add(self, name, component, collection: Optional[str] = None): # check for duplicated components for comp_id, comp in self.components.items(): if comp == component: - logger.warning( - f"component '{component}' already exists as '{comp_id}'" - ) - # if name is the same, use the existing component_id - if comp_id.split("_")[:-1] == component_id.split("_")[:-1]: + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning( + f"component '{name}' already exists as '{comp_id}'" + ) component_id = comp_id break + else: + logger.warning( + f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] if components_with_same_load_id: existing = ", ".join(components_with_same_load_id) @@ -311,12 +320,13 @@ def add(self, name, component, collection: Optional[str] = None): if collection: if collection not in self.collections: self.collections[collection] = set() - comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) - for comp_id in comp_ids_in_collection: - logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") - self.remove(comp_id) - self.collections[collection].add(component_id) - logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + if not component_id in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") else: logger.info(f"Added component '{name}' as '{component_id}'") @@ -590,7 +600,7 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str info = { "model_id": component_id, "added_time": self.added_time[component_id], - "collection": next((coll for coll, comps in self.collections.items() if component_id in comps), None), + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } # Additional info for torch.nn.Module components @@ -676,11 +686,19 @@ def format_device(component, info): ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 col_widths = { "name": max(15, max(len(name) for name in simple_names)), @@ -689,7 +707,7 @@ def format_device(component, info): "dtype": 15, "size": 10, "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) + "collection": max_collection_len } # Create the header lines @@ -718,11 +736,21 @@ def format_device(component, info): device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) - collection = info["collection"] or "N/A" + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + output += dash_line # Other components section @@ -738,9 +766,17 @@ def format_device(component, info): for name, component in others.items(): info = self.get_model_info(name) simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += dash_line # Add additional component info diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 1c67a3871764..36273da11f5a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -2043,19 +2043,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for name, value in config_dict.items(): if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) + # only pick up pretrained components from the repo + if component_spec_dict.get("repo", None) is not None: + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) From 808dff09cb3898803ba38a2ada860616561fdaa3 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 20 May 2025 11:42:51 +0200 Subject: [PATCH 40/54] [WIP] Modular Diffusers support custom code/pipeline blocks (#11539) * update * update --- src/diffusers/pipelines/modular_pipeline.py | 471 ++++++++++--------- src/diffusers/utils/dynamic_modules_utils.py | 85 +++- 2 files changed, 334 insertions(+), 222 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 636b543395df..4db8433768e4 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -12,29 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import importlib +import os import traceback import warnings from collections import OrderedDict +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm.auto import tqdm -import re -import os -import importlib - from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( + PushToHubMixin, is_accelerate_available, - is_accelerate_version, logging, - PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -42,18 +41,15 @@ OutputParam, format_components, format_configs, - format_input_params, format_inputs_short, format_intermediates_short, - format_output_params, - format_params, make_doc_string, ) -from .components_manager import ComponentsManager +from .pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj + -from copy import deepcopy if is_accelerate_available(): - import accelerate + pass logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -108,18 +104,16 @@ def format_value(v): intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" - f")" + f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" f" intermediates={{\n{intermediates}\n }}\n" f")" ) -@dataclass +@dataclass class BlockState: """ Container for block state data with attribute access and formatted representation. """ + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -129,28 +123,28 @@ def format_value(v): # Handle tensors directly if hasattr(v, "shape") and hasattr(v, "dtype"): return f"Tensor(dtype={v.dtype}, shape={v.shape})" - + # Handle lists of tensors elif isinstance(v, list): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"List[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle tuples of tensors elif isinstance(v, tuple): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle dicts with tensor values elif isinstance(v, dict): if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} return f"Dict of Tensors with shapes {shapes}" return repr(v) - + # Default case return repr(v) @@ -158,31 +152,92 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" - -class ModularPipelineMixin: +class ModularPipelineMixin(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + + return block_cls(**block_kwargs) + + def setup_loader( + self, + modular_repo: Optional[Union[str, os.PathLike]] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + ): """ - create a mouldar loader, optionally accept modular_repo to load from hub. + create a ModularLoader, optionally accept modular_repo to load from hub. """ # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) - + # Create deep copies to avoid modifying the original specs component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) # Create the loader with the updated specs specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + self.loader = loader_class( + specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection + ) @property def default_call_parameters(self) -> Dict[str, Any]: @@ -238,7 +293,6 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, if output is None: return state - elif isinstance(output, str): return state.get_intermediate(output) @@ -268,9 +322,8 @@ def set_progress_bar_config(self, **kwargs): class PipelineBlock(ModularPipelineMixin): - model_name = None - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -279,12 +332,11 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: return [] - + @property def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: @@ -322,7 +374,6 @@ def required_intermediates_inputs(self) -> List[str]: input_names.append(input_param.name) return input_names - def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -331,14 +382,14 @@ def __repr__(self): base_class = self.__class__.__bases__[0].__name__ # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) @@ -355,7 +406,9 @@ def __repr__(self): inputs = "Inputs:\n " + inputs_str # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = format_intermediates_short( + self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs + ) intermediates = f"Intermediates:\n{intermediates_str}" return ( @@ -369,24 +422,22 @@ def __repr__(self): f")" ) - @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: value = state.get_input(input_param.name) @@ -402,7 +453,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.name] = value return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): @@ -412,26 +463,28 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values exist for the same input. Args: named_input_lists: List of tuples containing (block_name, input_param_list) pairs - + Returns: List[InputParam]: Combined list of unique InputParam objects """ combined_dict = {} # name -> InputParam value_sources = {} # name -> block_name - + for block_name, inputs in named_input_lists: for input_param in inputs: if input_param.name in combined_dict: current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): warnings.warn( f"Multiple different default values found for input '{input_param.name}': " f"{current_param.default} (from block '{value_sources[input_param.name]}') and " @@ -443,9 +496,10 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li else: combined_dict[input_param.name] = input_param value_sources[input_param.name] = block_name - + return list(combined_dict.values()) + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: """ Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, @@ -453,17 +507,17 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> Args: named_output_lists: List of tuples containing (block_name, output_param_list) pairs - + Returns: List[OutputParam]: Combined list of unique OutputParam objects """ combined_dict = {} # name -> OutputParam - + for block_name, outputs in named_output_lists: for output_param in outputs: if output_param.name not in combined_dict: combined_dict[output_param.name] = output_param - + return list(combined_dict.values()) @@ -487,15 +541,15 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last + # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): + if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): raise ValueError( f"In {self.__class__.__name__}, exactly one None must be specified as the last element " "in block_trigger_inputs." @@ -509,7 +563,7 @@ def __init__(self): @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -532,7 +586,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - @property def required_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -557,7 +610,6 @@ def required_intermediates_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: @@ -571,7 +623,6 @@ def inputs(self) -> List[Tuple[str, Any]]: input_param.required = False return combined_inputs - @property def intermediates_inputs(self) -> List[str]: named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] @@ -589,7 +640,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] @@ -630,26 +681,27 @@ def _get_trigger_inputs(self): Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + trigger_inputs = set(self.block_trigger_inputs) trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - + return trigger_inputs @property @@ -660,12 +712,9 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -677,19 +726,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -699,7 +748,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -713,47 +762,41 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) + class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ + block_classes = [] block_names = [] @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -779,10 +822,10 @@ def expected_configs(self): @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - + Args: blocks_dict: Dictionary mapping block names to block instances - + Returns: A new SequentialPipelineBlocks instance """ @@ -791,14 +834,13 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance.block_names = list(blocks_dict.keys()) instance.blocks = blocks_dict return instance - + def __init__(self): blocks = OrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks - @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary @@ -809,9 +851,9 @@ def required_inputs(self) -> List[str]: for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - + @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -847,7 +889,7 @@ def intermediates_inputs(self) -> List[str]: should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -859,11 +901,11 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs - + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -878,29 +920,30 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise return pipeline, state - + def _get_trigger_inputs(self): """ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + return fn_recursive_get_trigger(self.blocks) @property @@ -913,10 +956,10 @@ def _traverse_trigger_blocks(self, trigger_inputs): def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - + # sequential or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): + if not hasattr(block, "block_trigger_inputs"): + if hasattr(block, "blocks"): # sequential for block_name, block in block.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) @@ -925,10 +968,10 @@ def fn_recursive_traverse(block, block_name, active_triggers): # PipelineBlock result_blocks[block_name] = block # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): + if hasattr(block, "outputs"): active_triggers.update(out.name for out in block.outputs) return result_blocks - + # auto else: # Find first block_trigger_input that matches any value in our active_triggers @@ -939,36 +982,35 @@ def fn_recursive_traverse(block, block_name, active_triggers): this_block = block.trigger_to_block_map[trigger_input] matching_trigger = trigger_input break - + # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] matching_trigger = None - + if this_block is not None: # sequential/auto - if hasattr(this_block, 'blocks'): + if hasattr(this_block, "blocks"): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined - if hasattr(this_block, 'outputs'): + if hasattr(this_block, "outputs"): active_triggers.update(out.name for out in this_block.outputs) return result_blocks - + all_blocks = OrderedDict() for block_name, block in self.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks - + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): trigger_inputs = [trigger_inputs] invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] @@ -977,7 +1019,7 @@ def get_execution_blocks(self, *trigger_inputs): f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" ) trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - + if trigger_inputs is None: if None in trigger_inputs_all: trigger_inputs = [None] @@ -985,17 +1027,14 @@ def get_execution_blocks(self, *trigger_inputs): trigger_inputs = [trigger_inputs_all[0]] blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1007,19 +1046,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1029,7 +1068,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -1043,39 +1082,30 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - -# YiYi TODO: +# YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() @@ -1084,30 +1114,29 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Base class for all Modular pipelines loaders. """ - config_name = "modular_model_index.json" + config_name = "modular_model_index.json" def register_components(self, **kwargs): """ - Register components with their corresponding specs. + Register components with their corresponding specs. This method is called when component changed or __init__ is called. Args: **kwargs: Keyword arguments where keys are component names and values are component objects. - + """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue - + is_registered = hasattr(self, name) if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") # actual library and class name of the module @@ -1115,10 +1144,10 @@ def register_components(self, **kwargs): library, class_name = _fetch_class_library_tuple(module) new_component_spec = ComponentSpec.from_component(name, module) component_spec_dict = self._component_spec_to_dict(new_component_spec) - + else: library, class_name = None, None - # if module is None, we do not update the spec, + # if module is None, we do not update the spec, # but we still need to update the config to make sure it's synced with the component spec # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) new_component_spec = component_spec @@ -1139,16 +1168,24 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue - + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + logger.info( + f"ModularLoader.register_components: {name} is already registered with same object, skipping" + ) continue - + # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + if ( + module is not None + and component_spec.type_hint is not None + and not isinstance(module, component_spec.type_hint) + ): + logger.warning( + f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}" + ) # warn if unregister if current_module is not None and module is None: @@ -1157,10 +1194,12 @@ def register_components(self, **kwargs): f"(was {current_module.__class__.__name__})" ) # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): logger.debug( f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" @@ -1175,46 +1214,51 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + def __init__( + self, + specs: List[Union[ComponentSpec, ConfigSpec]], + modular_repo: Optional[str] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): """ Initialize the loader with a list of component specs and config specs. """ self._component_manager = component_manager self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } + self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} + self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} # update component_specs and config_specs from modular_repo if modular_repo is not None: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + if ( + name in self._component_specs + and self._component_specs[name].default_creation_method == "from_pretrained" + and isinstance(value, (tuple, list)) + and len(value) == 3 + ): library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) self._component_specs[name] = component_spec elif name in self._config_specs: self._config_specs[name].default = value - + register_components_dict = {} for name, component_spec in self._component_specs.items(): register_components_dict[name] = None self.register_components(**register_components_dict) - + default_configs = {} for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @property def device(self) -> torch.device: r""" @@ -1251,7 +1295,7 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - + @property def device(self) -> torch.device: r""" @@ -1280,23 +1324,18 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } + return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} def update(self, **kwargs): """ Update components and configs after instance creation. - + Args: - """ + """ """ Update components and configuration values after the loader has been instantiated. @@ -1332,7 +1371,7 @@ def update(self, **kwargs): requires_safety_checker=False ) ``` - """ + """ # extract component_specs_updates & config_specs_updates from `specs` passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} @@ -1340,29 +1379,25 @@ def update(self, **kwargs): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - self.register_components(**passed_components) - config_to_register = {} for name, new_value in passed_config_values.items(): - # e.g. requires_aesthetics_score = False self._config_specs[name].default = new_value config_to_register[name] = new_value self.register_to_config(**config_to_register) - # YiYi TODO: support map for additional from_pretrained kwargs def load(self, component_names: Optional[List[str]] = None, **kwargs): """ Load selectedcomponents from specs. - + Args: component_names: List of component names to load **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: @@ -1379,7 +1414,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): unknown_component_names = set([name for name in component_names if name not in self._component_specs]) if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - + components_to_register = {} for name in components_to_load: spec = self._component_specs[name] @@ -1399,7 +1434,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): components_to_register[name] = spec.create(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") - + # Register all components at once self.register_components(**components_to_register) @@ -1407,11 +1442,12 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): def to(self, *args, **kwargs): pass - # YiYi TODO: + # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs + ): component_names = list(self._component_specs.keys()) config_names = list(self._config_specs.keys()) self.register_to_config(_components_names=component_names, _configs_names=config_names) @@ -1421,11 +1457,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs + ): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -1440,7 +1476,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - + for name in expected_component: for spec in component_specs: if spec.name == name: @@ -1450,7 +1486,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) return cls(component_specs + config_specs) - @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ @@ -1533,4 +1568,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..5d5eb23969ab 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,13 +15,16 @@ """Utilities to dynamically load objects from the Hub.""" import importlib +import signal import inspect import json import os import re import shutil import sys +import threading from pathlib import Path +from types import ModuleType from typing import Dict, Optional, Union from urllib import request @@ -37,6 +40,8 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15)) +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,15 +159,87 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): + if trust_remote_code is None: + if has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -454,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module) From 4968edc5dc499d472e88c5637dc2afd968f5bcbe Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:07:27 +0200 Subject: [PATCH 41/54] remove the duplicated components_manager file I forgot to deletee --- src/diffusers/pipelines/components_manager.py | 862 ------------------ 1 file changed, 862 deletions(-) delete mode 100644 src/diffusers/pipelines/components_manager.py diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py deleted file mode 100644 index bdff133e22d9..000000000000 --- a/src/diffusers/pipelines/components_manager.py +++ /dev/null @@ -1,862 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy - -import torch -import time -from dataclasses import dataclass - -from ..utils import ( - is_accelerate_available, - logging, -) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec - - -if is_accelerate_available(): - from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module - from accelerate.state import PartialState - from accelerate.utils import send_to_device - from accelerate.utils.memory import clear_device_cache - from accelerate.utils.modeling import convert_file_size_to_int - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# YiYi Notes: copied from modeling_utils.py (decide later where to put this) -def get_memory_footprint(self, return_buffers=True): - r""" - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to - benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch - discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 - - Arguments: - return_buffers (`bool`, *optional*, defaults to `True`): - Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are - tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm - layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 - """ - mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) - if return_buffers: - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) - mem = mem + mem_bufs - return mem - - -class CustomOffloadHook(ModelHook): - """ - A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are - on the given device. Optionally offloads other models to the CPU before the forward pass is called. - - Args: - execution_device(`str`, `int` or `torch.device`, *optional*): - The device on which the model should be executed. Will default to the MPS device if it's available, then - GPU 0 if there is a GPU, and finally to the CPU. - """ - - def __init__( - self, - execution_device: Optional[Union[str, int, torch.device]] = None, - other_hooks: Optional[List["UserCustomOffloadHook"]] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, - ): - self.execution_device = execution_device if execution_device is not None else PartialState().default_device - self.other_hooks = other_hooks - self.offload_strategy = offload_strategy - self.model_id = None - - def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): - self.offload_strategy = offload_strategy - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - """ - Add a hook to the list of hooks to consider for offloading. - """ - if self.other_hooks is None: - self.other_hooks = [] - self.other_hooks.append(hook) - - def init_hook(self, module): - return module.to("cpu") - - def pre_forward(self, module, *args, **kwargs): - if module.device != self.execution_device: - if self.other_hooks is not None: - hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] - # offload all other hooks - start_time = time.perf_counter() - if self.offload_strategy is not None: - hooks_to_offload = self.offload_strategy( - hooks=hooks_to_offload, - model_id=self.model_id, - model=module, - execution_device=self.execution_device, - ) - end_time = time.perf_counter() - logger.info( - f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" - ) - - for hook in hooks_to_offload: - logger.info( - f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" - ) - hook.offload() - - if hooks_to_offload: - clear_device_cache() - module.to(self.execution_device) - return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) - - -class UserCustomOffloadHook: - """ - A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of - the hook or remove it entirely. - """ - - def __init__(self, model_id, model, hook): - self.model_id = model_id - self.model = model - self.hook = hook - - def offload(self): - self.hook.init_hook(self.model) - - def attach(self): - add_hook_to_module(self.model, self.hook) - self.hook.model_id = self.model_id - - def remove(self): - remove_hook_from_module(self.model) - self.hook.model_id = None - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - self.hook.add_other_hook(hook) - - -def custom_offload_with_hook( - model_id: str, - model: torch.nn.Module, - execution_device: Union[str, int, torch.device] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, -): - hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) - user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) - user_hook.attach() - return user_hook - - -class AutoOffloadStrategy: - """ - Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on - the available memory on the device. - """ - - def __init__(self, memory_reserve_margin="3GB"): - self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) - - def __call__(self, hooks, model_id, model, execution_device): - if len(hooks) == 0: - return [] - - current_module_size = get_memory_footprint(model) - - mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] - mem_on_device = mem_on_device - self.memory_reserve_margin - if current_module_size < mem_on_device: - return [] - - min_memory_offload = current_module_size - mem_on_device - logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") - - # exlucde models that's not currently loaded on the device - module_sizes = dict( - sorted( - {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), - key=lambda x: x[1], - reverse=True, - ) - ) - - def search_best_candidate(module_sizes, min_memory_offload): - """ - search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a - minimum memory offload size. the combination of models should add up to the smallest modulesize that is - larger than `min_memory_offload` - """ - model_ids = list(module_sizes.keys()) - best_candidate = None - best_size = float("inf") - for r in range(1, len(model_ids) + 1): - for candidate_model_ids in combinations(model_ids, r): - candidate_size = sum( - module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids - ) - if candidate_size < min_memory_offload: - continue - else: - if best_candidate is None or candidate_size < best_size: - best_candidate = candidate_model_ids - best_size = candidate_size - - return best_candidate - - best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) - - if best_offload_model_ids is None: - # if no combination is found, meaning that we cannot meet the memory requirement, offload all models - logger.warning("no combination of models to offload to cpu is found, offloading all models") - hooks_to_offload = hooks - else: - hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] - - return hooks_to_offload - - - -from .modular_pipeline_utils import ComponentSpec -import uuid -class ComponentsManager: - def __init__(self): - self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added - self.collections = OrderedDict() # collection_name -> set of component_names - self.model_hooks = None - self._auto_offload_enabled = False - - - def _get_by_collection(self, collection: str): - """ - Select components by collection name. - """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components - - - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components - - - def add(self, name, component, collection: Optional[str] = None): - - for comp_id, comp in self.components.items(): - if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id - - component_id = f"{name}_{uuid.uuid4()}" - - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) - if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) - logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." - ) - - - # add component to components manager - self.components[component_id] = component - self.added_time[component_id] = time.time() - if collection: - if collection not in self.collections: - self.collections[collection] = set() - self.collections[collection].add(component_id) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") - return component_id - - - def remove(self, name: Union[str, List[str]]): - - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") - return - - self.components.pop(name) - self.added_time.pop(name) - - for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False): - """ - Select components by name with simple pattern matching. - - Args: - names: Component name(s) or pattern(s) - Patterns: - - "unet" : match any component with base name "unet" (e.g., unet_123abc) - - "!unet" : everything except components with base name "unet" - - "unet*" : anything with base name starting with "unet" - - "!unet*" : anything with base name NOT starting with "unet" - - "*unet*" : anything with base name containing "unet" - - "!*unet*" : anything with base name NOT containing "unet" - - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" - - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" - - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" - collection: Optional collection to filter by - load_id: Optional load_id to filter by - as_name_component_tuples: If True, returns a list of (name, component) tuples using base names - instead of a dictionary with component IDs as keys - - Returns: - Dictionary mapping component IDs to components, - or list of (base_name, component) tuples if as_name_component_tuples=True - """ - - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) - - # Helper to extract base name from component_id - def get_base_name(component_id): - parts = component_id.split('_') - # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return component_id - - if names is None: - if as_name_component_tuples: - return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] - else: - return components - - # Create mapping from component_id to base_name for all components - base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - - def matches_pattern(component_id, pattern, exact_match=False): - """ - Helper function to check if a component matches a pattern based on its base name. - - Args: - component_id: The component ID to check - pattern: The pattern to match against - exact_match: If True, only exact matches to base_name are considered - """ - base_name = base_names[component_id] - - # Exact match with base name - if exact_match: - return pattern == base_name - - # Prefix match (ends with *) - elif pattern.endswith('*'): - prefix = pattern[:-1] - return base_name.startswith(prefix) - - # Contains match (starts with *) - elif pattern.startswith('*'): - search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] - return search in base_name - - # Exact match (no wildcards) - else: - return pattern == base_name - - if isinstance(names, str): - # Check if this is a "not" pattern - is_not_pattern = names.startswith('!') - if is_not_pattern: - names = names[1:] # Remove the ! prefix - - # Handle OR patterns (containing |) - if '|' in names: - terms = names.split('|') - matches = {} - - for comp_id, comp in components.items(): - # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - - # Check if any of the terms match this component - should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - - # Flip the decision if this is a NOT pattern - if is_not_pattern: - should_include = not should_include - - if should_include: - matches[comp_id] = comp - - log_msg = "NOT " if is_not_pattern else "" - match_type = "exactly matching" if exact_match else "matching any of patterns" - logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - - # Try exact match with a base name - elif any(names == base_name for base_name in base_names.values()): - # Find all components with this base name - matches = { - comp_id: comp for comp_id, comp in components.items() - if (base_names[comp_id] == names) != is_not_pattern - } - - if is_not_pattern: - logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - - # Prefix match (ends with *) - elif names.endswith('*'): - prefix = names[:-1] - matches = { - comp_id: comp for comp_id, comp in components.items() - if base_names[comp_id].startswith(prefix) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") - else: - logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - - # Contains match (starts with *) - elif names.startswith('*'): - search = names[1:-1] if names.endswith('*') else names[1:] - matches = { - comp_id: comp for comp_id, comp in components.items() - if (search in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - - # Substring match (no wildcards, but not an exact component name) - elif any(names in base_name for base_name in base_names.values()): - matches = { - comp_id: comp for comp_id, comp in components.items() - if (names in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - - else: - raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - - if not matches: - raise ValueError(f"No components found matching pattern '{names}'") - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] - else: - return matches - - elif isinstance(names, list): - results = {} - for name in names: - result = self.get(name, collection, load_id, as_name_component_tuples=False) - results.update(result) - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in results.items()] - else: - return results - - else: - raise ValueError(f"Invalid type for names: {type(names)}") - - def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): - remove_hook_from_module(component, recurse=True) - - self.disable_auto_cpu_offload() - offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) - device = torch.device(device) - if device.index is None: - device = torch.device(f"{device.type}:{0}") - all_hooks = [] - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module): - hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) - all_hooks.append(hook) - - for hook in all_hooks: - other_hooks = [h for h in all_hooks if h is not hook] - for other_hook in other_hooks: - if other_hook.hook.execution_device == hook.hook.execution_device: - hook.add_other_hook(other_hook) - - self.model_hooks = all_hooks - self._auto_offload_enabled = True - self._auto_offload_device = device - - def disable_auto_cpu_offload(self): - if self.model_hooks is None: - self._auto_offload_enabled = False - return - - for hook in self.model_hooks: - hook.offload() - hook.remove() - if self.model_hooks: - clear_device_cache() - self.model_hooks = None - self._auto_offload_enabled = False - - # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: - """Get comprehensive information about a component. - - Args: - name: Name of the component to get info for - fields: Optional field(s) to return. Can be a string for single field or list of fields. - If None, returns all fields. - - Returns: - Dictionary containing requested component metadata. - If fields is specified, returns only those fields. - If a single field is requested as string, returns just that field's value. - """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") - - component = self.components[name] - - # Build complete info dict first - info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), - } - - # Additional info for torch.nn.Module components - if isinstance(component, torch.nn.Module): - # Check for hook information - has_hook = hasattr(component, "_hf_hook") - execution_device = None - if has_hook and hasattr(component._hf_hook, "execution_device"): - execution_device = component._hf_hook.execution_device - - info.update({ - "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), - "adapters": None, # Default to None - "has_hook": has_hook, - "execution_device": execution_device, - }) - - # Get adapters if applicable - if hasattr(component, "peft_config"): - info["adapters"] = list(component.peft_config.keys()) - - # Check for IP-Adapter scales - if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): - processors = copy.deepcopy(component.attn_processors) - # First check if any processor is an IP-Adapter - processor_types = [v.__class__.__name__ for v in processors.values()] - if any("IPAdapter" in ptype for ptype in processor_types): - # Then get scales only from IP-Adapter processors - scales = { - k: v.scale - for k, v in processors.items() - if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ - } - if scales: - info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) - - # If fields specified, filter info - if fields is not None: - if isinstance(fields, str): - # Single field requested, return just that value - return {fields: info.get(fields)} - else: - # List of fields requested, return dict with just those fields - return {k: v for k, v in info.items() if k in fields} - - return info - - def __repr__(self): - # Helper to get simple name without UUID - def get_simple_name(name): - # Extract the base name by splitting on underscore and taking first part - # This assumes names are in format "name_uuid" - parts = name.split('_') - # If we have at least 2 parts and the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return name - - # Extract load_id if available - def get_load_id(component): - if hasattr(component, "_diffusers_load_id"): - return component._diffusers_load_id - return "N/A" - - # Format device info compactly - def format_device(component, info): - if not info["has_hook"]: - return str(getattr(component, 'device', 'N/A')) - else: - device = str(getattr(component, 'device', 'N/A')) - exec_device = str(info['execution_device'] or 'N/A') - return f"{device}({exec_device})" - - # Get all simple names to calculate width - simple_names = [get_simple_name(id) for id in self.components.keys()] - - # Get max length of load_ids for models - load_ids = [ - get_load_id(component) - for component in self.components.values() - if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") - ] - max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] - - col_widths = { - "name": max(15, max(len(name) for name in simple_names)), - "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 15, # Reduced since using more compact format - "dtype": 15, - "size": 10, - "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) - } - - # Create the header lines - sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - - output = "Components:\n" + sep_line - - # Separate components into models and others - models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} - others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} - - # Models section - if models: - output += "Models:\n" + dash_line - # Column headers - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " - output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" - output += dash_line - - # Model entries - for name, component in models.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - device_str = format_device(component, info) - dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - load_id = get_load_id(component) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" - output += dash_line - - # Other components section - if others: - if models: # Add extra newline if we had models section - output += "\n" - output += "Other Components:\n" + dash_line - # Column headers for other components - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" - output += dash_line - - # Other component entries - for name, component in others.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" - output += dash_line - - # Add additional component info - output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" - for name in self.components: - info = self.get_model_info(name) - if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - simple_name = get_simple_name(name) - output += f"\n{simple_name}:\n" - if info.get("adapters") is not None: - output += f" Adapters: {info['adapters']}\n" - if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" - output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" - - return output - - def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): - """ - Load components from a pretrained model and add them to the manager. - - Args: - pretrained_model_name_or_path (str): The path or identifier of the pretrained model - prefix (str, optional): Prefix to add to all component names loaded from this model. - If provided, components will be named as "{prefix}_{component_name}" - **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() - """ - subfolder = kwargs.pop("subfolder", None) - # YiYi TODO: extend AutoModel to support non-diffusers models - if subfolder: - from ..models import AutoModel - component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) - component_name = f"{prefix}_{subfolder}" if prefix else subfolder - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - else: - from ..pipelines.pipeline_utils import DiffusionPipeline - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: - """ - Get a single component by name. Raises an error if multiple components match or none are found. - - Args: - name: Component name or pattern - collection: Optional collection to filter by - load_id: Optional load_id to filter by - - Returns: - A single component - - Raises: - ValueError: If no components match or multiple components match - """ - results = self.get(name, collection, load_id) - - if not results: - raise ValueError(f"No components found matching '{name}'") - - if len(results) > 1: - raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - - return next(iter(results.values())) - -def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: - """Summarizes a dictionary by finding common prefixes that share the same value. - - For a dictionary with dot-separated keys like: - { - 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], - 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], - 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], - } - - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: - { - 'down_blocks': [0.6], - 'up_blocks': [0.3] - } - """ - # First group by values - convert lists to tuples to make them hashable - value_to_keys = {} - for key, value in d.items(): - value_tuple = tuple(value) if isinstance(value, list) else value - if value_tuple not in value_to_keys: - value_to_keys[value_tuple] = [] - value_to_keys[value_tuple].append(key) - - def find_common_prefix(keys: List[str]) -> str: - """Find the shortest common prefix among a list of dot-separated keys.""" - if not keys: - return "" - if len(keys) == 1: - return keys[0] - - # Split all keys into parts - key_parts = [k.split('.') for k in keys] - - # Find how many initial parts are common - common_length = 0 - for parts in zip(*key_parts): - if len(set(parts)) == 1: # All parts at this position are the same - common_length += 1 - else: - break - - if common_length == 0: - return "" - - # Return the common prefix - return '.'.join(key_parts[0][:common_length]) - - # Create summary by finding common prefixes for each value group - summary = {} - for value_tuple, keys in value_to_keys.items(): - prefix = find_common_prefix(keys) - if prefix: # Only add if we found a common prefix - # Convert tuple back to list if it was originally a list - value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple - summary[prefix] = value - else: - summary[""] = value # Use empty string if no common prefix - - return summary From de6ab6b49d17b9e735638f9f21df9b173fd2d5b0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:07:58 +0200 Subject: [PATCH 42/54] fix import in block mapping --- .../stable_diffusion_xl/modular_pipeline_block_mappings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py index c739a24e9759..6d909ab5a4a0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -41,11 +41,11 @@ StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLIPAdapterStep ) -from .after_denoise import ( +from .decoders import ( StableDiffusionXLDecodeStep, - StableDiffusionXLInpaintDecodeStep + StableDiffusionXLInpaintDecodeStep, + StableDiffusionXLAutoDecodeStep ) -from .after_denoise import StableDiffusionXLAutoDecodeStep # YiYi notes: comment out for now, work on this later From eb9415031a54b6aba5a44a52ead90197502f806f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:08:28 +0200 Subject: [PATCH 43/54] add a to-do for modular loader --- src/diffusers/modular_pipelines/modular_pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 36273da11f5a..ef725c32f4f9 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1638,9 +1638,10 @@ def __repr__(self): # YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config +# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. From 1b89ac144c6eba7d66ca34924c83a6323944ccac Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:10:06 +0200 Subject: [PATCH 44/54] prepare_latents_img2img pipeline method -> function, maybe do the same for others? --- .../stable_diffusion_xl/before_denoise.py | 168 +++++++++--------- 1 file changed, 83 insertions(+), 85 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 8f083f1870e7..07f096249c0d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -127,6 +127,86 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): + + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -751,89 +831,6 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - @staticmethod - def prepare_latents_img2img( - components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -842,8 +839,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.device = components._execution_device block_state.add_noise = True if block_state.denoising_start is None else False if block_state.latents is None: - block_state.latents = self.prepare_latents_img2img( - components, + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, block_state.image_latents, block_state.latent_timestep, block_state.batch_size, From d136ae36c87b66cb6e53c30098d09cd307641588 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:11:05 +0200 Subject: [PATCH 45/54] update input for loop blocks, do not need to include intermediate --- .../stable_diffusion_xl/denoise.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index b29920764acb..bc567a6b034f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -68,18 +68,11 @@ def intermediates_inputs(self) -> List[str]: ), ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - return components, block_state # loop step (1): prepare latent input for denoiser (with inpainting) @@ -120,9 +113,6 @@ def intermediates_inputs(self) -> List[str]: ), ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] @staticmethod def check_inputs(components, block_state): @@ -187,12 +177,6 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), InputParam( "num_inference_steps", required=True, @@ -319,12 +303,6 @@ def intermediates_inputs(self) -> List[str]: type_hint=List[float], description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), InputParam( "timestep_cond", type_hint=Optional[torch.Tensor], @@ -492,12 +470,6 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam("generator"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), ] @property From 72e1b74638ecc6d0658673aa5cb94018997708b9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 20:26:51 +0200 Subject: [PATCH 46/54] solve merge conflict: manually add back the remote code change to modular_pipeline --- .../modular_pipelines/modular_pipeline.py | 77 ++++++++++++++++--- 1 file changed, 67 insertions(+), 10 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index ef725c32f4f9..02ceb49e7b34 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import traceback import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type +from typing import Any, Dict, List, Tuple, Union, Optional from copy import deepcopy @@ -31,11 +33,10 @@ from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( is_accelerate_available, - is_accelerate_version, logging, PushToHubMixin, ) -from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple +from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -43,14 +44,12 @@ OutputParam, format_components, format_configs, - format_input_params, format_inputs_short, format_intermediates_short, - format_output_params, - format_params, make_doc_string, ) from .components_manager import ComponentsManager +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code from copy import deepcopy if is_accelerate_available(): @@ -245,19 +244,76 @@ def format_value(v): -class ModularPipelineMixin: +class ModularPipelineMixin(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ + + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("TODO") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + print(f"block_kwargs: {block_kwargs}") + + return block_cls(**block_kwargs) + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ - create a mouldar loader, optionally accept modular_repo to load from hub. + create a ModularLoader, optionally accept modular_repo to load from hub. """ # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) @@ -365,7 +421,8 @@ class PipelineBlock(ModularPipelineMixin): @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") + # raise NotImplementedError("description method must be implemented in subclasses") + return "TODO: add a description" @property def expected_components(self) -> List[ComponentSpec]: From 29de29f02c5c4cb3bf8fd18eaaa31b19a306e5e2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 May 2025 22:31:10 +0200 Subject: [PATCH 47/54] add node_utils --- .../modular_pipelines/modular_pipeline.py | 5 +- src/diffusers/modular_pipelines/node_utils.py | 351 ++++++++++++++++++ 2 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/modular_pipelines/node_utils.py diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 02ceb49e7b34..3136c3bb11f1 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -303,7 +303,6 @@ def from_pretrained( block_kwargs = { name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs } - print(f"block_kwargs: {block_kwargs}") return block_cls(**block_kwargs) @@ -749,6 +748,8 @@ def expected_configs(self): @property def required_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] first_block = next(iter(self.blocks.values())) required_by_all = set(getattr(first_block, "required_inputs", set())) @@ -763,6 +764,8 @@ def required_inputs(self) -> List[str]: # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] first_block = next(iter(self.blocks.values())) required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py new file mode 100644 index 000000000000..2dfb85a5f903 --- /dev/null +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -0,0 +1,351 @@ +from ..configuration_utils import ConfigMixin +from .modular_pipeline import SequentialPipelineBlocks +from .modular_pipeline_utils import InputParam, OutputParam +from ..image_processor import PipelineImageInput + +from typing import Union, List, Optional, Tuple +import torch +import PIL +import numpy as np +import logging +logger = logging.getLogger(__name__) + +# YiYi Notes: this is actually for SDXL, put it here for now +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + +DEFAULT_PARAM_MAPS = { + "prompt": { + "label": "Prompt", + "type": "string", + "default": "a bear sitting in a chair drinking a milkshake", + "display": "textarea", + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", + "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + "display": "textarea", + }, + + "num_inference_steps": { + "label": "Steps", + "type": "int", + "default": 25, + "min": 1, + "max": 1000, + }, + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "display": "random", + }, + "width": { + "label": "Width", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "height": { + "label": "Height", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "images": { + "label": "Images", + "type": "image", + "display": "output", + }, + "image": { + "label": "Image", + "type": "image", + "display": "input", + }, +} + +DEFAULT_TYPE_MAPS ={ + "int": { + "type": "int", + "default": 0, + "min": 0, + }, + "float": { + "type": "float", + "default": 0.0, + "min": 0.0, + }, + "str": { + "type": "string", + "default": "", + }, + "bool": { + "type": "boolean", + "default": False, + }, + "image": { + "type": "image", + }, +} + +DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] +DEFAULT_CATEGORY = "Modular Diffusers" +DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] +DEFAULT_PARAMS_GROUPS_KEYS = { + "text_encoders": ["text_encoder", "tokenizer"], + "ip_adapter_embeds": ["ip_adapter_embeds"], + "text_embeds": ["prompt_embeds"], +} + + +def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): + """ + Get the group name for a given parameter name, if not part of a group, return None + e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None + """ + for group_name, group_keys in group_params_keys.items(): + for group_key in group_keys: + if group_key in name: + return group_name + return None + +class MellonNode(ConfigMixin): + + block_class = None + config_name = "node_config.json" + + + def __init__(self, category=DEFAULT_CATEGORY, label=None, input_params=None, intermediate_params=None, component_params=None, output_params=None): + self.blocks = self.block_class() + + if label is None: + label = self.blocks.__class__.__name__ + + expected_inputs = [inp.name for inp in self.blocks.inputs] + expected_intermediates = [inp.name for inp in self.blocks.intermediates_inputs] + expected_components = [comp.name for comp in self.blocks.expected_components] + expected_outputs = [out.name for out in self.blocks.intermediates_outputs] + + if input_params is None: + input_params ={} + for inp in self.blocks.inputs: + # create a param dict for each input e.g. for prompt, param = {"prompt": {"label": "Prompt", "type": "string", "default": "a bear sitting in a chair drinking a milkshake", "display": "textarea"} } + param = {} + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] + elif inp.required: + group_name = get_group_name(inp.name) + if group_name: + param = group_name + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + inp_spec = SDXL_INPUTS_SCHEMA.get(inp.name, None) + if inp_spec: + type_str = str(inp_spec.type_hint).lower() + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param[inp.name] = type_param + param[inp.name]["display"] = "input" + break + else: + param = inp.name + # add the param dict to the inp_params dict + if param: + input_params[inp.name] = param + + if intermediate_params is None: + intermediate_params = {} + for inp in self.blocks.intermediates_inputs: + param = {} + if inp.name in DEFAULT_PARAM_MAPS: + param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] + elif inp.required: + group_name = get_group_name(inp.name) + if group_name: + param = group_name + else: + inp_spec = SDXL_INTERMEDIATE_INPUTS_SCHEMA.get(inp.name, None) + if inp_spec: + type_str = str(inp_spec.type_hint).lower() + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param[inp.name] = type_param + param[inp.name]["display"] = "input" + break + else: + param = inp.name + # add the param dict to the intermediate_params dict + if param: + intermediate_params[inp.name] = param + + if component_params is None: + component_params = {} + for comp in self.blocks.expected_components: + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + param = {} + group_name = get_group_name(comp.name) + if group_name: + param = group_name + elif comp.name in DEFAULT_MODEL_KEYS: + param[comp.name] = { + "label": comp.name, + "type": "diffusers_auto_model", + "display": "input", + } + else: + param = comp.name + # add the param dict to the model_params dict + if param: + component_params[comp.name] = param + + if output_params is None: + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.blocks.keys())[-1] + outputs = self.blocks.blocks[last_block_name].intermediates_outputs + else: + outputs = self.blocks.intermediates_outputs + + for out in outputs: + param = {} + if out.name in DEFAULT_PARAM_MAPS: + param[out.name] = DEFAULT_PARAM_MAPS[out.name] + param[out.name]["display"] = "output" + else: + group_name = get_group_name(out.name) + if group_name: + param = group_name + else: + param = out.name + # add the param dict to the outputs dict + if param: + output_params[out.name] = param + + register_dict = { + "category": category, + "label": label, + "input_params": input_params, + "intermediate_params": intermediate_params, + "component_params": component_params, + "output_params": output_params, + } + self.register_to_config(**register_dict) + + + + + + + + + + + + From 87f63d424a6efb0d309ced1e67b827bd10881b7c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 22 May 2025 11:50:36 +0200 Subject: [PATCH 48/54] modular node! --- .../modular_pipeline_utils.py | 4 +- src/diffusers/modular_pipelines/node_utils.py | 440 ++++++++++++------ 2 files changed, 306 insertions(+), 138 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 0c6d1b585589..6d6704f4eb38 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -246,7 +246,7 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -258,7 +258,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" - kwargs_type: str = None + kwargs_type: str = None # YiYi notes: remove this feature (maybe) def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 2dfb85a5f903..9ee9c069277d 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,7 +1,10 @@ from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin from .modular_pipeline_utils import InputParam, OutputParam from ..image_processor import PipelineImageInput +from pathlib import Path +import json +import os from typing import Union, List, Optional, Tuple import torch @@ -77,29 +80,8 @@ "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") } -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} +SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} + DEFAULT_PARAM_MAPS = { "prompt": { @@ -191,7 +173,7 @@ DEFAULT_PARAMS_GROUPS_KEYS = { "text_encoders": ["text_encoder", "tokenizer"], "ip_adapter_embeds": ["ip_adapter_embeds"], - "text_embeds": ["prompt_embeds"], + "prompt_embeddings": ["prompt_embeds"], } @@ -200,144 +182,330 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None """ + if name is None: + return None for group_name, group_keys in group_params_keys.items(): for group_key in group_keys: if group_key in name: return group_name return None + + +class ModularNode(ConfigMixin): -class MellonNode(ConfigMixin): - - block_class = None config_name = "node_config.json" + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + return cls(blocks, **kwargs) - def __init__(self, category=DEFAULT_CATEGORY, label=None, input_params=None, intermediate_params=None, component_params=None, output_params=None): - self.blocks = self.block_class() + def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): + self.blocks = blocks if label is None: label = self.blocks.__class__.__name__ - - expected_inputs = [inp.name for inp in self.blocks.inputs] - expected_intermediates = [inp.name for inp in self.blocks.intermediates_inputs] - expected_components = [comp.name for comp in self.blocks.expected_components] - expected_outputs = [out.name for out in self.blocks.intermediates_outputs] - - if input_params is None: - input_params ={} - for inp in self.blocks.inputs: - # create a param dict for each input e.g. for prompt, param = {"prompt": {"label": "Prompt", "type": "string", "default": "a bear sitting in a chair drinking a milkshake", "display": "textarea"} } - param = {} - if inp.name in DEFAULT_PARAM_MAPS: - # first check if it's in the default param map, if so, directly use that - param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] - elif inp.required: - group_name = get_group_name(inp.name) - if group_name: - param = group_name - else: - # if not, check if it's in the SDXL input schema, if so, - # 1. use the type hint to determine the type - # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} - inp_spec = SDXL_INPUTS_SCHEMA.get(inp.name, None) - if inp_spec: - type_str = str(inp_spec.type_hint).lower() - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param[inp.name] = type_param - param[inp.name]["display"] = "input" - break - else: - param = inp.name - # add the param dict to the inp_params dict - if param: - input_params[inp.name] = param - - if intermediate_params is None: - intermediate_params = {} - for inp in self.blocks.intermediates_inputs: - param = {} - if inp.name in DEFAULT_PARAM_MAPS: - param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] - elif inp.required: - group_name = get_group_name(inp.name) - if group_name: - param = group_name - else: - inp_spec = SDXL_INTERMEDIATE_INPUTS_SCHEMA.get(inp.name, None) - if inp_spec: - type_str = str(inp_spec.type_hint).lower() - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param[inp.name] = type_param - param[inp.name]["display"] = "input" - break - else: - param = inp.name - # add the param dict to the intermediate_params dict - if param: - intermediate_params[inp.name] = param - - if component_params is None: - component_params = {} - for comp in self.blocks.expected_components: - to_exclude = False - for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: - if exclude_key in comp.name: - to_exclude = True + # blocks param name -> mellon param name + self.name_mapping = {} + + input_params = {} + # pass or create a default param dict for each input + # e.g. for prompt, + # prompt = { + # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type + # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) + # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} + # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} + inputs = self.blocks.inputs + self.blocks.intermediates_inputs + for inp in inputs: + param = kwargs.pop(inp.name, None) + if param: + # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) + input_params[inp.name] = param + mellon_name = param.pop("name", inp.name) + if mellon_name != inp.name: + self.name_mapping[inp.name] = mellon_name + continue + + if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + continue + + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param = DEFAULT_PARAM_MAPS[inp.name].copy() + elif get_group_name(inp.name): + param = get_group_name(inp.name) + if inp.name not in self.name_mapping: + self.name_mapping[inp.name] = param + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + if inp.type_hint is not None: + type_str = str(inp.type_hint).lower() + else: + inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) + type_str = str(inp_spec.type_hint).lower() if inp_spec else "" + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param = type_param.copy() + param["label"] = inp.name + param["display"] = "input" break - if to_exclude: - continue - - param = {} - group_name = get_group_name(comp.name) + else: + param = inp.name + # add the param dict to the inp_params dict + input_params[inp.name] = param + + + component_params = {} + for comp in self.blocks.expected_components: + param = kwargs.pop(comp.name, None) + if param: + component_params[comp.name] = param + mellon_name = param.pop("name", comp.name) + if mellon_name != comp.name: + self.name_mapping[comp.name] = mellon_name + continue + + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + if get_group_name(comp.name): + param = get_group_name(comp.name) + if comp.name not in self.name_mapping: + self.name_mapping[comp.name] = param + elif comp.name in DEFAULT_MODEL_KEYS: + param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} + else: + param = comp.name + # add the param dict to the model_params dict + component_params[comp.name] = param + + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.blocks.keys())[-1] + outputs = self.blocks.blocks[last_block_name].intermediates_outputs + else: + outputs = self.blocks.intermediates_outputs + + for out in outputs: + param = kwargs.pop(out.name, None) + if param: + output_params[out.name] = param + mellon_name = param.pop("name", out.name) + if mellon_name != out.name: + self.name_mapping[out.name] = mellon_name + continue + + if out.name in DEFAULT_PARAM_MAPS: + param = DEFAULT_PARAM_MAPS[out.name].copy() + param["display"] = "output" + else: + group_name = get_group_name(out.name) if group_name: param = group_name - elif comp.name in DEFAULT_MODEL_KEYS: - param[comp.name] = { - "label": comp.name, - "type": "diffusers_auto_model", - "display": "input", - } + if out.name not in self.name_mapping: + self.name_mapping[out.name] = param else: - param = comp.name - # add the param dict to the model_params dict - if param: - component_params[comp.name] = param - - if output_params is None: - output_params = {} - if isinstance(self.blocks, SequentialPipelineBlocks): - last_block_name = list(self.blocks.blocks.keys())[-1] - outputs = self.blocks.blocks[last_block_name].intermediates_outputs - else: - outputs = self.blocks.intermediates_outputs + param = out.name + # add the param dict to the outputs dict + output_params[out.name] = param - for out in outputs: - param = {} - if out.name in DEFAULT_PARAM_MAPS: - param[out.name] = DEFAULT_PARAM_MAPS[out.name] - param[out.name]["display"] = "output" - else: - group_name = get_group_name(out.name) - if group_name: - param = group_name - else: - param = out.name - # add the param dict to the outputs dict - if param: - output_params[out.name] = param + if len(kwargs) > 0: + logger.warning(f"Unused kwargs: {kwargs}") register_dict = { "category": category, "label": label, "input_params": input_params, - "intermediate_params": intermediate_params, "component_params": component_params, "output_params": output_params, + "name_mapping": self.name_mapping, } self.register_to_config(**register_dict) + + def setup(self, components, collection=None): + self.blocks.setup_loader(component_manager=components, collection=collection) + self._components_manager = components + + @property + def mellon_config(self): + return self._convert_to_mellon_config() + + def _convert_to_mellon_config(self): + + node = {} + node["label"] = self.config.label + node["category"] = self.config.category + + node_param = {} + for inp_name, inp_param in self.config.input_params.items(): + if inp_name in self.name_mapping: + mellon_name = self.name_mapping[inp_name] + else: + mellon_name = inp_name + if isinstance(inp_param, str): + param = { + "label": inp_param, + "type": inp_param, + "display": "input", + } + else: + param = inp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") + + + for comp_name, comp_param in self.config.component_params.items(): + if comp_name in self.name_mapping: + mellon_name = self.name_mapping[comp_name] + else: + mellon_name = comp_name + if isinstance(comp_param, str): + param = { + "label": comp_param, + "type": comp_param, + "display": "input", + } + else: + param = comp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") + + + for out_name, out_param in self.config.output_params.items(): + if out_name in self.name_mapping: + mellon_name = self.name_mapping[out_name] + else: + mellon_name = out_name + if isinstance(out_param, str): + param = { + "label": out_param, + "type": out_param, + "display": "output", + } + else: + param = out_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") + node["params"] = node_param + return node + + def save_mellon_config(self, file_path): + """ + Save the Mellon configuration to a JSON file. + + Args: + file_path (str or Path): Path where the JSON file will be saved + + Returns: + Path: Path to the saved config file + """ + file_path = Path(file_path) + + # Create directory if it doesn't exist + os.makedirs(file_path.parent, exist_ok=True) + + # Create a combined dictionary with module definition and name mapping + config = { + "module": self.mellon_config, + "name_mapping": self.name_mapping + } + + # Save the config to file + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2) + + logger.info(f"Mellon config and name mapping saved to {file_path}") + + return file_path + + @classmethod + def load_mellon_config(cls, file_path): + """ + Load a Mellon configuration from a JSON file. + + Args: + file_path (str or Path): Path to the JSON file containing Mellon config + + Returns: + dict: The loaded combined configuration containing 'module' and 'name_mapping' + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + logger.info(f"Mellon config loaded from {file_path}") + + + return config + + def process_inputs(self, **kwargs): + + params_components = {} + for comp_name, comp_param in self.config.component_params.items(): + logger.debug(f"component: {comp_name}") + mellon_comp_name = self.name_mapping.get(comp_name, comp_name) + if mellon_comp_name in kwargs: + if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: + comp = kwargs[mellon_comp_name].pop(comp_name) + else: + comp = kwargs.pop(mellon_comp_name) + if comp: + params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) + + + params_run = {} + for inp_name, inp_param in self.config.input_params.items(): + logger.debug(f"input: {inp_name}") + mellon_inp_name = self.name_mapping.get(inp_name, inp_name) + if mellon_inp_name in kwargs: + if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: + inp = kwargs[mellon_inp_name].pop(inp_name) + else: + inp = kwargs.pop(mellon_inp_name) + if inp is not None: + params_run[inp_name] = inp + + return_output_names = list(self.config.output_params.keys()) + + return params_components, params_run, return_output_names + + def execute(self, **kwargs): + params_components, params_run, return_output_names = self.process_inputs(**kwargs) + self.blocks.loader.update(**params_components) + output = self.blocks.run(**params_run, output=return_output_names) + return output From f16e9c78078021b212a38d1ec56b8f2d52f707f4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 10 Jun 2025 23:10:17 +0200 Subject: [PATCH 49/54] add --- .../modular_pipeline_utils.py | 19 +++++ .../stable_diffusion_xl/denoise.py | 76 ++++++++++++++++--- .../modular_pipeline_block_mappings.py | 32 ++++---- 3 files changed, 101 insertions(+), 26 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 6d6704f4eb38..a6ca13dbff26 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -19,11 +19,30 @@ from ..utils.import_utils import is_torch_available from ..configuration_utils import FrozenDict, ConfigMixin +from collections import OrderedDict if is_torch_available(): import torch +class InsertableOrderedDict(OrderedDict): + def insert(self, key, value, index): + items = list(self.items()) + + # Remove key if it already exists to avoid duplicates + items = [(k, v) for k, v in items if k != key] + + # Insert at the specified index + items.insert(index, (key, value)) + + # Clear and update self + self.clear() + self.update(items) + + # Return self for method chaining + return self + + # YiYi TODO: # 1. validate the dataclass fields # 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index bc567a6b034f..4d7ab12cf009 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -54,7 +54,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" + return "step within the denoising loop that prepare the latent input for the denoiser. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" @property @@ -89,7 +89,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" + return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" @property @@ -165,7 +165,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: return ( - "Step within the denoising loop that denoise the latents with guidance" + "Step within the denoising loop that denoise the latents with guidance. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" ) @property @@ -269,7 +269,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step within the denoising loop that denoise the latents with guidance (with controlnet). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -458,7 +458,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step within the denoising loop that update the latents. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -521,7 +521,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step within the denoising loop that update the latents (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -622,7 +622,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): @property def description(self) -> str: return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + "Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes" ) @property @@ -683,21 +683,52 @@ class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + ) + # control_cond class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + ) # mask class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] - + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only). " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + ) # control_cond + mask class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] - + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + ) # all task without controlnet @@ -706,18 +737,45 @@ class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): block_names = ["inpaint_denoise", "denoise"] block_trigger_inputs = ["mask", None] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when no mask is provided." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided." + ) + # all task with controlnet class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] block_trigger_inputs = ["mask", None] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when no mask is provided." + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + ) + # all task with or without controlnet class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_denoise", "denoise"] block_trigger_inputs = ["controlnet_cond", None] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." + ) + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py index 6d909ab5a4a0..00cd5ca3735a 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict +from ..modular_pipeline_utils import InsertableOrderedDict # Import all the necessary block classes from .denoise import ( StableDiffusionXLAutoDenoiseStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLControlNetDenoiseStep + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseLoop, + StableDiffusionXLInpaintDenoiseLoop ) from .before_denoise import ( StableDiffusionXLAutoBeforeDenoiseStep, @@ -50,56 +51,53 @@ # YiYi notes: comment out for now, work on this later # block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ +TEXT2IMAGE_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLSetTimestepsStep), ("prepare_latents", StableDiffusionXLPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), + ("denoise", StableDiffusionXLDenoiseLoop), ("decode", StableDiffusionXLDecodeStep) ]) -IMAGE2IMAGE_BLOCKS = OrderedDict([ +IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLVaeEncoderStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), + ("denoise", StableDiffusionXLDenoiseLoop), ("decode", StableDiffusionXLDecodeStep) ]) -INPAINT_BLOCKS = OrderedDict([ +INPAINT_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), + ("denoise", StableDiffusionXLInpaintDenoiseLoop), ("decode", StableDiffusionXLInpaintDecodeStep) ]) -CONTROLNET_BLOCKS = OrderedDict([ +CONTROLNET_BLOCKS = InsertableOrderedDict([ ("controlnet_input", StableDiffusionXLControlNetInputStep), ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) -CONTROLNET_UNION_BLOCKS = OrderedDict([ +CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([ ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) -IP_ADAPTER_BLOCKS = OrderedDict([ +IP_ADAPTER_BLOCKS = InsertableOrderedDict([ ("ip_adapter", StableDiffusionXLIPAdapterStep), ]) -AUTO_BLOCKS = OrderedDict([ +AUTO_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), @@ -108,7 +106,7 @@ ("decode", StableDiffusionXLAutoDecodeStep) ]) -AUTO_CORE_BLOCKS = OrderedDict([ +AUTO_CORE_BLOCKS = InsertableOrderedDict([ ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), ("denoise", StableDiffusionXLAutoDenoiseStep), ]) From cb6d5fed19ce4672857d6dfbf95ba2848feea5b5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 18 Jun 2025 10:11:22 +0200 Subject: [PATCH 50/54] refator based on dhruv's feedbacks --- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/modular_pipeline.py | 259 ++++++++++-------- src/diffusers/modular_pipelines/node_utils.py | 4 +- 3 files changed, 146 insertions(+), 121 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index cb2ed78ce360..8a23219761eb 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -23,7 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["modular_pipeline"] = [ - "ModularPipelineMixin", + "ModularPipelineBlocks", "PipelineBlock", "AutoPipelineBlocks", "SequentialPipelineBlocks", @@ -53,7 +53,7 @@ BlockState, LoopSequentialPipelineBlocks, ModularLoader, - ModularPipelineMixin, + ModularPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3136c3bb11f1..5a93a2995180 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -243,8 +243,7 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" - -class ModularPipelineMixin(ConfigMixin): +class ModularPipelineBlocks(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ @@ -305,13 +304,10 @@ def from_pretrained( } return block_cls(**block_kwargs) - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. """ - - # Import components loader (it is model-specific class) loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) @@ -322,98 +318,12 @@ def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, c # Create the loader with the updated specs specs = component_specs + config_specs - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) - - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - if not hasattr(self, "loader"): - logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") - self.loader = None - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for expected_input_param in self.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediates_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.intermediates_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise + loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + modular_pipeline = ModularPipeline(blocks=self, loader=loader) + return modular_pipeline - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - @torch.compiler.disable - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - -class PipelineBlock(ModularPipelineMixin): +class PipelineBlock(ModularPipelineBlocks): model_name = None @@ -680,7 +590,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -class AutoPipelineBlocks(ModularPipelineMixin): +class AutoPipelineBlocks(ModularPipelineBlocks): """ A class that automatically selects a block to run based on the inputs. @@ -969,7 +879,8 @@ def doc(self): expected_configs=self.expected_configs ) -class SequentialPipelineBlocks(ModularPipelineMixin): + +class SequentialPipelineBlocks(ModularPipelineBlocks): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ @@ -1009,15 +920,24 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. Args: - blocks_dict: Dictionary mapping block names to block instances + blocks_dict: Dictionary mapping block names to block classes or instances Returns: A new SequentialPipelineBlocks instance """ instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict + + # Create instances if classes are provided + blocks = {} + for name, block in blocks_dict.items(): + if inspect.isclass(block): + blocks[name] = block() + else: + blocks[name] = block + + instance.block_classes = [block.__class__ for block in blocks.values()] + instance.block_names = list(blocks.keys()) + instance.blocks = blocks return instance def __init__(self): @@ -1330,7 +1250,7 @@ def doc(self): ) #YiYi TODO: __repr__ -class LoopSequentialPipelineBlocks(ModularPipelineMixin): +class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. """ @@ -1694,7 +1614,24 @@ def __repr__(self): return result + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs # YiYi TODO: @@ -1889,19 +1826,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") @property def dtype(self) -> torch.dtype: @@ -2197,4 +2121,105 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) + + +class ModularPipeline: + """ + Base class for all Modular pipelines. + + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline + loader: ModularLoader, the loader to be used in the pipeline + """ + + def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): + self.blocks = blocks + self.loader = loader + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self.blocks(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + + def load_components(self, component_names: Optional[List[str]] = None, **kwargs): + self.loader.load(component_names=component_names, **kwargs) + + def update_components(self, **kwargs): + self.loader.update(**kwargs) + + def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs) + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs) + return ModularPipeline(blocks=blocks, loader=loader) + + def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): + self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + + @property + def doc(self): + return self.blocks.doc \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 9ee9c069277d..5f5e1c6c782d 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,5 +1,5 @@ from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks from .modular_pipeline_utils import InputParam, OutputParam from ..image_processor import PipelineImageInput from pathlib import Path @@ -202,7 +202,7 @@ def from_pretrained( trust_remote_code: Optional[bool] = None, **kwargs, ): - blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) return cls(blocks, **kwargs) def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): From 58e9565719700a83071c2eb4a4264641ced852d1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 19 Jun 2025 02:24:51 +0200 Subject: [PATCH 51/54] update doc format for kwargs_type --- src/diffusers/modular_pipelines/modular_pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index a6ca13dbff26..ced059551f9a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -421,7 +421,9 @@ def wrap_text(text, indent, max_length): for param in params: # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" + # YiYi Notes: remove this line if we remove kwargs_type + name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name + param_str = f"{param_indent}{name} (`{type_str}`" # Add optional tag and default value if parameter is an InputParam and optional if hasattr(param, "required"): From de631947cc1620a2edf771119754f212093eb734 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 19 Jun 2025 04:45:20 +0200 Subject: [PATCH 52/54] up --- src/diffusers/modular_pipelines/modular_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5a93a2995180..1b4a606d84a3 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -304,6 +304,7 @@ def from_pretrained( } return block_cls(**block_kwargs) + def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. @@ -2137,6 +2138,10 @@ def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): self.blocks = blocks self.loader = loader + def __repr__(self): + blocks_class = self.blocks.__class__.__name__ + loader_class = self.loader.__class__.__name__ + return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})" @property def default_call_parameters(self) -> Dict[str, Any]: From 8423652b357b15cba540a7553e03d33ccc5a3a2f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 19 Jun 2025 05:30:18 +0200 Subject: [PATCH 53/54] updatee modular_pipeline.from_pretrained, modular_repo ->pretrained_model_name_or_path --- src/diffusers/__init__.py | 4 ++++ src/diffusers/modular_pipelines/__init__.py | 2 ++ .../modular_pipelines/modular_pipeline.py | 20 ++++++++++--------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7a3de0b95747..d78b759c85c1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -264,6 +264,8 @@ _import_structure["modular_pipelines"].extend( [ "ModularLoader", + "ModularPipeline", + "ModularPipelineBlocks", "ComponentSpec", "ComponentsManager", ] @@ -894,6 +896,8 @@ ) from .modular_pipelines import ( ModularLoader, + ModularPipeline, + ModularPipelineBlocks, ComponentSpec, ComponentsManager, ) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 8a23219761eb..4499634d9fbd 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["modular_pipeline"] = [ "ModularPipelineBlocks", + "ModularPipeline", "PipelineBlock", "AutoPipelineBlocks", "SequentialPipelineBlocks", @@ -54,6 +55,7 @@ LoopSequentialPipelineBlocks, ModularLoader, ModularPipelineBlocks, + ModularPipeline, PipelineBlock, PipelineState, SequentialPipelineBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 1b4a606d84a3..196687e2d0c5 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -305,7 +305,7 @@ def from_pretrained( return block_cls(**block_kwargs) - def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. """ @@ -319,7 +319,7 @@ def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, # Create the loader with the updated specs specs = component_specs + config_specs - loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection) modular_pipeline = ModularPipeline(blocks=self, loader=loader) return modular_pipeline @@ -1748,7 +1748,7 @@ def register_components(self, **kwargs): # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): """ Initialize the loader with a list of component specs and config specs. """ @@ -1762,8 +1762,8 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: } # update component_specs and config_specs from modular_repo - if modular_repo is not None: - config_dict = self.load_config(modular_repo, **kwargs) + if pretrained_model_name_or_path is not None: + config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) for name, value in config_dict.items(): # only update component_spec for from_pretrained components @@ -2215,10 +2215,12 @@ def load_components(self, component_names: Optional[List[str]] = None, **kwargs) def update_components(self, **kwargs): self.loader.update(**kwargs) - def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs) - blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs) - return ModularPipeline(blocks=blocks, loader=loader) + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs) + return pipeline def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) From 79be5a1b808027cff0970eaf1cd73cb6ff903068 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jun 2025 09:21:39 +0530 Subject: [PATCH 54/54] save_pretrained for serializing config. (#11603) * save_pretrained for serializing config. * remove pushtohub * diffusers-cli rough --------- Co-authored-by: YiYi Xu --- src/diffusers/commands/custom_blocks.py | 133 ++++++++++++++++++ src/diffusers/commands/diffusers_cli.py | 2 + .../modular_pipelines/modular_pipeline.py | 15 ++ 3 files changed, 150 insertions(+) create mode 100644 src/diffusers/commands/custom_blocks.py diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py new file mode 100644 index 000000000000..d2f2de3a8f9a --- /dev/null +++ b/src/diffusers/commands/custom_blocks.py @@ -0,0 +1,133 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + TODO +""" + +import ast +from argparse import ArgumentParser, Namespace +from pathlib import Path +import importlib.util +import os +from ..utils import logging +from . import BaseDiffusersCLICommand + + +EXPECTED_PARENT_CLASSES = ["PipelineBlock"] +CONFIG = "config.json" + +def conversion_command_factory(args: Namespace): + return CustomBlocksCommand(args.block_module_name, args.block_class_name) + + +class CustomBlocksCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + conversion_parser = parser.add_parser("custom_blocks") + conversion_parser.add_argument( + "--block_module_name", + type=str, + default="block.py", + help="Module filename in which the custom block will be implemented.", + ) + conversion_parser.add_argument( + "--block_class_name", type=str, default=None, help="Name of the custom block. If provided None, we will try to infer it." + ) + conversion_parser.set_defaults(func=conversion_command_factory) + + def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): + self.logger = logging.get_logger("diffusers-cli/custom_blocks") + self.block_module_name = Path(block_module_name) + self.block_class_name = block_class_name + + def run(self): + # determine the block to be saved. + out = self._get_class_names(self.block_module_name) + classes_found = list({cls for cls, _ in out}) + + if self.block_class_name is not None: + child_class, parent_class = self._choose_block(out, self.block_class_name) + if child_class is None and parent_class is None: + raise ValueError( + "`block_class_name` could not be retrieved. Available classes from " + f"{self.block_module_name}:\n{classes_found}" + ) + else: + self.logger.info( + f"Found classes: {classes_found} will be using {classes_found[0]}. " + "If this needs to be changed, re-run the command specifying `block_class_name`." + ) + child_class, parent_class = out[0][0], out[0][1] + + # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. + # the user is responsible for running it, so I guess that is safe? + module_name = f"__dynamic__{self.block_module_name.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + getattr(module, child_class)().save_pretrained(os.getcwd()) + + # or, we could create it manually. + # automap = self._create_automap(parent_class=parent_class, child_class=child_class) + # with open(CONFIG, "w") as f: + # json.dump(automap, f) + with open("requirements.txt", "w") as f: + f.write("") + + def _choose_block(self, candidates, chosen=None): + for cls, base in candidates: + if cls == chosen: + return cls, base + return None, None + + def _get_class_names(self, file_path): + source = file_path.read_text(encoding="utf-8") + try: + tree = ast.parse(source, filename=file_path) + except SyntaxError as e: + raise ValueError(f"Could not parse {file_path!r}: {e}") from e + + results: list[tuple[str, str]] = [] + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + + # extract all base names for this class + base_names = [ + bname for b in node.bases + if (bname := self._get_base_name(b)) is not None + ] + + # for each allowed base that appears in the class's bases, emit a tuple + for allowed in EXPECTED_PARENT_CLASSES: + if allowed in base_names: + results.append((node.name, allowed)) + + return results + + def _get_base_name(self, node: ast.expr): + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + val = self._get_base_name(node.value) + return f"{val}.{node.attr}" if val else node.attr + return None + + def _create_automap(self, parent_class, child_class): + module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] + auto_map = {f"{parent_class}": f"{module}.{child_class}"} + return {"auto_map": auto_map} + diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py index f582c3bcd0df..cdc7dad166f0 100644 --- a/src/diffusers/commands/diffusers_cli.py +++ b/src/diffusers/commands/diffusers_cli.py @@ -17,6 +17,7 @@ from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand +from .custom_blocks import CustomBlocksCommand def main(): @@ -26,6 +27,7 @@ def main(): # Register commands EnvironmentCommand.register_subcommand(commands_parser) FP16SafetensorsCommand.register_subcommand(commands_parser) + CustomBlocksCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 196687e2d0c5..84b9b594d758 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -529,6 +529,21 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(param_name, param, input_param.kwargs_type) + def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): + # TODO: factor out this logic. + cls_name = self.__class__.__name__ + + full_mod = type(self).__module__ + module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + auto_map = {f"{parent_module}": f"{module}.{cls_name}"} + _component_names = [c.name for c in self.expected_components] + + self.register_to_config(auto_map=auto_map, _component_names=_component_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + self._internal_dict = FrozenDict(config) + def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """