From 2fd4508242456e928f30896527368ecab5012016 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 10:19:59 +0000 Subject: [PATCH 01/23] Check correct model type is passed to `from_pretrained` --- src/diffusers/pipelines/pipeline_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..70ad163c2c22 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -833,6 +833,13 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): + if key not in passed_class_obj: + continue + class_name = passed_class_obj[key].__class__.__name__ + if class_name != expected_class_name: + raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") + # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( From 185a78f294c3e1254cdf6f9a972ac7ea72274262 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 10:42:27 +0000 Subject: [PATCH 02/23] Flax, skip scheduler --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 70ad163c2c22..78b283ea6973 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -834,9 +834,10 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): - if key not in passed_class_obj: + if key not in passed_class_obj or key == "scheduler": continue class_name = passed_class_obj[key].__class__.__name__ + class_name = class_name[4:] if class_name.startswith("Flax") else class_name if class_name != expected_class_name: raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") From 679c18c973c5c5d20eaf09589ed447a9e9c6f0e4 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 10:42:38 +0000 Subject: [PATCH 03/23] test_wrong_model --- tests/pipelines/test_pipelines.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 43b01c40f5bb..1b48dee59b16 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1802,6 +1802,17 @@ def test_pipe_same_device_id_offload(self): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + def test_wrong_model(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + with self.assertRaises(ValueError) as error_context: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer + ) + + assert "Expected" in str(error_context.exception) + assert "text_encoder" in str(error_context.exception) + assert "CLIPTokenizer" in str(error_context.exception) + @slow @require_torch_gpu From 6aad7a799a97e1f612e4ffa352c6272cddbff7d3 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:24:43 +0000 Subject: [PATCH 04/23] Fix for scheduler --- src/diffusers/pipelines/pipeline_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 78b283ea6973..b4b4680486ef 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import fnmatch import importlib import inspect @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + expected_types = pipeline_class._get_signature_types() passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -832,13 +834,21 @@ def load_module(name, value): return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + scheduler_types = expected_types["scheduler"][0] + if isinstance(scheduler_types, enum.EnumType): + scheduler_types = list(scheduler_types) + else: + scheduler_types = [str(scheduler_types)] + scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): - if key not in passed_class_obj or key == "scheduler": + if key not in passed_class_obj: continue class_name = passed_class_obj[key].__class__.__name__ class_name = class_name[4:] if class_name.startswith("Flax") else class_name - if class_name != expected_class_name: + if key == "scheduler" and class_name not in scheduler_types: + raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") + elif key != "scheduler" and class_name != expected_class_name: raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") # Special case: safety_checker must be loaded separately when using `from_flax` From c1db3bd49b575f3839b4ff6194c4e76063281b4b Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:27:40 +0000 Subject: [PATCH 05/23] Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul --- tests/pipelines/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 1b48dee59b16..434d24852656 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1811,7 +1811,7 @@ def test_wrong_model(self): assert "Expected" in str(error_context.exception) assert "text_encoder" in str(error_context.exception) - assert "CLIPTokenizer" in str(error_context.exception) + assert f"{tokenizer.__class__.__name}" in str(error_context.exception) @slow From b8fa81a82268ca8704cbba461dfca0702ee34bfe Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:36:58 +0000 Subject: [PATCH 06/23] EnumMeta --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b4b4680486ef..7e67c1b7f759 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -835,7 +835,7 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} scheduler_types = expected_types["scheduler"][0] - if isinstance(scheduler_types, enum.EnumType): + if isinstance(scheduler_types, enum.EnumMeta): scheduler_types = list(scheduler_types) else: scheduler_types = [str(scheduler_types)] From 44f24a47db702cf5e124ea6aed22ce6b764cd665 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:52:04 +0000 Subject: [PATCH 07/23] Flax --- src/diffusers/pipelines/pipeline_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 7e67c1b7f759..0c4450f8eff5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -846,6 +846,7 @@ def load_module(name, value): continue class_name = passed_class_obj[key].__class__.__name__ class_name = class_name[4:] if class_name.startswith("Flax") else class_name + expected_class_name = expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name if key == "scheduler" and class_name not in scheduler_types: raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") elif key != "scheduler" and class_name != expected_class_name: From 5af8c7f2f7ab26755c377425b88daee4d88107ed Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:54:14 +0000 Subject: [PATCH 08/23] scheduler in expected types --- src/diffusers/pipelines/pipeline_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0c4450f8eff5..3bb631d46852 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -834,12 +834,14 @@ def load_module(name, value): return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - scheduler_types = expected_types["scheduler"][0] - if isinstance(scheduler_types, enum.EnumMeta): - scheduler_types = list(scheduler_types) - else: - scheduler_types = [str(scheduler_types)] - scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] + scheduler_types = None + if "scheduler" in expected_types: + scheduler_types = expected_types["scheduler"][0] + if isinstance(scheduler_types, enum.EnumMeta): + scheduler_types = list(scheduler_types) + else: + scheduler_types = [str(scheduler_types)] + scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): if key not in passed_class_obj: @@ -847,7 +849,7 @@ def load_module(name, value): class_name = passed_class_obj[key].__class__.__name__ class_name = class_name[4:] if class_name.startswith("Flax") else class_name expected_class_name = expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name - if key == "scheduler" and class_name not in scheduler_types: + if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types: raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") elif key != "scheduler" and class_name != expected_class_name: raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") From baea1415a4ddae1bcab61a93c975e87e9998b1ba Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 11:55:12 +0000 Subject: [PATCH 09/23] make --- src/diffusers/pipelines/pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3bb631d46852..8dd5d7d22bcf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -848,7 +848,9 @@ def load_module(name, value): continue class_name = passed_class_obj[key].__class__.__name__ class_name = class_name[4:] if class_name.startswith("Flax") else class_name - expected_class_name = expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name + expected_class_name = ( + expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name + ) if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types: raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") elif key != "scheduler" and class_name != expected_class_name: From c5e1e2ddf8ec61b5557d332798078296230d4f82 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 12:49:29 +0000 Subject: [PATCH 10/23] type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' --- tests/pipelines/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 434d24852656..1b48dee59b16 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1811,7 +1811,7 @@ def test_wrong_model(self): assert "Expected" in str(error_context.exception) assert "text_encoder" in str(error_context.exception) - assert f"{tokenizer.__class__.__name}" in str(error_context.exception) + assert "CLIPTokenizer" in str(error_context.exception) @slow From dba12b600203d15f95d5facf11786ddcb8a7b0de Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 12:49:46 +0000 Subject: [PATCH 11/23] support union --- src/diffusers/pipelines/pipeline_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8dd5d7d22bcf..89d83c8504c3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, _UnionGenericAlias import numpy as np import PIL.Image @@ -836,11 +836,12 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} scheduler_types = None if "scheduler" in expected_types: - scheduler_types = expected_types["scheduler"][0] - if isinstance(scheduler_types, enum.EnumMeta): - scheduler_types = list(scheduler_types) - else: - scheduler_types = [str(scheduler_types)] + scheduler_types = [] + for scheduler_type in expected_types["scheduler"]: + if isinstance(scheduler_type, enum.EnumMeta): + scheduler_types.extend(list(scheduler_type)) + else: + scheduler_types.extend([str(scheduler_type)]) scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): From 99b0f92e6b60e25704419aaefcc7fc589757fad6 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 12:49:59 +0000 Subject: [PATCH 12/23] fix typing in kandinsky --- src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py | 4 ++-- .../pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py | 4 ++-- .../pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py index 471db61556f5..c912344bd609 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -17,7 +17,7 @@ import torch from ...models import UNet2DConditionModel, VQModel -from ...schedulers import DDPMScheduler +from ...schedulers import DDPMScheduler, UnCLIPScheduler from ...utils import deprecate, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -83,7 +83,7 @@ class KandinskyV22Pipeline(DiffusionPipeline): def __init__( self, unet: UNet2DConditionModel, - scheduler: DDPMScheduler, + scheduler: Union[DDPMScheduler, UnCLIPScheduler], movq: VQModel, ): super().__init__() diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index 68334fef3811..9f054bd51198 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -140,13 +140,13 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): def __init__( self, unet: UNet2DConditionModel, - scheduler: DDPMScheduler, + scheduler: Union[DDPMScheduler, UnCLIPScheduler], movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, - prior_scheduler: UnCLIPScheduler, + prior_scheduler: Union[DDPMScheduler, UnCLIPScheduler], prior_image_processor: CLIPImageProcessor, ): super().__init__() diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py index f2134b22b40b..ec1057500e21 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -5,7 +5,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer -from ...schedulers import UnCLIPScheduler +from ...schedulers import DDPMScheduler, UnCLIPScheduler from ...utils import ( logging, replace_example_docstring, @@ -114,7 +114,7 @@ def __init__( image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, - scheduler: UnCLIPScheduler, + scheduler: Union[DDPMScheduler, UnCLIPScheduler], image_processor: CLIPImageProcessor, ): super().__init__() From 3a43c8a0b4f5d1e2b0f8daf93e1b79db7a82c58b Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 12:50:13 +0000 Subject: [PATCH 13/23] make --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 89d83c8504c3..5fe0a33dcdcf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, _UnionGenericAlias +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image From 803e33f22e574bd5a226aafb7bd75dcf5b1c3950 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 13:03:09 +0000 Subject: [PATCH 14/23] add LCMScheduler --- src/diffusers/schedulers/scheduling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f20224b19009..ce6d04a1add2 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -46,6 +46,7 @@ class KarrasDiffusionSchedulers(Enum): UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 EDMEulerScheduler = 15 + LCMScheduler = 16 AysSchedules = { From c81415b8cebc7b0c38393e0fee84e4aaebc16352 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 13:20:17 +0000 Subject: [PATCH 15/23] 'LCMScheduler' object has no attribute 'sigmas' --- .../stable_diffusion_2/test_stable_diffusion_latent_upscale.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index 134175bdaffe..67b9d4a4dfbf 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -250,6 +250,7 @@ def test_karras_schedulers_shape(self): "KDPM2AncestralDiscreteScheduler", "DPMSolverSDEScheduler", "EDMEulerScheduler", + "LCMScheduler", ] components = self.get_dummy_components() pipe = self.pipeline_class(**components) From 13a824e4a66b8c6233b10acad5009a112e9beb04 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 13:37:00 +0000 Subject: [PATCH 16/23] tests for wrong scheduler --- tests/pipelines/test_pipelines.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 1b48dee59b16..3567f4cfc055 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -47,6 +47,8 @@ DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, LMSDiscreteScheduler, ModelMixin, PNDMScheduler, @@ -1813,6 +1815,28 @@ def test_wrong_model(self): assert "text_encoder" in str(error_context.exception) assert "CLIPTokenizer" in str(error_context.exception) + def test_wrong_model_scheduler_type(self): + scheduler = EulerDiscreteScheduler.from_pretrained("hf-internal-testing/tiny-flux-pipe", subfolder="scheduler") + with self.assertRaises(ValueError) as error_context: + _ = FluxPipeline.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", scheduler=scheduler + ) + + assert "Expected" in str(error_context.exception) + assert "scheduler" in str(error_context.exception) + assert "EulerDiscreteScheduler" in str(error_context.exception) + + def test_wrong_model_scheduler_enum(self): + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler") + with self.assertRaises(ValueError) as error_context: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", scheduler=scheduler + ) + + assert "Expected" in str(error_context.exception) + assert "scheduler" in str(error_context.exception) + assert "FlowMatchEulerDiscreteScheduler" in str(error_context.exception) + @slow @require_torch_gpu From 56790675d53d58eeca111a1325d8b803f614c8e2 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 13:37:13 +0000 Subject: [PATCH 17/23] make --- tests/pipelines/test_pipelines.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 3567f4cfc055..2b03c61f08bb 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1818,16 +1818,16 @@ def test_wrong_model(self): def test_wrong_model_scheduler_type(self): scheduler = EulerDiscreteScheduler.from_pretrained("hf-internal-testing/tiny-flux-pipe", subfolder="scheduler") with self.assertRaises(ValueError) as error_context: - _ = FluxPipeline.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", scheduler=scheduler - ) + _ = FluxPipeline.from_pretrained("hf-internal-testing/tiny-flux-pipe", scheduler=scheduler) assert "Expected" in str(error_context.exception) assert "scheduler" in str(error_context.exception) assert "EulerDiscreteScheduler" in str(error_context.exception) def test_wrong_model_scheduler_enum(self): - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler" + ) with self.assertRaises(ValueError) as error_context: _ = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/diffusers-stable-diffusion-tiny-all", scheduler=scheduler From 24d79a38daa7fb38965ab61bc0577ea6808ca5a5 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 15:56:45 +0000 Subject: [PATCH 18/23] update --- src/diffusers/pipelines/pipeline_utils.py | 43 ++++++++++++----------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5fe0a33dcdcf..d978046421fc 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -46,7 +46,7 @@ from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..quantizers.bitsandbytes.utils import _check_bnb_status -from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, @@ -834,28 +834,31 @@ def load_module(name, value): return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - scheduler_types = None - if "scheduler" in expected_types: - scheduler_types = [] - for scheduler_type in expected_types["scheduler"]: - if isinstance(scheduler_type, enum.EnumMeta): - scheduler_types.extend(list(scheduler_type)) - else: - scheduler_types.extend([str(scheduler_type)]) - scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] - for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): + for key in init_dict.keys(): if key not in passed_class_obj: continue - class_name = passed_class_obj[key].__class__.__name__ - class_name = class_name[4:] if class_name.startswith("Flax") else class_name - expected_class_name = ( - expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name - ) - if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types: - raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") - elif key != "scheduler" and class_name != expected_class_name: - raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") + + class_obj = passed_class_obj[key] + _expected_class_types = [] + for expected_type in expected_types[key]: + if isinstance(expected_type, enum.EnumMeta): + _expected_class_types.extend(expected_type.__members__.keys()) + else: + _expected_class_types.append(expected_type.__name__) + + _is_valid_type = class_obj.__class__.__name__ in _expected_class_types + if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: + _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) + _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ + if _requires_flow_match and not _is_flow_match: + raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") + elif not _requires_flow_match and _is_flow_match: + raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") + elif not _is_valid_type: + raise ValueError( + f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." + ) # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: From 3f841d53c90df36dac203e982df9952d185275be Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 21:31:08 +0000 Subject: [PATCH 19/23] warning --- src/diffusers/pipelines/pipeline_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d978046421fc..2a19feb4fdec 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -852,11 +852,11 @@ def load_module(name, value): _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ if _requires_flow_match and not _is_flow_match: - raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") + logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") elif not _requires_flow_match and _is_flow_match: - raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") + logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") elif not _is_valid_type: - raise ValueError( + logger.warning( f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." ) From f18687fdb41e87e2b8f7e5b4a491c41afeec8588 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 22:02:07 +0000 Subject: [PATCH 20/23] tests --- tests/pipelines/test_pipelines.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 2b03c61f08bb..40b55ee389f7 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1811,31 +1811,34 @@ def test_wrong_model(self): "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer ) - assert "Expected" in str(error_context.exception) - assert "text_encoder" in str(error_context.exception) - assert "CLIPTokenizer" in str(error_context.exception) + assert "is of type" in str(error_context.exception) + assert "but should be" in str(error_context.exception) def test_wrong_model_scheduler_type(self): scheduler = EulerDiscreteScheduler.from_pretrained("hf-internal-testing/tiny-flux-pipe", subfolder="scheduler") - with self.assertRaises(ValueError) as error_context: + with self.assertLogs( + logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" + ) as warning_context: _ = FluxPipeline.from_pretrained("hf-internal-testing/tiny-flux-pipe", scheduler=scheduler) - assert "Expected" in str(error_context.exception) - assert "scheduler" in str(error_context.exception) - assert "EulerDiscreteScheduler" in str(error_context.exception) + assert any("Expected" in message for message in warning_context.output) + assert any("scheduler" in message for message in warning_context.output) + assert any("EulerDiscreteScheduler" in message for message in warning_context.output) def test_wrong_model_scheduler_enum(self): scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( "hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler" ) - with self.assertRaises(ValueError) as error_context: + with self.assertLogs( + logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" + ) as warning_context: _ = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/diffusers-stable-diffusion-tiny-all", scheduler=scheduler ) - assert "Expected" in str(error_context.exception) - assert "scheduler" in str(error_context.exception) - assert "FlowMatchEulerDiscreteScheduler" in str(error_context.exception) + assert any("Expected" in message for message in warning_context.output) + assert any("scheduler" in message for message in warning_context.output) + assert any("FlowMatchEulerDiscreteScheduler" in message for message in warning_context.output) @slow From 87f8f0348c2827792469148fd5127d2483f14c47 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 12:07:07 +0000 Subject: [PATCH 21/23] Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2a19feb4fdec..475c99394a34 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -848,7 +848,7 @@ def load_module(name, value): _expected_class_types.append(expected_type.__name__) _is_valid_type = class_obj.__class__.__name__ in _expected_class_types - if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: + if (isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin)) and not _is_valid_type: _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ if _requires_flow_match and not _is_flow_match: From 56ac8b4bd74d9e3366da78d315accf2f0b83707c Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 12:09:21 +0000 Subject: [PATCH 22/23] import FlaxSchedulerMixin --- src/diffusers/pipelines/pipeline_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 475c99394a34..164a4cdddf4b 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -47,6 +47,7 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from ..schedulers.scheduling_utils_flax import FlaxSchedulerMixin from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, @@ -848,7 +849,9 @@ def load_module(name, value): _expected_class_types.append(expected_type.__name__) _is_valid_type = class_obj.__class__.__name__ in _expected_class_types - if (isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin)) and not _is_valid_type: + if ( + isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin) + ) and not _is_valid_type: _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ if _requires_flow_match and not _is_flow_match: From 87dcf54280a29fe63e89bcf3a6e8d00156cc2a73 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 22:03:23 +0000 Subject: [PATCH 23/23] skip scheduler --- .../kandinsky2_2/pipeline_kandinsky2_2.py | 4 +-- .../pipeline_kandinsky2_2_combined.py | 4 +-- .../pipeline_kandinsky2_2_prior.py | 4 +-- src/diffusers/pipelines/pipeline_utils.py | 16 +++-------- src/diffusers/schedulers/scheduling_utils.py | 1 - .../test_stable_diffusion_latent_upscale.py | 1 - tests/pipelines/test_pipelines.py | 28 ------------------- 7 files changed, 10 insertions(+), 48 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py index c912344bd609..471db61556f5 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -17,7 +17,7 @@ import torch from ...models import UNet2DConditionModel, VQModel -from ...schedulers import DDPMScheduler, UnCLIPScheduler +from ...schedulers import DDPMScheduler from ...utils import deprecate, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -83,7 +83,7 @@ class KandinskyV22Pipeline(DiffusionPipeline): def __init__( self, unet: UNet2DConditionModel, - scheduler: Union[DDPMScheduler, UnCLIPScheduler], + scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index 9f054bd51198..68334fef3811 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -140,13 +140,13 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): def __init__( self, unet: UNet2DConditionModel, - scheduler: Union[DDPMScheduler, UnCLIPScheduler], + scheduler: DDPMScheduler, movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, - prior_scheduler: Union[DDPMScheduler, UnCLIPScheduler], + prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py index ec1057500e21..f2134b22b40b 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -5,7 +5,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer -from ...schedulers import DDPMScheduler, UnCLIPScheduler +from ...schedulers import UnCLIPScheduler from ...utils import ( logging, replace_example_docstring, @@ -114,7 +114,7 @@ def __init__( image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, - scheduler: Union[DDPMScheduler, UnCLIPScheduler], + scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 164a4cdddf4b..c505c5a262a3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -46,8 +46,7 @@ from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..quantizers.bitsandbytes.utils import _check_bnb_status -from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin -from ..schedulers.scheduling_utils_flax import FlaxSchedulerMixin +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, @@ -839,6 +838,8 @@ def load_module(name, value): for key in init_dict.keys(): if key not in passed_class_obj: continue + if "scheduler" in key: + continue class_obj = passed_class_obj[key] _expected_class_types = [] @@ -849,16 +850,7 @@ def load_module(name, value): _expected_class_types.append(expected_type.__name__) _is_valid_type = class_obj.__class__.__name__ in _expected_class_types - if ( - isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin) - ) and not _is_valid_type: - _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) - _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ - if _requires_flow_match and not _is_flow_match: - logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") - elif not _requires_flow_match and _is_flow_match: - logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") - elif not _is_valid_type: + if not _is_valid_type: logger.warning( f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." ) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index ce6d04a1add2..f20224b19009 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -46,7 +46,6 @@ class KarrasDiffusionSchedulers(Enum): UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 EDMEulerScheduler = 15 - LCMScheduler = 16 AysSchedules = { diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index 67b9d4a4dfbf..134175bdaffe 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -250,7 +250,6 @@ def test_karras_schedulers_shape(self): "KDPM2AncestralDiscreteScheduler", "DPMSolverSDEScheduler", "EDMEulerScheduler", - "LCMScheduler", ] components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 40b55ee389f7..423c82e0602e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -47,8 +47,6 @@ DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, - FlowMatchEulerDiscreteScheduler, - FluxPipeline, LMSDiscreteScheduler, ModelMixin, PNDMScheduler, @@ -1814,32 +1812,6 @@ def test_wrong_model(self): assert "is of type" in str(error_context.exception) assert "but should be" in str(error_context.exception) - def test_wrong_model_scheduler_type(self): - scheduler = EulerDiscreteScheduler.from_pretrained("hf-internal-testing/tiny-flux-pipe", subfolder="scheduler") - with self.assertLogs( - logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" - ) as warning_context: - _ = FluxPipeline.from_pretrained("hf-internal-testing/tiny-flux-pipe", scheduler=scheduler) - - assert any("Expected" in message for message in warning_context.output) - assert any("scheduler" in message for message in warning_context.output) - assert any("EulerDiscreteScheduler" in message for message in warning_context.output) - - def test_wrong_model_scheduler_enum(self): - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler" - ) - with self.assertLogs( - logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" - ) as warning_context: - _ = StableDiffusionPipeline.from_pretrained( - "hf-internal-testing/diffusers-stable-diffusion-tiny-all", scheduler=scheduler - ) - - assert any("Expected" in message for message in warning_context.output) - assert any("scheduler" in message for message in warning_context.output) - assert any("FlowMatchEulerDiscreteScheduler" in message for message in warning_context.output) - @slow @require_torch_gpu