Skip to content

Commit 1908c47

Browse files
authored
Deprecate upcast_vae in SDXL based pipelines (#12619)
* update * update * Revert "update" This reverts commit 7390638. * Revert "update" This reverts commit 21a03f9. * update * update * update * update * update
1 parent 759ea58 commit 1908c47

File tree

43 files changed

+170
-771
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+170
-771
lines changed

examples/community/lpw_stable_diffusion_xl.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
TextualInversionLoaderMixin,
3030
)
3131
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
32-
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
3332
from diffusers.models.lora import adjust_lora_scale_text_encoder
3433
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
3534
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -1328,18 +1327,8 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d
13281327

13291328
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
13301329
def upcast_vae(self):
1331-
dtype = self.vae.dtype
1330+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
13321331
self.vae.to(dtype=torch.float32)
1333-
use_torch_2_0_or_xformers = isinstance(
1334-
self.vae.decoder.mid_block.attentions[0].processor,
1335-
(AttnProcessor2_0, XFormersAttnProcessor),
1336-
)
1337-
# if xformers or torch_2_0 is used attention block does not need
1338-
# to be in float32 which can save lots of memory
1339-
if use_torch_2_0_or_xformers:
1340-
self.vae.post_quant_conv.to(dtype)
1341-
self.vae.decoder.conv_in.to(dtype)
1342-
self.vae.decoder.mid_block.to(dtype)
13431332

13441333
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
13451334
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):

examples/community/mixture_tiling_sdxl.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,13 @@
3030
TextualInversionLoaderMixin,
3131
)
3232
from diffusers.models import AutoencoderKL, UNet2DConditionModel
33-
from diffusers.models.attention_processor import (
34-
AttnProcessor2_0,
35-
FusedAttnProcessor2_0,
36-
XFormersAttnProcessor,
37-
)
3833
from diffusers.models.lora import adjust_lora_scale_text_encoder
3934
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4035
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
4136
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
4237
from diffusers.utils import (
4338
USE_PEFT_BACKEND,
39+
deprecate,
4440
is_invisible_watermark_available,
4541
is_torch_xla_available,
4642
logging,
@@ -710,22 +706,8 @@ def _gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype):
710706
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
711707

712708
def upcast_vae(self):
713-
dtype = self.vae.dtype
709+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
714710
self.vae.to(dtype=torch.float32)
715-
use_torch_2_0_or_xformers = isinstance(
716-
self.vae.decoder.mid_block.attentions[0].processor,
717-
(
718-
AttnProcessor2_0,
719-
XFormersAttnProcessor,
720-
FusedAttnProcessor2_0,
721-
),
722-
)
723-
# if xformers or torch_2_0 is used attention block does not need
724-
# to be in float32 which can save lots of memory
725-
if use_torch_2_0_or_xformers:
726-
self.vae.post_quant_conv.to(dtype)
727-
self.vae.decoder.conv_in.to(dtype)
728-
self.vae.decoder.mid_block.to(dtype)
729711

730712
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
731713
def get_guidance_scale_embedding(

examples/community/mod_controlnet_tile_sr_sdxl.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,13 @@
3939
MultiControlNetModel,
4040
UNet2DConditionModel,
4141
)
42-
from diffusers.models.attention_processor import (
43-
AttnProcessor2_0,
44-
XFormersAttnProcessor,
45-
)
4642
from diffusers.models.lora import adjust_lora_scale_text_encoder
4743
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4844
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
4945
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
5046
from diffusers.utils import (
5147
USE_PEFT_BACKEND,
48+
deprecate,
5249
logging,
5350
replace_example_docstring,
5451
scale_lora_layers,
@@ -1220,23 +1217,9 @@ def prepare_tiles(
12201217

12211218
return tile_weights, tile_row_overlaps, tile_col_overlaps
12221219

1223-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
12241220
def upcast_vae(self):
1225-
dtype = self.vae.dtype
1221+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
12261222
self.vae.to(dtype=torch.float32)
1227-
use_torch_2_0_or_xformers = isinstance(
1228-
self.vae.decoder.mid_block.attentions[0].processor,
1229-
(
1230-
AttnProcessor2_0,
1231-
XFormersAttnProcessor,
1232-
),
1233-
)
1234-
# if xformers or torch_2_0 is used attention block does not need
1235-
# to be in float32 which can save lots of memory
1236-
if use_torch_2_0_or_xformers:
1237-
self.vae.post_quant_conv.to(dtype)
1238-
self.vae.decoder.conv_in.to(dtype)
1239-
self.vae.decoder.mid_block.to(dtype)
12401223

12411224
@property
12421225
def guidance_scale(self):

examples/community/pipeline_controlnet_xl_kolors.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@
4040
MultiControlNetModel,
4141
UNet2DConditionModel,
4242
)
43-
from diffusers.models.attention_processor import (
44-
AttnProcessor2_0,
45-
XFormersAttnProcessor,
46-
)
4743
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
4844
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4945
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -760,21 +756,8 @@ def _get_add_time_ids(
760756

761757
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
762758
def upcast_vae(self):
763-
dtype = self.vae.dtype
759+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
764760
self.vae.to(dtype=torch.float32)
765-
use_torch_2_0_or_xformers = isinstance(
766-
self.vae.decoder.mid_block.attentions[0].processor,
767-
(
768-
AttnProcessor2_0,
769-
XFormersAttnProcessor,
770-
),
771-
)
772-
# if xformers or torch_2_0 is used attention block does not need
773-
# to be in float32 which can save lots of memory
774-
if use_torch_2_0_or_xformers:
775-
self.vae.post_quant_conv.to(dtype)
776-
self.vae.decoder.conv_in.to(dtype)
777-
self.vae.decoder.mid_block.to(dtype)
778761

779762
@property
780763
def guidance_scale(self):

examples/community/pipeline_controlnet_xl_kolors_img2img.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@
4040
MultiControlNetModel,
4141
UNet2DConditionModel,
4242
)
43-
from diffusers.models.attention_processor import (
44-
AttnProcessor2_0,
45-
XFormersAttnProcessor,
46-
)
4743
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
4844
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4945
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -930,21 +926,8 @@ def _get_add_time_ids(
930926

931927
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
932928
def upcast_vae(self):
933-
dtype = self.vae.dtype
929+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
934930
self.vae.to(dtype=torch.float32)
935-
use_torch_2_0_or_xformers = isinstance(
936-
self.vae.decoder.mid_block.attentions[0].processor,
937-
(
938-
AttnProcessor2_0,
939-
XFormersAttnProcessor,
940-
),
941-
)
942-
# if xformers or torch_2_0 is used attention block does not need
943-
# to be in float32 which can save lots of memory
944-
if use_torch_2_0_or_xformers:
945-
self.vae.post_quant_conv.to(dtype)
946-
self.vae.decoder.conv_in.to(dtype)
947-
self.vae.decoder.mid_block.to(dtype)
948931

949932
@property
950933
def guidance_scale(self):

examples/community/pipeline_controlnet_xl_kolors_inpaint.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939
MultiControlNetModel,
4040
UNet2DConditionModel,
4141
)
42-
from diffusers.models.attention_processor import (
43-
AttnProcessor2_0,
44-
XFormersAttnProcessor,
45-
)
4642
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
4743
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4844
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -1006,21 +1002,8 @@ def _get_add_time_ids(
10061002

10071003
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
10081004
def upcast_vae(self):
1009-
dtype = self.vae.dtype
1005+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
10101006
self.vae.to(dtype=torch.float32)
1011-
use_torch_2_0_or_xformers = isinstance(
1012-
self.vae.decoder.mid_block.attentions[0].processor,
1013-
(
1014-
AttnProcessor2_0,
1015-
XFormersAttnProcessor,
1016-
),
1017-
)
1018-
# if xformers or torch_2_0 is used attention block does not need
1019-
# to be in float32 which can save lots of memory
1020-
if use_torch_2_0_or_xformers:
1021-
self.vae.post_quant_conv.to(dtype)
1022-
self.vae.decoder.conv_in.to(dtype)
1023-
self.vae.decoder.mid_block.to(dtype)
10241007

10251008
@property
10261009
def denoising_end(self):

examples/community/pipeline_demofusion_sdxl.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
TextualInversionLoaderMixin,
1717
)
1818
from diffusers.models import AutoencoderKL, UNet2DConditionModel
19-
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
2019
from diffusers.models.lora import adjust_lora_scale_text_encoder
2120
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
2221
from diffusers.schedulers import KarrasDiffusionSchedulers
2322
from diffusers.utils import (
23+
deprecate,
2424
is_accelerate_available,
2525
is_accelerate_version,
2626
is_invisible_watermark_available,
@@ -612,20 +612,9 @@ def tiled_decode(self, latents, current_height, current_width):
612612

613613
return image
614614

615-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
616615
def upcast_vae(self):
617-
dtype = self.vae.dtype
616+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
618617
self.vae.to(dtype=torch.float32)
619-
use_torch_2_0_or_xformers = isinstance(
620-
self.vae.decoder.mid_block.attentions[0].processor,
621-
(AttnProcessor2_0, XFormersAttnProcessor),
622-
)
623-
# if xformers or torch_2_0 is used attention block does not need
624-
# to be in float32 which can save lots of memory
625-
if use_torch_2_0_or_xformers:
626-
self.vae.post_quant_conv.to(dtype)
627-
self.vae.decoder.conv_in.to(dtype)
628-
self.vae.decoder.mid_block.to(dtype)
629618

630619
@torch.no_grad()
631620
@replace_example_docstring(EXAMPLE_DOC_STRING)

examples/community/pipeline_faithdiff_stable_diffusion_xl.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,6 @@
4040
UNet2DConditionLoadersMixin,
4141
)
4242
from diffusers.models import AutoencoderKL
43-
from diffusers.models.attention_processor import (
44-
AttnProcessor2_0,
45-
FusedAttnProcessor2_0,
46-
LoRAAttnProcessor2_0,
47-
LoRAXFormersAttnProcessor,
48-
XFormersAttnProcessor,
49-
)
5043
from diffusers.models.lora import adjust_lora_scale_text_encoder
5144
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
5245
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -1642,24 +1635,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
16421635
return latents
16431636

16441637
def upcast_vae(self):
1645-
dtype = self.vae.dtype
1638+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
16461639
self.vae.to(dtype=torch.float32)
1647-
use_torch_2_0_or_xformers = isinstance(
1648-
self.vae.decoder.mid_block.attentions[0].processor,
1649-
(
1650-
AttnProcessor2_0,
1651-
XFormersAttnProcessor,
1652-
LoRAXFormersAttnProcessor,
1653-
LoRAAttnProcessor2_0,
1654-
FusedAttnProcessor2_0,
1655-
),
1656-
)
1657-
# if xformers or torch_2_0 is used attention block does not need
1658-
# to be in float32 which can save lots of memory
1659-
if use_torch_2_0_or_xformers:
1660-
self.vae.post_quant_conv.to(dtype)
1661-
self.vae.decoder.conv_in.to(dtype)
1662-
self.vae.decoder.mid_block.to(dtype)
16631640

16641641
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
16651642
def get_guidance_scale_embedding(

examples/community/pipeline_kolors_differential_img2img.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
2323
from diffusers.loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
2424
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
25-
from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2625
from diffusers.pipelines.kolors.pipeline_output import KolorsPipelineOutput
2726
from diffusers.pipelines.kolors.text_encoder import ChatGLMModel
2827
from diffusers.pipelines.kolors.tokenizer import ChatGLMTokenizer
2928
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3029
from diffusers.schedulers import KarrasDiffusionSchedulers
31-
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
30+
from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
3231
from diffusers.utils.torch_utils import randn_tensor
3332

3433

@@ -709,24 +708,9 @@ def _get_add_time_ids(
709708
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
710709
return add_time_ids
711710

712-
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
713711
def upcast_vae(self):
714-
dtype = self.vae.dtype
712+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
715713
self.vae.to(dtype=torch.float32)
716-
use_torch_2_0_or_xformers = isinstance(
717-
self.vae.decoder.mid_block.attentions[0].processor,
718-
(
719-
AttnProcessor2_0,
720-
XFormersAttnProcessor,
721-
FusedAttnProcessor2_0,
722-
),
723-
)
724-
# if xformers or torch_2_0 is used attention block does not need
725-
# to be in float32 which can save lots of memory
726-
if use_torch_2_0_or_xformers:
727-
self.vae.post_quant_conv.to(dtype)
728-
self.vae.decoder.conv_in.to(dtype)
729-
self.vae.decoder.mid_block.to(dtype)
730714

731715
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
732716
def get_guidance_scale_embedding(

examples/community/pipeline_kolors_inpainting.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@
3232
TextualInversionLoaderMixin,
3333
)
3434
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
35-
from diffusers.models.attention_processor import (
36-
AttnProcessor2_0,
37-
LoRAAttnProcessor2_0,
38-
LoRAXFormersAttnProcessor,
39-
XFormersAttnProcessor,
40-
)
4135
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
4236
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4337
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -1008,23 +1002,8 @@ def _get_add_time_ids(
10081002

10091003
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
10101004
def upcast_vae(self):
1011-
dtype = self.vae.dtype
1005+
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
10121006
self.vae.to(dtype=torch.float32)
1013-
use_torch_2_0_or_xformers = isinstance(
1014-
self.vae.decoder.mid_block.attentions[0].processor,
1015-
(
1016-
AttnProcessor2_0,
1017-
XFormersAttnProcessor,
1018-
LoRAXFormersAttnProcessor,
1019-
LoRAAttnProcessor2_0,
1020-
),
1021-
)
1022-
# if xformers or torch_2_0 is used attention block does not need
1023-
# to be in float32 which can save lots of memory
1024-
if use_torch_2_0_or_xformers:
1025-
self.vae.post_quant_conv.to(dtype)
1026-
self.vae.decoder.conv_in.to(dtype)
1027-
self.vae.decoder.mid_block.to(dtype)
10281007

10291008
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
10301009
def get_guidance_scale_embedding(

0 commit comments

Comments
 (0)