From 199c240a3444016a36ff06da3a9b490994bb9320 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Oct 2024 03:46:22 +0200 Subject: [PATCH 01/33] update --- src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_allegro.py | 995 +++++++++++ src/diffusers/models/embeddings.py | 52 + src/diffusers/models/normalization.py | 36 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_allegro.py | 1586 +++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/allegro/__init__.py | 48 + .../pipelines/allegro/pipeline_allegro.py | 829 +++++++++ .../pipelines/allegro/pipeline_output.py | 23 + 12 files changed, 3583 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_allegro.py create mode 100644 src/diffusers/models/transformers/transformer_allegro.py create mode 100644 src/diffusers/pipelines/allegro/__init__.py create mode 100644 src/diffusers/pipelines/allegro/pipeline_allegro.py create mode 100644 src/diffusers/pipelines/allegro/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a1d126f3823b..dab0ee1db1a8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,10 +80,12 @@ "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", + "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", + "AllegroTransformer3DModel", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "ConsistencyDecoderVAE", @@ -237,6 +239,7 @@ else: _import_structure["pipelines"].extend( [ + "AllegroPipeline", "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", "AmusedImg2ImgPipeline", @@ -558,7 +561,9 @@ from .models import ( AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, + AllegroTransformer3DModel, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -697,6 +702,7 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( + AllegroPipeline, AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AmusedImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4dda8c36ba1c..310c35c4cb72 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -54,6 +55,7 @@ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -81,6 +83,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -98,6 +101,7 @@ from .modeling_utils import ModelMixin from .transformers import ( AuraFlowTransformer2DModel, + AllegroTransformer3DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, DiTTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ccf4552b2a5e..9628fe7f21b0 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,5 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py new file mode 100644 index 000000000000..2ec0855635b4 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -0,0 +1,995 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +import os +from typing import Dict, Optional, Tuple, Union +from einops import rearrange + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..modeling_outputs import AutoencoderKLOutput +from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from ..attention_processor import Attention +from ..resnet import ResnetBlock2D +from ..upsampling import Upsample2D +from ..downsampling import Downsample2D +from ..attention_processor import SpatialNorm + + +class TemporalConvBlock(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + spa_pad = int((spa_stride-1)*0.5) + temp_pad = 0 + self.temp_pad = temp_pad + + if down_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) + ) + elif up_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) + ) + else: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + self.down_sample = down_sample + self.up_sample = up_sample + + + def forward(self, hidden_states): + identity = hidden_states + + if self.down_sample: + identity = identity[:,:,::2] + elif self.up_sample: + hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) + hidden_states_new[:, :, 0::2] = hidden_states + hidden_states_new[:, :, 1::2] = hidden_states + identity = hidden_states_new + del hidden_states_new + + if self.down_sample or self.up_sample: + hidden_states = self.conv1(hidden_states) + else: + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv1(hidden_states) + + + if self.up_sample: + hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) + + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv2(hidden_states) + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv3(hidden_states) + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + return hidden_states + + +class AllegroDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + add_temp_downsample=False, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_temp_downsample: + self.temp_convs_down = TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + down_sample=True, + spa_stride=3 + ) + self.add_temp_downsample = add_temp_downsample + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + if self.downsamplers: + for down_layer in self.downsamplers: + down_layer.requires_grad_(True) + + def forward(self, hidden_states): + bz = hidden_states.shape[0] + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + if self.add_temp_downsample: + hidden_states = self.temp_convs_down(hidden_states) + + if self.downsamplers is not None: + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + for upsampler in self.downsamplers: + hidden_states = upsampler(hidden_states) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + return hidden_states + + +class AllegroUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + add_temp_upsample=False, + temb_channels=None, + ): + super().__init__() + self.add_upsample = add_upsample + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + self.add_temp_upsample = add_temp_upsample + if add_temp_upsample: + self.temp_conv_up = TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + up_sample=True, + spa_stride=3 + ) + + + if self.add_upsample: + # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + if self.add_upsample: + self.upsamplers.requires_grad_(True) + + def forward(self, hidden_states): + bz = hidden_states.shape[0] + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + if self.add_temp_upsample: + hidden_states = self.temp_conv_up(hidden_states) + + if self.upsamplers is not None: + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + return hidden_states + + +class UNetMidBlock3DConv(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvBlock( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + temp_convs.append( + TemporalConvBlock( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + + def forward( + self, + hidden_states, + ): + bz = hidden_states.shape[0] + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + + hidden_states = self.resnets[0](hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = self.temp_convs[0](hidden_states) + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + + for attn, resnet, temp_conv in zip( + self.attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + return hidden_states + + +class AllegroEncoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D"), + blocks_temp_li=[False, False, False, False], + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + + self.layers_per_block = layers_per_block + self.blocks_temp_li = blocks_temp_li + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + in_channels=block_out_channels[0], + out_channels=block_out_channels[0], + kernel_size=(3, 1, 1), + padding=(1, 0, 0) + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "AllegroDownBlock3D": + down_block = AllegroDownBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + add_temp_downsample=blocks_temp_li[i], + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0)) + + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x): + ''' + x: [b, c, (tb f), h, w] + ''' + bz = x.shape[0] + sample = rearrange(x, 'b c n h w -> (b n) c h w') + sample = self.conv_in(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_in(sample) + sample = sample+temp_sample + # down + for b_id, down_block in enumerate(self.down_blocks): + sample = down_block(sample) + # middle + sample = self.mid_block(sample) + + # post-process + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + + temp_sample = sample + sample = self.temp_conv_out(sample) + sample = sample+temp_sample + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + + sample = self.conv_out(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + return sample + +class AllegroDecoder3D(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D"), + blocks_temp_li=[False, False, False, False], + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + self.blocks_temp_li = blocks_temp_li + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + block_out_channels[-1], + block_out_channels[-1], + (3,1,1), + padding = (1, 0, 0) + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "AllegroUpBlock3D": + up_block = AllegroUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + add_temp_upsample=blocks_temp_li[i], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + else: + raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`") + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, z): + bz = z.shape[0] + sample = rearrange(z, 'b c n h w -> (b n) c h w') + sample = self.conv_in(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_in(sample) + sample = sample+temp_sample + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # up + for b_id, up_block in enumerate(self.up_blocks): + sample = up_block(sample) + + # post-process + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_out(sample) + sample = sample+temp_sample + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + + sample = self.conv_out(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + return sample + + +class AutoencoderKLAllegro(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size. + tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width) + chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size. + t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling + scaling_factor (`float`, *optional*, defaults to 0.13235): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling. + blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling. + load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + latent_channels: int = 4, + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_size: int = 320, + scaling_factor: float = 0.13235, + force_upcast: bool = True, + tile_overlap: tuple = (120, 80), + chunk_len: int = 24, + t_over: int = 8, + blocks_tempdown_li=[True, True, False, False], + blocks_tempup_li=[False, True, True, False], + ) -> None: + super().__init__() + + self.blocks_tempdown_li = blocks_tempdown_li + self.blocks_tempup_li = blocks_tempup_li + + self.encoder = AllegroEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + blocks_temp_li=blocks_tempdown_li, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + self.decoder = AllegroDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + blocks_temp_li=blocks_tempup_li, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + sample_size = ( + sample_size[0] + if isinstance(sample_size, (list, tuple)) + else sample_size + ) + self.tile_overlap = tile_overlap + self.vae_scale_factor=[4, 8, 8] + self.sample_size = sample_size + self.chunk_len = chunk_len + self.t_over = t_over + + self.latent_chunk_len = self.chunk_len//4 + self.latent_t_over = self.t_over//4 + self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256) + self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192) + + + def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + KERNEL = self.kernel + STRIDE = self.stride + LOCAL_BS = local_batch_size + OUT_C = 8 + + B, C, N, H, W = input_imgs.shape + + + out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 + out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 + out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 + + ## cut video into overlapped small cubes and batch forward + num = 0 + + out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype) + vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) + + for i in range(out_n): + for j in range(out_h): + for k in range(out_w): + n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] + h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] + w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] + video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[num%LOCAL_BS] = video_cube + + if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: + latent = self.encoder(vae_batch_input) + + if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: + out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + else: + out_latent[num-LOCAL_BS+1:num+1] = latent + vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) + num+=1 + + ## flatten the batched out latent to videos and supress the overlapped parts + B, C, N, H, W = input_imgs.shape + + out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype) + OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 + OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 + OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2] + + for i in range(out_n): + n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0] + for j in range(out_h): + h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1] + for k in range(out_w): + w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2] + latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend + + ## final conv + out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w') + out_video_cube = self.quant_conv(out_video_cube) + out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B) + + posterior = DiagonalGaussianDistribution(out_video_cube) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + + def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]: + KERNEL = self.kernel + STRIDE = self.stride + + LOCAL_BS = local_batch_size + OUT_C = 3 + IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 + IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 + + B, C, N, H, W = input_latents.shape + + ## post quant conv (a mapping) + input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w') + input_latents = self.post_quant_conv(input_latents) + input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B) + + ## out tensor shape + out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 + out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1 + out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1 + + ## cut latent into overlapped small cubes and batch forward + num = 0 + decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + for i in range(out_n): + for j in range(out_h): + for k in range(out_w): + n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0] + h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] + w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] + latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[num%LOCAL_BS] = latent_cube + if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: + + latent = self.decoder(vae_batch_input) + + if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: + decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + else: + decoded_cube[num-LOCAL_BS+1:num+1] = latent + vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + num+=1 + B, C, N, H, W = input_latents.shape + + out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype) + OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2] + for i in range(out_n): + n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] + for j in range(out_h): + h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] + for k in range(out_w): + w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] + out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend + + out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous() + + decoded = out_video + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + encoder_local_batch_size: int = 2, + decoder_local_batch_size: int = 2, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + PyTorch random number generator. + encoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the encoder's batch inference. + decoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the decoder's batch inference. + """ + x = sample + posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + kwargs["torch_type"] = torch.float32 + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + +def prepare_for_blend(n_param, h_param, w_param, x): + n, n_max, overlap_n = n_param + h, h_max, overlap_h = h_param + w, w_max, overlap_w = w_param + if overlap_n > 0: + if n > 0: # the head overlap part decays from 0 to 1 + x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) + if n < n_max-1: # the tail overlap part decays from 1 to 0 + x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) + if h > 0: + x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) + if h < h_max-1: + x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) + if w > 0: + x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w) + if w < w_max-1: + x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w) + return x diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 44f01c46ebe8..777920ded186 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1544,6 +1544,58 @@ def forward( return objs +class AllegroCombinedTimestepSizeEmbeddings(nn.Module): + """ + For Allegro. TODO(aryan) + """ + + def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.use_additional_conditions = True + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): + if size.ndim == 1: + size = size[:, None] + + if size.shape[0] != batch_size: + size = size.repeat(batch_size // size.shape[0], 1) + if size.shape[0] != batch_size: + raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") + + current_batch_size, dims = size.shape[0], size.shape[1] + size = size.reshape(-1) + size_freq = self.additional_condition_proj(size).to(size.dtype) + + size_emb = embedder(size_freq) + size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) + return size_emb + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) + aspect_ratio = self.apply_condition( + aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder + ) + conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 21e9d3cd6fc5..b0a7d3dfeb5c 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -23,6 +23,7 @@ from ..utils import is_torch_version from .activations import get_activation from .embeddings import ( + AllegroCombinedTimestepSizeEmbeddings, CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings, ) @@ -355,6 +356,41 @@ def forward( return x +class AllegroAdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = AllegroCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb( + timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None + ) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 58787c079ea8..873a2bbecf05 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -14,6 +14,7 @@ from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py new file mode 100644 index 000000000000..3c4386829543 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -0,0 +1,1586 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from functools import partial +from importlib import import_module +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import collections +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from ...configuration_utils import ConfigMixin, register_to_config +from ..activations import GEGLU, GELU, ApproximateGELU +from ..attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + SlicedAttnProcessor, + SpatialNorm, + XFormersAttnAddedKVProcessor, + XFormersAttnProcessor, +) +from ..embeddings import PixArtAlphaTextProjection, SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps, PatchEmbed +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, AdaLayerNormZero +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available +from ...utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import nn +from ..normalization import AllegroAdaLayerNormSingle +from ..modeling_outputs import Transformer2DModelOutput +from ..attention import FeedForward + + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + +class PositionGetter3D(object): + """ return positions of patches """ + + def __init__(self, ): + self.cache_positions = {} + + def __call__(self, b, t, h, w, device): + if not (b, t,h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + z = torch.arange(t, device=device) + pos = torch.cartesian_prod(z, y, x) + + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() + poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) + max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) + + self.cache_positions[b, t, h, w] = (poses, max_poses) + pos = self.cache_positions[b, t, h, w] + + return pos + + +class RoPE3D(torch.nn.Module): + + def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): + super().__init__() + self.base = freq + self.F0 = F0 + self.interpolation_scale_t = interpolation_scale_thw[0] + self.interpolation_scale_h = interpolation_scale_thw[1] + self.interpolation_scale_w = interpolation_scale_thw[2] + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + + # for (batch_size x ntokens x nheads x dim) + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (t, y and x position of each token) + output: + * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) + """ + assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" + D = tokens.size(3) // 3 + poses, max_poses = positions + assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 + cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) + cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) + cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) + # split features into three along the feature dimension, and apply rope1d on each half + t, y, x = tokens.chunk(3, dim=-1) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x) + tokens = torch.cat((t, y, x), dim=-1) + return tokens + +class PatchEmbed2D(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_abs_pos=False, + ): + super().__init__() + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + + def forward(self, latent): + b, _, _, _, _ = latent.shape + video_latent = None + + latent = rearrange(latent, 'b c t h w -> (b t) c h w') + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) + video_latent = latent + + return video_latent + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + use_rope: bool = False, + interpolation_scale_thw=None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0( + use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ + if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False): + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = head_size if head_size is not None else self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _init_compress(self): + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(self.inner_dim) + + +class AttnProcessor2_0(nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, use_rope=False, interpolation_scale_thw=None): + super().__init__() + self.use_rope = use_rope + self.interpolation_scale_thw = interpolation_scale_thw + + if self.use_rope: + self._init_rope(interpolation_scale_thw) + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def _init_rope(self, interpolation_scale_thw): + self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) + self.position_getter = PositionGetter3D() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + frame: int = 8, + height: int = 16, + width: int = 16, + ) -> torch.FloatTensor: + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + attn_heads = attn.heads + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn_heads + + query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + use_rope: bool = False, + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + use_rope=False, # do not position in cross attention + interpolation_scale_thw=interpolation_scale_thw, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + frame: int = None, + height: int = None, + width: int = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + frame=frame, + height=height, + width=width, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1. Cross-Attention + if self.attn2 is not None: + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + + # 2. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class AllegroTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + +# { +# "_class_name": "AllegroTransformer3DModel", +# "_diffusers_version": "0.30.3", +# "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/transformer", +# "activation_fn": "gelu-approximate", +# "attention_bias": true, +# "attention_head_dim": 96, +# "ca_attention_mode": "xformers", +# "caption_channels": 4096, +# "cross_attention_dim": 2304, +# "double_self_attention": false, +# "downsampler": null, +# "dropout": 0.0, +# "in_channels": 4, +# "interpolation_scale_h": 2.0, +# "interpolation_scale_t": 2.2, +# "interpolation_scale_w": 2.0, +# "model_max_length": 300, +# "norm_elementwise_affine": false, +# "norm_eps": 1e-06, +# "norm_type": "ada_norm_single", +# "num_attention_heads": 24, +# "num_embeds_ada_norm": 1000, +# "num_layers": 32, +# "only_cross_attention": false, +# "out_channels": 4, +# "patch_size": 2, +# "patch_size_t": 1, +# "sa_attention_mode": "flash", +# "sample_size": [ +# 90, +# 160 +# ], +# "sample_size_t": 22, +# "upcast_attention": false, +# "use_additional_conditions": null, +# "use_linear_projection": false, +# "use_rope": true +# } + + + @register_to_config + def __init__( + self, + patch_size: int = 2, + patch_size_temporal: int = 1, + num_attention_heads: int = 24, + attention_head_dim: int = 96, + in_channels: int = 4, + out_channels: int = 4, + num_layers: int = 32, + dropout: float = 0.0, + cross_attention_dim: int = 2304, + attention_bias: bool = True, + sample_height: int = 90, + sample_width: int = 160, + sample_frames: int = 22, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int = 1000, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + interpolation_scale_h: float = 2.0, + interpolation_scale_w: float = 2.0, + interpolation_scale_t: float = 2.2, + use_additional_conditions: Optional[bool] = None, + use_rotary_positional_embeddings: bool = True, + model_max_length: int = 300, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + + interpolation_scale_t = ( + interpolation_scale_t if interpolation_scale_t is not None else ((sample_frames - 1) // 16 + 1) if sample_frames % 2 == 1 else sample_frames // 16 + ) + interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 + interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 + + self.pos_embed = PatchEmbed2D( + height=sample_height, + width=sample_width, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + # pos_embed_type=None, + ) + interpolation_scale_thw = (interpolation_scale_t, interpolation_scale_h, interpolation_scale_w) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + use_rope=use_rotary_positional_embeddings, + interpolation_scale_thw=interpolation_scale_thw, + ) + for _ in range(num_layers) + ] + ) + + # 4. Define output layers + if norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) + elif norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=self.inner_dim + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + added_cond_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle` + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + batch_size, c, frame, h, w = hidden_states.shape + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(self.dtype) + attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + + if attention_mask_vid.numel() > 0: + attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w + attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size), + stride=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size)) + attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') + + attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1+use_image_num, l -> a video with images + # b, 1, l -> only images + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None + + # 1. Input + frame = frame // self.config.patch_size_temporal + height = hidden_states.shape[-2] // self.config.patch_size + width = hidden_states.shape[-1] // self.config.patch_size + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs + hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, + ) + + for _, block in enumerate(self.transformer_blocks): + # TODO(aryan): Implement gradient checkpointing + hidden_states = block( + hidden_states, + attention_mask_vid, + encoder_hidden_states_vid, + encoder_attention_mask_vid, + timestep_vid, + cross_attention_kwargs, + class_labels, + frame=frame, + height=height, + width=width, + ) + + # 3. Output + output = None + if hidden_states is not None: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep_vid, + class_labels=class_labels, + embedded_timestep=embedded_timestep_vid, + num_frames=frame, + height=height, + width=width, + ) # b c t h w + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _operate_on_patched_inputs(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, added_cond_kwargs: Dict[str, Any], batch_size: int): + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # TODO(aryan): remove dtype conversion here and move to pipeline if needed + + timestep_vid = None + embedded_timestep_vid = None + encoder_hidden_states_vid = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d + + timestep_vid = timestep + embedded_timestep_vid = embedded_timestep + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d + encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d') + + return hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None + ) -> torch.Tensor: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=self.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, num_frames, height, width, self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + output = hidden_states.reshape(-1, self.out_channels, num_frames * self.config.patch_size_temporal, height * self.config.patch_size, width * self.config.patch_size) + return output diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7366520f4692..634088f1b51a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -116,6 +116,7 @@ "VersatileDiffusionTextToImagePipeline", ] ) + _import_structure["allegro"] = ["AllegroPipeline"] _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = [ "AnimateDiffPipeline", @@ -454,6 +455,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: + from .allegro import AllegroPipeline from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import ( AnimateDiffControlNetPipeline, diff --git a/src/diffusers/pipelines/allegro/__init__.py b/src/diffusers/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..2162b825e0a2 --- /dev/null +++ b/src/diffusers/pipelines/allegro/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_allegro"] = ["AllegroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_allegro import AllegroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 000000000000..d3972f8a9019 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,829 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union +import torch +from dataclasses import dataclass +from transformers import T5EncoderModel, T5Tokenizer +import tqdm + +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, + BaseOutput +) +from ...utils.torch_utils import randn_tensor +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from .pipeline_output import AllegroPipelineOutput +from ...video_processor import VideoProcessor + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> # You can replace the your_path_to_model with your own path. + >>> pipe = AllegroPipeline.from_pretrained(your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True) + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).video[0] + ``` +""" + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLAllegro, + transformer: AllegroTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + return caption.strip() + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if num_frames % 2 == 0: + num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) + else: + num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1 + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] + return frames + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + max_sequence_length: int = 300, + verbose: bool = True, + ) -> Union[AllegroPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_size_t * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_size[0] * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size[1] * self.vae_scale_factor_spatial + + self.check_inputs( + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + progress_wrap = tqdm.tqdm if verbose else (lambda x: x) + for i, t in progress_wrap(list(enumerate(timesteps))): + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + if prompt_attention_mask.ndim == 2: + prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + # prepare attention_mask. + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if latent_channels == self.transformer.config.out_channels // 2: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + video = self.decode_latents(latents) + video = video[:, :, :num_frames, :height, :width] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AllegroPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/allegro/pipeline_output.py b/src/diffusers/pipelines/allegro/pipeline_output.py new file mode 100644 index 000000000000..ed8ca1862540 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch +import PIL + +from diffusers.utils import BaseOutput + + +@dataclass +class AllegroPipelineOutput(BaseOutput): + r""" + Output class for Allegro pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] From 901d10ebeb47eacd7229fa84695aa60613e10ed2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Oct 2024 17:40:18 +0200 Subject: [PATCH 02/33] refactor transformer part 1 --- src/diffusers/models/attention_processor.py | 92 ++ src/diffusers/models/embeddings.py | 45 +- .../transformers/transformer_allegro.py | 1196 +---------------- .../pipelines/allegro/pipeline_allegro.py | 68 +- 4 files changed, 268 insertions(+), 1133 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d333590982e3..ca91dd436a39 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1506,6 +1506,98 @@ def __call__( return hidden_states, encoder_hidden_states +class AllegroAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Allegro model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply RoPE if needed + if image_rotary_emb is not None and not attn.is_cross_attention: + from .embeddings import apply_rotary_emb_allegro + + query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1]) + key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1]) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class AuraFlowAttnProcessor2_0: """Attention processor used typically in processing Aura Flow.""" diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 777920ded186..99389f3ab8f1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -564,6 +564,31 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): return cos, sin +def get_3d_rotary_pos_embed_allegro( + embed_dim, crops_coords, grid_size, temporal_size, interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), theta: int = 10000 +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # TODO(aryan): docs + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 3 + dim_h = embed_dim // 3 + dim_w = embed_dim // 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False) + + return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. @@ -684,7 +709,7 @@ def get_1d_rotary_pos_embed( freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: - # stable audio + # stable audio, allegro freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin @@ -743,6 +768,24 @@ def apply_rotary_emb( return x_out.type_as(x) +def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): + # TODO(aryan): rewrite + def apply_1d_rope(tokens, pos, cos, sin): + cos = F.embedding(pos, cos)[:, None, :, :] + sin = F.embedding(pos, sin)[:, None, :, :] + x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2:] + tokens_rotated = torch.cat((-x2, x1), dim=-1) + return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype) + + (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis + t, h, w = x.chunk(3, dim=-1) + t = apply_1d_rope(t, positions[0], t_cos, t_sin) + h = apply_1d_rope(h, positions[1], h_cos, h_sin) + w = apply_1d_rope(w, positions[2], w_cos, w_sin) + x = torch.cat([t, h, w], dim=-1) + return x + + class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 3c4386829543..03793684b6cf 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -28,21 +28,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..activations import GEGLU, GELU, ApproximateGELU from ..attention_processor import ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - AttnProcessor, - CustomDiffusionAttnProcessor, - CustomDiffusionAttnProcessor2_0, - CustomDiffusionXFormersAttnProcessor, - LoRAAttnAddedKVProcessor, - LoRAAttnProcessor, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - SlicedAttnAddedKVProcessor, - SlicedAttnProcessor, - SpatialNorm, - XFormersAttnAddedKVProcessor, - XFormersAttnProcessor, + Attention, + AllegroAttnProcessor2_0, ) from ..embeddings import PixArtAlphaTextProjection, SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps, PatchEmbed from ..modeling_utils import ModelMixin @@ -50,105 +37,15 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available from ...utils.torch_utils import maybe_allow_in_graph from einops import rearrange, repeat -from torch import nn +import torch.nn as nn from ..normalization import AllegroAdaLayerNormSingle from ..modeling_outputs import Transformer2DModelOutput from ..attention import FeedForward - - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - -from diffusers.utils import logging +from ...utils import logging logger = logging.get_logger(__name__) -class PositionGetter3D(object): - """ return positions of patches """ - - def __init__(self, ): - self.cache_positions = {} - - def __call__(self, b, t, h, w, device): - if not (b, t,h,w) in self.cache_positions: - x = torch.arange(w, device=device) - y = torch.arange(h, device=device) - z = torch.arange(t, device=device) - pos = torch.cartesian_prod(z, y, x) - - pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() - poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) - max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) - - self.cache_positions[b, t, h, w] = (poses, max_poses) - pos = self.cache_positions[b, t, h, w] - - return pos - - -class RoPE3D(torch.nn.Module): - - def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): - super().__init__() - self.base = freq - self.F0 = F0 - self.interpolation_scale_t = interpolation_scale_thw[0] - self.interpolation_scale_h = interpolation_scale_thw[1] - self.interpolation_scale_w = interpolation_scale_thw[2] - self.cache = {} - - def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): - if (D, seq_len, device, dtype) not in self.cache: - inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale - freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) - freqs = torch.cat((freqs, freqs), dim=-1) - cos = freqs.cos() # (Seq, Dim) - sin = freqs.sin() - self.cache[D, seq_len, device, dtype] = (cos, sin) - return self.cache[D, seq_len, device, dtype] - - @staticmethod - def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - def apply_rope1d(self, tokens, pos1d, cos, sin): - assert pos1d.ndim == 2 - - # for (batch_size x ntokens x nheads x dim) - cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] - sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] - return (tokens * cos) + (self.rotate_half(tokens) * sin) - - def forward(self, tokens, positions): - """ - input: - * tokens: batch_size x nheads x ntokens x dim - * positions: batch_size x ntokens x 3 (t, y and x position of each token) - output: - * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) - """ - assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" - D = tokens.size(3) // 3 - poses, max_poses = positions - assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 - cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) - cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) - cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) - # split features into three along the feature dimension, and apply rope1d on each half - t, y, x = tokens.chunk(3, dim=-1) - t = self.apply_rope1d(t, poses[0], cos_t, sin_t) - y = self.apply_rope1d(y, poses[1], cos_y, sin_y) - x = self.apply_rope1d(x, poses[2], cos_x, sin_x) - tokens = torch.cat((t, y, x), dim=-1) - return tokens - class PatchEmbed2D(nn.Module): """2D Image to Patch Embedding""" @@ -201,806 +98,9 @@ def forward(self, latent): @maybe_allow_in_graph -class Attention(nn.Module): +class AllegroTransformerBlock(nn.Module): r""" - A cross attention layer. - - Parameters: - query_dim (`int`): - The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): - The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to False): - Set to `True` to upcast the attention computation to `float32`. - upcast_softmax (`bool`, *optional*, defaults to False): - Set to `True` to upcast the softmax computation to `float32`. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the group norm in the cross attention. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - norm_num_groups (`int`, *optional*, defaults to `None`): - The number of groups to use for the group norm in the attention. - spatial_norm_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the spatial normalization. - out_bias (`bool`, *optional*, defaults to `True`): - Set to `True` to use a bias in the output linear layer. - scale_qk (`bool`, *optional*, defaults to `True`): - Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. - only_cross_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if - `added_kv_proj_dim` is not `None`. - eps (`float`, *optional*, defaults to 1e-5): - An additional value added to the denominator in group normalization that is used for numerical stability. - rescale_output_factor (`float`, *optional*, defaults to 1.0): - A factor to rescale the output by dividing it with this value. - residual_connection (`bool`, *optional*, defaults to `False`): - Set to `True` to add the residual connection to the output. - _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): - Set to `True` if the attention block is loaded from a deprecated state dict. - processor (`AttnProcessor`, *optional*, defaults to `None`): - The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and - `AttnProcessor` otherwise. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - use_rope: bool = False, - interpolation_scale_thw=None, - ): - super().__init__() - self.inner_dim = dim_head * heads - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.use_rope = use_rope - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - if cross_attention_norm is None: - self.norm_cross = None - elif cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(self.cross_attention_dim) - elif cross_attention_norm == "group_norm": - if self.added_kv_proj_dim is not None: - # The given `encoder_hidden_states` are initially of shape - # (batch_size, seq_len, added_kv_proj_dim) before being projected - # to (batch_size, seq_len, cross_attention_dim). The norm is applied - # before the projection, so we need to use `added_kv_proj_dim` as - # the number of channels for the group norm. - norm_cross_num_channels = added_kv_proj_dim - else: - norm_cross_num_channels = self.cross_attention_dim - - self.norm_cross = nn.GroupNorm( - num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True - ) - else: - raise ValueError( - f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" - ) - - linear_cls = nn.Linear - - - self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - if processor is None: - processor = ( - AttnProcessor2_0( - use_rope, - interpolation_scale_thw=interpolation_scale_thw, - ) - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ) -> None: - r""" - Set whether to use memory efficient attention from `xformers` or not. - - Args: - use_memory_efficient_attention_xformers (`bool`): - Whether to use memory efficient attention from `xformers` or not. - attention_op (`Callable`, *optional*): - The attention operation to use. Defaults to `None` which uses the default attention operation from - `xformers`. - """ - is_lora = hasattr(self, "processor") - is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, - (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), - ) - is_added_kv_processor = hasattr(self, "processor") and isinstance( - self.processor, - ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - SlicedAttnAddedKVProcessor, - XFormersAttnAddedKVProcessor, - LoRAAttnAddedKVProcessor, - ), - ) - - if use_memory_efficient_attention_xformers: - if is_added_kv_processor and (is_lora or is_custom_diffusion): - raise NotImplementedError( - f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" - ) - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - - if is_lora: - # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers - # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? - processor = LoRAXFormersAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - elif is_custom_diffusion: - processor = CustomDiffusionXFormersAttnProcessor( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - elif is_added_kv_processor: - # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported - # throw warning - logger.info( - "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." - ) - processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) - else: - processor = XFormersAttnProcessor(attention_op=attention_op) - else: - if is_lora: - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - processor = attn_processor_class( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - elif is_custom_diffusion: - attn_processor_class = ( - CustomDiffusionAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else CustomDiffusionAttnProcessor - ) - processor = attn_processor_class( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - - self.set_processor(processor) - - def set_attention_slice(self, slice_size: int) -> None: - r""" - Set the slice size for attention computation. - - Args: - slice_size (`int`): - The slice size for attention computation. - """ - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = AttnAddedKVProcessor() - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - _remove_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to remove LoRA layers from the model. - """ - if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: - deprecate( - "set_processor to offload LoRA", - "0.26.0", - "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", - ) - # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete - # We need to remove all LoRA layers - # Don't forget to remove ALL `_remove_lora` from the codebase - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False): - r""" - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible - # serialization format for LoRA Attention Processors. It should be deleted once the integration - # with PEFT is completed. - is_lora_activated = { - name: module.lora_layer is not None - for name, module in self.named_modules() - if hasattr(module, "lora_layer") - } - - # 1. if no layer has a LoRA activated we can return the processor as usual - if not any(is_lora_activated.values()): - return self.processor - - # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` - is_lora_activated.pop("add_k_proj", None) - is_lora_activated.pop("add_v_proj", None) - # 2. else it is not posssible that only some layers have LoRA activated - if not all(is_lora_activated.values()): - raise ValueError( - f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" - ) - - # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor - non_lora_processor_cls_name = self.processor.__class__.__name__ - lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) - - hidden_size = self.inner_dim - - # now create a LoRA attention processor from the LoRA layers - if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: - kwargs = { - "cross_attention_dim": self.cross_attention_dim, - "rank": self.to_q.lora_layer.rank, - "network_alpha": self.to_q.lora_layer.network_alpha, - "q_rank": self.to_q.lora_layer.rank, - "q_hidden_size": self.to_q.lora_layer.out_features, - "k_rank": self.to_k.lora_layer.rank, - "k_hidden_size": self.to_k.lora_layer.out_features, - "v_rank": self.to_v.lora_layer.rank, - "v_hidden_size": self.to_v.lora_layer.out_features, - "out_rank": self.to_out[0].lora_layer.rank, - "out_hidden_size": self.to_out[0].lora_layer.out_features, - } - - if hasattr(self.processor, "attention_op"): - kwargs["attention_op"] = self.processor.attention_op - - lora_processor = lora_processor_cls(hidden_size, **kwargs) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) - elif lora_processor_cls == LoRAAttnAddedKVProcessor: - lora_processor = lora_processor_cls( - hidden_size, - cross_attention_dim=self.add_k_proj.weight.shape[0], - rank=self.to_q.lora_layer.rank, - network_alpha=self.to_q.lora_layer.network_alpha, - ) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) - - # only save if used - if self.add_k_proj.lora_layer is not None: - lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) - lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) - else: - lora_processor.add_k_proj_lora = None - lora_processor.add_v_proj_lora = None - else: - raise ValueError(f"{lora_processor_cls} does not exist.") - - return lora_processor - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` - is the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is - reshaped to `[batch_size * heads, seq_len, dim // heads]`. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - - if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - - return tensor - - def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None - ) -> torch.Tensor: - r""" - Compute the attention scores. - - Args: - query (`torch.Tensor`): The query tensor. - key (`torch.Tensor`): The key tensor. - attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. - - Returns: - `torch.Tensor`: The attention probabilities/scores. - """ - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, - ) - del baddbmm_input - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - del attention_scores - - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None, - ) -> torch.Tensor: - r""" - Prepare the attention mask for the attention computation. - - Args: - attention_mask (`torch.Tensor`): - The attention mask to prepare. - target_length (`int`): - The target length of the attention mask. This is the length of the attention mask after padding. - batch_size (`int`): - The batch size, which is used to repeat the attention mask. - out_dim (`int`, *optional*, defaults to `3`): - The output dimension of the attention mask. Can be either `3` or `4`. - - Returns: - `torch.Tensor`: The prepared attention mask. - """ - head_size = head_size if head_size is not None else self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: - # we want to instead pad by (0, remaining_length), where remaining_length is: - # remaining_length: int = target_length - current_length - # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the - `Attention` class. - - Args: - encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. - - Returns: - `torch.Tensor`: The normalized encoder hidden states. - """ - assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - # Group norm norms along the channels dimension and expects - # input to be in the shape of (N, C, *). In this case, we want - # to norm along the hidden dimension, so we need to move - # (batch_size, sequence_length, hidden_size) -> - # (batch_size, hidden_size, sequence_length) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - else: - assert False - - return encoder_hidden_states - - def _init_compress(self): - self.sr.bias.data.zero_() - self.norm = nn.LayerNorm(self.inner_dim) - - -class AttnProcessor2_0(nn.Module): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self, use_rope=False, interpolation_scale_thw=None): - super().__init__() - self.use_rope = use_rope - self.interpolation_scale_thw = interpolation_scale_thw - - if self.use_rope: - self._init_rope(interpolation_scale_thw) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def _init_rope(self, interpolation_scale_thw): - self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) - self.position_getter = PositionGetter3D() - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - frame: int = 8, - height: int = 16, - width: int = 16, - ) -> torch.FloatTensor: - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - attn_heads = attn.heads - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn_heads - - query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - if self.use_rope: - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) - query = self.rope(query, pos_thw) - key = self.rope(key, pos_thw) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. + TODO(aryan): docs """ def __init__( @@ -1011,52 +111,17 @@ def __init__( dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, - double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' norm_eps: float = 1e-5, final_dropout: bool = False, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - use_rope: bool = False, - interpolation_scale_thw: Tuple[int] = (1, 1, 1), ): super().__init__() - self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, @@ -1066,38 +131,24 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, - use_rope=use_rope, - interpolation_scale_thw=interpolation_scale_thw, + processor=AllegroAttnProcessor2_0(), ) - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - use_rope=False, # do not position in cross attention - interpolation_scale_thw=interpolation_scale_thw, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None + # 2. Cross Attention + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + processor=AllegroAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none - # 3. Feed-forward - if not self.use_ada_layer_norm_single: - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # 3. Feed Forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.ff = FeedForward( dim, @@ -1106,63 +157,35 @@ def __init__( final_dropout=final_dropout, ) - # 5. Scale-shift for PixArt-Alpha. - if self.use_ada_layer_norm_single: - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + # 4. Scale-shift + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) def forward( self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - frame: int = None, - height: int = None, - width: int = None, - ) -> torch.FloatTensor: - # Notice that normalization is always applied before the real computation in the following blocks. - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb = None, + ) -> torch.Tensor: # 0. Self-Attention batch_size = hidden_states.shape[0] - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.use_layer_norm: - norm_hidden_states = self.norm1(hidden_states) - elif self.use_ada_layer_norm_single: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - norm_hidden_states = norm_hidden_states.squeeze(1) - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - frame=frame, - height=height, - width=width, - **cross_attention_kwargs, + encoder_hidden_states=None, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.use_ada_layer_norm_single: - attn_output = gate_msa * attn_output + attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: @@ -1170,50 +193,26 @@ def forward( # 1. Cross-Attention if self.attn2 is not None: - - if self.use_ada_layer_norm: - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.use_ada_layer_norm_zero or self.use_layer_norm: - norm_hidden_states = self.norm2(hidden_states) - elif self.use_ada_layer_norm_single: - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.use_ada_layer_norm_single is False: - norm_hidden_states = self.pos_embed(norm_hidden_states) + norm_hidden_states = hidden_states attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - **cross_attention_kwargs, + image_rotary_emb=None, ) hidden_states = attn_output + hidden_states - # 2. Feed-forward - if not self.use_ada_layer_norm_single: - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self.use_ada_layer_norm_single: - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.use_ada_layer_norm_single: - ff_output = gate_mlp * ff_output - + ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states + + # TODO(aryan): maybe following line is not required if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) @@ -1308,12 +307,7 @@ def __init__( sample_width: int = 160, sample_frames: int = 22, activation_fn: str = "gelu-approximate", - num_embeds_ada_norm: int = 1000, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, upcast_attention: bool = False, - norm_type: str = "ada_norm_single", norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = 4096, @@ -1348,46 +342,29 @@ def __init__( # 3. Define transformers blocks, spatial attention self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + AllegroTransformerBlock( self.inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, upcast_attention=upcast_attention, - norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, - use_rope=use_rotary_positional_embeddings, - interpolation_scale_thw=interpolation_scale_thw, ) for _ in range(num_layers) ] ) # 4. Define output layers - if norm_type != "ada_norm_single": - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) - self.proj_out_2 = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) - elif norm_type == "ada_norm_single": - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) # 5. PixArt-Alpha blocks. - self.adaln_single = None - self.use_additional_conditions = False - if norm_type == "ada_norm_single": - # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 - # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use - # additional conditions until we find better name - self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=self.use_additional_conditions) + self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) self.caption_projection = None if caption_channels is not None: @@ -1410,48 +387,9 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - added_cond_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle` - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. @@ -1476,8 +414,7 @@ def forward( if attention_mask_vid.numel() > 0: attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w - attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size), - stride=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size)) + attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size), stride=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size)) attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None @@ -1501,17 +438,14 @@ def forward( for _, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing - hidden_states = block( - hidden_states, - attention_mask_vid, - encoder_hidden_states_vid, - encoder_attention_mask_vid, - timestep_vid, - cross_attention_kwargs, - class_labels, - frame=frame, - height=height, - width=width, + block: AllegroTransformerBlock + hidden_states = block.forward( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states_vid, + temb=timestep_vid, + attention_mask=attention_mask_vid, + encoder_attention_mask=encoder_attention_mask_vid, + image_rotary_emb=image_rotary_emb, ) # 3. Output @@ -1540,7 +474,7 @@ def _operate_on_patched_inputs(self, hidden_states: torch.Tensor, encoder_hidden encoder_hidden_states_vid = None if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: + if self.config.use_additional_conditions and added_cond_kwargs is None: raise ValueError( "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." ) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index d3972f8a9019..fcf848cfd7b9 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -36,6 +36,7 @@ ) from ...utils.torch_utils import randn_tensor from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro from .pipeline_output import AllegroPipelineOutput from ...video_processor import VideoProcessor @@ -106,6 +107,25 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + class AllegroPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Allegro. @@ -563,10 +583,51 @@ def prepare_latents( def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = 1 / self.vae.config.scaling_factor * latents - frames = self.vae.decode(latents).sample frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] return frames + + def _prepare_rotary_positional_embeddings( + self, + batch_size: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + attention_head_dim = 96 + vae_scale_factor_spatial = 8 + patch_size = 2 + + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = 1280 // (vae_scale_factor_spatial * patch_size) + base_size_height = 720 // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + interpolation_scale=(self.transformer.config.interpolation_scale_t, self.transformer.config.interpolation_scale_h, self.transformer.config.interpolation_scale_w) + ) + + grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long) + grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long) + grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long) + + pos = torch.cartesian_prod(grid_t, grid_h, grid_w) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, batch_size, -1) + grid_t, grid_h, grid_w = pos + + freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device)) + freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device)) + freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device)) + + return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -750,6 +811,8 @@ def __call__( # 6.1 Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + image_rotary_emb = self._prepare_rotary_positional_embeddings(batch_size, height, width, latents.size(2), device) + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -778,9 +841,11 @@ def __call__( prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + # prepare attention_mask. # b c t h w -> b t h w attention_mask = torch.ones_like(latent_model_input)[:, 0] + # predict noise model_output noise_pred = self.transformer( latent_model_input, @@ -789,6 +854,7 @@ def __call__( encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, + image_rotary_emb=image_rotary_emb, return_dict=False, )[0] From ec05bbdf340d08021dc421f1c6d0ccbd03877572 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 01:42:44 +0200 Subject: [PATCH 03/33] refactor part 2 --- .../transformers/transformer_allegro.py | 149 ++++++------------ .../pipelines/allegro/pipeline_allegro.py | 7 +- 2 files changed, 49 insertions(+), 107 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 03793684b6cf..86a13cc5d582 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -314,14 +314,12 @@ def __init__( interpolation_scale_h: float = 2.0, interpolation_scale_w: float = 2.0, interpolation_scale_t: float = 2.2, - use_additional_conditions: Optional[bool] = None, use_rotary_positional_embeddings: bool = True, model_max_length: int = 300, ): super().__init__() self.inner_dim = num_attention_heads * attention_head_dim - self.out_channels = in_channels if out_channels is None else out_channels interpolation_scale_t = ( interpolation_scale_t if interpolation_scale_t is not None else ((sample_frames - 1) // 16 + 1) if sample_frames % 2 == 1 else sample_frames // 16 @@ -329,6 +327,7 @@ def __init__( interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 + # 1. Patch embedding self.pos_embed = PatchEmbed2D( height=sample_height, width=sample_width, @@ -337,9 +336,8 @@ def __init__( embed_dim=self.inner_dim, # pos_embed_type=None, ) - interpolation_scale_thw = (interpolation_scale_t, interpolation_scale_h, interpolation_scale_w) - # 3. Define transformers blocks, spatial attention + # 2. Transformer blocks self.transformer_blocks = nn.ModuleList( [ AllegroTransformerBlock( @@ -358,19 +356,18 @@ def __init__( ] ) - # 4. Define output layers + # 3. Output projection & norm self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) - # 5. PixArt-Alpha blocks. + # 4. Timestep embeddings self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) - self.caption_projection = None - if caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim - ) + # 5. Caption projection + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=self.inner_dim + ) self.gradient_checkpointing = False @@ -382,15 +379,14 @@ def forward( hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): - batch_size, c, frame, h, w = hidden_states.shape + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t = self.config.patch_size_temporal + p = self.config.patch_size # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. @@ -409,112 +405,61 @@ def forward( # (keep = +0, discard = -10000.0) # b, frame+use_image_num, h, w -> a video with images # b, 1, h, w -> only images - attention_mask = attention_mask.to(self.dtype) - attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + attention_mask = attention_mask.to(hidden_states.dtype) + attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width] - if attention_mask_vid.numel() > 0: - attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w - attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size), stride=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size)) - attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') + if attention_mask.numel() > 0: + attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width] + attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) + attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) - attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None + attention_mask = (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # b, 1+use_image_num, l -> a video with images # b, 1, l -> only images encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None + encoder_attention_mask = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None # 1. Input - frame = frame // self.config.patch_size_temporal - height = hidden_states.shape[-2] // self.config.patch_size - width = hidden_states.shape[-1] // self.config.patch_size + post_patch_num_frames = num_frames // self.config.patch_size_temporal + post_patch_height = height // self.config.patch_size + post_patch_width = width // self.config.patch_size - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs - hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, - ) + timestep, embedded_timestep = self.adaln_single(timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype) + + hidden_states = self.pos_embed(hidden_states) # TODO(aryan): remove dtype conversion here and move to pipeline if needed + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) - for _, block in enumerate(self.transformer_blocks): + for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing - block: AllegroTransformerBlock - hidden_states = block.forward( + hidden_states = block( hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states_vid, - temb=timestep_vid, - attention_mask=attention_mask_vid, - encoder_attention_mask=encoder_attention_mask_vid, + encoder_hidden_states=encoder_hidden_states, + temb=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) # 3. Output - output = None - if hidden_states is not None: - output = self._get_output_for_patched_inputs( - hidden_states=hidden_states, - timestep=timestep_vid, - class_labels=class_labels, - embedded_timestep=embedded_timestep_vid, - num_frames=frame, - height=height, - width=width, - ) # b c t h w + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - - def _operate_on_patched_inputs(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, added_cond_kwargs: Dict[str, Any], batch_size: int): - hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # TODO(aryan): remove dtype conversion here and move to pipeline if needed - - timestep_vid = None - embedded_timestep_vid = None - encoder_hidden_states_vid = None - - if self.adaln_single is not None: - if self.config.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype - ) # b 6d, b d - - timestep_vid = timestep - embedded_timestep_vid = embedded_timestep - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d - encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d') - - return hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid - - def _get_output_for_patched_inputs( - self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None - ) -> torch.Tensor: - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=self.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config.norm_type == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, num_frames, height, width, self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) - output = hidden_states.reshape(-1, self.out_channels, num_frames * self.config.patch_size_temporal, height * self.config.patch_size, width * self.config.patch_size) - return output diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index fcf848cfd7b9..6a0a71275012 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -808,12 +808,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Prepare micro-conditions. - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - + # 7. Prepare rotary embeddings image_rotary_emb = self._prepare_rotary_positional_embeddings(batch_size, height, width, latents.size(2), device) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) progress_wrap = tqdm.tqdm if verbose else (lambda x: x) @@ -853,7 +851,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, - added_cond_kwargs=added_cond_kwargs, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] From 892b70dee47edaa766c659148da734f501220065 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 02:07:07 +0200 Subject: [PATCH 04/33] refactor part 3 --- .../autoencoders/autoencoder_kl_allegro.py | 274 ++++++++---------- 1 file changed, 128 insertions(+), 146 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 2ec0855635b4..7440c4070e14 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -34,66 +34,63 @@ from ..attention_processor import SpatialNorm -class TemporalConvBlock(nn.Module): - """ - Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: +class AllegroTemporalConvBlock(nn.Module): + r""" + Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ - def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): + def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, up_sample: bool = False, down_sample: bool = False, stride: int = 1) -> None: super().__init__() + out_dim = out_dim or in_dim - self.in_dim = in_dim - self.out_dim = out_dim - spa_pad = int((spa_stride-1)*0.5) - temp_pad = 0 - self.temp_pad = temp_pad + pad_h = pad_w = int((stride - 1) * 0.5) + pad_t = 0 + + self.down_sample = down_sample + self.up_sample = up_sample if down_sample: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) + nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2,1,1), padding=(0, pad_h, pad_w)) ) elif up_sample: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), - nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) + nn.Conv3d(in_dim, out_dim*2, (1, stride, stride), padding=(0, pad_h, pad_w)) ) else: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) + nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)) ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), ) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) - # zero out the last layer params,so the conv block is identity + # zero out the last layer params, so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) - self.down_sample = down_sample - self.up_sample = up_sample - - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: identity = hidden_states if self.down_sample: @@ -112,7 +109,6 @@ def forward(self, hidden_states): hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) hidden_states = self.conv1(hidden_states) - if self.up_sample: hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) @@ -149,6 +145,7 @@ def __init__( downsample_padding=1, ): super().__init__() + resnets = [] temp_convs = [] @@ -169,24 +166,24 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock( + AllegroTemporalConvBlock( out_channels, out_channels, dropout=0.1, - ) ) + ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_temp_downsample: - self.temp_convs_down = TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - down_sample=True, - spa_stride=3 - ) + self.temp_convs_down = AllegroTemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + down_sample=True, + stride=3 + ) self.add_temp_downsample = add_temp_downsample if add_downsample: @@ -200,29 +197,24 @@ def __init__( else: self.downsamplers = None - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) - if self.downsamplers: - for down_layer in self.downsamplers: - down_layer.requires_grad_(True) - - def forward(self, hidden_states): - bz = hidden_states.shape[0] + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) + if self.add_temp_downsample: hidden_states = self.temp_convs_down(hidden_states) if self.downsamplers is not None: - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - for upsampler in self.downsamplers: - hidden_states = upsampler(hidden_states) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states @@ -267,55 +259,49 @@ def __init__( ) ) temp_convs.append( - TemporalConvBlock( + AllegroTemporalConvBlock( out_channels, out_channels, dropout=0.1, - ) ) + ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.add_temp_upsample = add_temp_upsample if add_temp_upsample: - self.temp_conv_up = TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - up_sample=True, - spa_stride=3 - ) - + self.temp_conv_up = AllegroTemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + up_sample=True, + stride=3 + ) if self.add_upsample: - # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None - - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) - if self.add_upsample: - self.upsamplers.requires_grad_(True) - def forward(self, hidden_states): - bz = hidden_states.shape[0] + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) + if self.add_temp_upsample: hidden_states = self.temp_conv_up(hidden_states) if self.upsamplers is not None: - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states @@ -355,7 +341,7 @@ def __init__( ) ] temp_convs = [ - TemporalConvBlock( + AllegroTemporalConvBlock( in_channels, in_channels, dropout=0.1, @@ -402,7 +388,7 @@ def __init__( ) temp_convs.append( - TemporalConvBlock( + AllegroTemporalConvBlock( in_channels, in_channels, dropout=0.1, @@ -412,30 +398,27 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) - - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) def forward( self, - hidden_states, + hidden_states: torch.Tensor ): - bz = hidden_states.shape[0] - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.resnets[0](hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = self.temp_convs[0](hidden_states) - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for attn, resnet, temp_conv in zip( self.attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) + return hidden_states @@ -522,39 +505,42 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x): - ''' - x: [b, c, (tb f), h, w] - ''' - bz = x.shape[0] - sample = rearrange(x, 'b c n h w -> (b n) c h w') + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_in(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + residual = sample sample = self.temp_conv_in(sample) - sample = sample+temp_sample - # down - for b_id, down_block in enumerate(self.down_blocks): + sample = sample + residual + + # Down blocks + for down_block in self.down_blocks: sample = down_block(sample) - # middle + + # Mid block sample = self.mid_block(sample) - # post-process - sample = rearrange(sample, 'b c n h w -> (b n) c h w') + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - temp_sample = sample + residual = sample sample = self.temp_conv_out(sample) - sample = sample+temp_sample - sample = rearrange(sample, 'b c n h w -> (b n) c h w') + sample = sample + residual + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample - + + class AllegroDecoder3D(nn.Module): def __init__( self, @@ -569,6 +555,7 @@ def __init__( norm_type: str = "group", # group, spatial ): super().__init__() + self.layers_per_block = layers_per_block self.blocks_temp_li = blocks_temp_li @@ -637,84 +624,85 @@ def __init__( self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0)) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + # TODO(aryan): implement gradient checkpointing self.gradient_checkpointing = False - def forward(self, z): - bz = z.shape[0] - sample = rearrange(z, 'b c n h w -> (b n) c h w') - sample = self.conv_in(sample) + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + residual = sample sample = self.temp_conv_in(sample) - sample = sample+temp_sample + sample = sample + residual upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle + + # Mid block sample = self.mid_block(sample) sample = sample.to(upscale_dtype) - # up - for b_id, up_block in enumerate(self.up_blocks): + # Up blocks + for up_block in self.up_blocks: sample = up_block(sample) - # post-process - sample = rearrange(sample, 'b c n h w -> (b n) c h w') + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) - - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + residual = sample sample = self.temp_conv_out(sample) - sample = sample+temp_sample - sample = rearrange(sample, 'b c n h w -> (b n) c h w') - + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample class AutoencoderKLAllegro(ModelMixin, ConfigMixin): r""" - A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in + [Allegro](https://github.com/rhymes-ai/Allegro). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + in_channels (int, defaults to `3`): Number of channels in the input image. + out_channels (int, defaults to `3`): Number of channels in the output. + down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + up_block_types (`Tuple[str]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + block_out_channels (`Tuple[int]`, defaults to `(128, 256, 512, 512)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size. - tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width) - chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size. - t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling - scaling_factor (`float`, *optional*, defaults to 0.13235): + act_fn (`str`, defaults to `"silu"`): + The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, defaults to `0.13235`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - force_upcast (`bool`, *optional*, default to `True`): + force_upcast (`bool`, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling. - blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling. - load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts. + TODO(aryan): docs """ _supports_gradient_checkpointing = True @@ -800,8 +788,8 @@ def __init__( self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256) self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192) - def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + # TODO(aryan): rewrite to encode and tiled_encode KERNEL = self.kernel STRIDE = self.stride LOCAL_BS = local_batch_size @@ -809,7 +797,6 @@ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch B, C, N, H, W = input_imgs.shape - out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 @@ -868,8 +855,8 @@ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch return AutoencoderKLOutput(latent_dist=posterior) - def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]: + # TODO(aryan): rewrite to decode and tiled_decode KERNEL = self.kernel STRIDE = self.stride @@ -968,11 +955,6 @@ def forward( return (dec,) return DecoderOutput(sample=dec) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - kwargs["torch_type"] = torch.float32 - return super().from_pretrained(pretrained_model_name_or_path, **kwargs) def prepare_for_blend(n_param, h_param, w_param, x): From fd18f9adcb7ee82d528fb43c5fc5bd4430674bc4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 02:07:44 +0200 Subject: [PATCH 05/33] make style --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_processor.py | 6 +- .../autoencoders/autoencoder_kl_allegro.py | 377 ++++++++++-------- src/diffusers/models/embeddings.py | 21 +- .../transformers/transformer_allegro.py | 166 ++++---- .../pipelines/allegro/pipeline_allegro.py | 64 +-- .../pipelines/allegro/pipeline_output.py | 2 +- 8 files changed, 361 insertions(+), 281 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dab0ee1db1a8..22731be196c2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -77,6 +77,7 @@ else: _import_structure["models"].extend( [ + "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", @@ -85,7 +86,6 @@ "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", - "AllegroTransformer3DModel", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "ConsistencyDecoderVAE", @@ -559,9 +559,9 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( + AllegroTransformer3DModel, AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, - AllegroTransformer3DModel, AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 310c35c4cb72..38dd2819133d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -100,8 +100,8 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( - AuraFlowTransformer2DModel, AllegroTransformer3DModel, + AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, DiTTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ca91dd436a39..d0d20d177557 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1514,7 +1514,9 @@ class AllegroAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -1569,7 +1571,7 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None and not attn.is_cross_attention: from .embeddings import apply_rotary_emb_allegro - + query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1]) key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1]) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 7440c4070e14..33e4cfbb1b35 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -14,24 +14,20 @@ # limitations under the License. import math -from dataclasses import dataclass -import os -from typing import Dict, Optional, Tuple, Union -from einops import rearrange +from typing import Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F +from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..modeling_outputs import AutoencoderKLOutput +from ..attention_processor import Attention, SpatialNorm from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution -from ..attention_processor import Attention +from ..downsampling import Downsample2D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin from ..resnet import ResnetBlock2D from ..upsampling import Upsample2D -from ..downsampling import Downsample2D -from ..attention_processor import SpatialNorm class AllegroTemporalConvBlock(nn.Module): @@ -40,9 +36,17 @@ class AllegroTemporalConvBlock(nn.Module): https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ - def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, up_sample: bool = False, down_sample: bool = False, stride: int = 1) -> None: + def __init__( + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + up_sample: bool = False, + down_sample: bool = False, + stride: int = 1, + ) -> None: super().__init__() - + out_dim = out_dim or in_dim pad_h = pad_w = int((stride - 1) * 0.5) pad_t = 0 @@ -52,21 +56,21 @@ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = if down_sample: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2,1,1), padding=(0, pad_h, pad_w)) + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)), ) elif up_sample: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim*2, (1, stride, stride), padding=(0, pad_h, pad_w)) + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)), ) else: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)) + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), @@ -94,38 +98,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: identity = hidden_states if self.down_sample: - identity = identity[:,:,::2] + identity = identity[:, :, ::2] elif self.up_sample: - hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) + hidden_states_new = torch.cat((hidden_states, hidden_states), dim=2) hidden_states_new[:, :, 0::2] = hidden_states hidden_states_new[:, :, 1::2] = hidden_states identity = hidden_states_new del hidden_states_new - + if self.down_sample or self.up_sample: hidden_states = self.conv1(hidden_states) else: - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) hidden_states = self.conv1(hidden_states) if self.up_sample: - hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) + hidden_states = rearrange(hidden_states, "b (d c) f h w -> b c (f d) h w", d=2) - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) hidden_states = self.conv2(hidden_states) - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) hidden_states = self.conv3(hidden_states) - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states return hidden_states - + class AllegroDownBlock3D(nn.Module): def __init__( @@ -178,11 +182,7 @@ def __init__( if add_temp_downsample: self.temp_convs_down = AllegroTemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - down_sample=True, - stride=3 + out_channels, out_channels, dropout=0.1, down_sample=True, stride=3 ) self.add_temp_downsample = add_temp_downsample @@ -196,16 +196,16 @@ def __init__( ) else: self.downsamplers = None - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] - + for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) - + if self.add_temp_downsample: hidden_states = self.temp_convs_down(hidden_states) @@ -214,7 +214,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + return hidden_states @@ -259,24 +259,20 @@ def __init__( ) ) temp_convs.append( - AllegroTemporalConvBlock( + AllegroTemporalConvBlock( out_channels, out_channels, dropout=0.1, ) ) - + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.add_temp_upsample = add_temp_upsample if add_temp_upsample: self.temp_conv_up = AllegroTemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - up_sample=True, - stride=3 + out_channels, out_channels, dropout=0.1, up_sample=True, stride=3 ) if self.add_upsample: @@ -286,13 +282,13 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] - + for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) - + if self.add_temp_upsample: hidden_states = self.temp_conv_up(hidden_states) @@ -301,10 +297,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + return hidden_states - + class UNetMidBlock3DConv(nn.Module): def __init__( self, @@ -399,35 +395,35 @@ def __init__( self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) - def forward( - self, - hidden_states: torch.Tensor - ): + def forward(self, hidden_states: torch.Tensor): batch_size = hidden_states.shape[0] - + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.resnets[0](hidden_states, temb=None) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = self.temp_convs[0](hidden_states) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - for attn, resnet, temp_conv in zip( - self.attentions, self.resnets[1:], self.temp_convs[1:] - ): + for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb=None) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) hidden_states = temp_conv(hidden_states) - + return hidden_states - + class AllegroEncoder3D(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D"), + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), blocks_temp_li=[False, False, False, False], block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, @@ -452,7 +448,7 @@ def __init__( in_channels=block_out_channels[0], out_channels=block_out_channels[0], kernel_size=(3, 1, 1), - padding=(1, 0, 0) + padding=(1, 0, 0), ) self.down_blocks = nn.ModuleList([]) @@ -478,7 +474,7 @@ def __init__( ) else: raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`") - + self.down_blocks.append(down_block) # mid @@ -499,7 +495,7 @@ def __init__( conv_out_channels = 2 * out_channels if double_z else out_channels - self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0)) + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) @@ -507,20 +503,20 @@ def __init__( def forward(self, sample: torch.Tensor) -> torch.Tensor: batch_size = sample.shape[0] - + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_in(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + residual = sample - sample = self.temp_conv_in(sample) + sample = self.temp_conv_in(sample) sample = sample + residual - + # Down blocks for down_block in self.down_blocks: sample = down_block(sample) - + # Mid block sample = self.mid_block(sample) @@ -531,13 +527,13 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) residual = sample - sample = self.temp_conv_out(sample) + sample = self.temp_conv_out(sample) sample = sample + residual sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + return sample @@ -546,7 +542,12 @@ def __init__( self, in_channels: int = 4, out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D"), + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), blocks_temp_li=[False, False, False, False], block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, @@ -567,12 +568,7 @@ def __init__( padding=1, ) - self.temp_conv_in = nn.Conv3d( - block_out_channels[-1], - block_out_channels[-1], - (3,1,1), - padding = (1, 0, 0) - ) + self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -615,7 +611,7 @@ def __init__( ) else: raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`") - + self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -624,10 +620,10 @@ def __init__( self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - + self.conv_act = nn.SiLU() - self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0)) + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) # TODO(aryan): implement gradient checkpointing @@ -639,13 +635,13 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_in(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + residual = sample - sample = self.temp_conv_in(sample) + sample = self.temp_conv_in(sample) sample = sample + residual upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - + # Mid block sample = self.mid_block(sample) sample = sample.to(upscale_dtype) @@ -659,17 +655,17 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + residual = sample sample = self.temp_conv_out(sample) sample = sample + residual - + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + return sample - + class AutoencoderKLAllegro(ModelMixin, ConfigMixin): r""" @@ -743,7 +739,7 @@ def __init__( self.blocks_tempdown_li = blocks_tempdown_li self.blocks_tempup_li = blocks_tempup_li - + self.encoder = AllegroEncoder3D( in_channels=in_channels, out_channels=latent_channels, @@ -772,23 +768,25 @@ def __init__( self.use_tiling = False # only relevant if vae tiling is enabled - sample_size = ( - sample_size[0] - if isinstance(sample_size, (list, tuple)) - else sample_size - ) + sample_size = sample_size[0] if isinstance(sample_size, (list, tuple)) else sample_size self.tile_overlap = tile_overlap - self.vae_scale_factor=[4, 8, 8] + self.vae_scale_factor = [4, 8, 8] self.sample_size = sample_size self.chunk_len = chunk_len self.t_over = t_over - self.latent_chunk_len = self.chunk_len//4 - self.latent_t_over = self.t_over//4 - self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256) - self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192) - - def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + self.latent_chunk_len = self.chunk_len // 4 + self.latent_t_over = self.t_over // 4 + self.kernel = (self.chunk_len, self.sample_size, self.sample_size) # (24, 256, 256) + self.stride = ( + self.chunk_len - self.t_over, + self.sample_size - self.tile_overlap[0], + self.sample_size - self.tile_overlap[1], + ) # (16, 112, 192) + + def encode( + self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1 + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: # TODO(aryan): rewrite to encode and tiled_encode KERNEL = self.kernel STRIDE = self.stride @@ -796,16 +794,22 @@ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch OUT_C = 8 B, C, N, H, W = input_imgs.shape - + out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 - + ## cut video into overlapped small cubes and batch forward num = 0 - out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype) - vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) + out_latent = torch.zeros( + (out_n * out_h * out_w, OUT_C, KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8), + device=input_imgs.device, + dtype=input_imgs.dtype, + ) + vae_batch_input = torch.zeros( + (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype + ) for i in range(out_n): for j in range(out_h): @@ -814,39 +818,50 @@ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num%LOCAL_BS] = video_cube - - if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: + vae_batch_input[num % LOCAL_BS] = video_cube + + if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: latent = self.encoder(vae_batch_input) - - if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: - out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + + if num == out_n * out_h * out_w - 1 and num % LOCAL_BS != LOCAL_BS - 1: + out_latent[num - num % LOCAL_BS :] = latent[: num % LOCAL_BS + 1] else: - out_latent[num-LOCAL_BS+1:num+1] = latent - vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) - num+=1 - + out_latent[num - LOCAL_BS + 1 : num + 1] = latent + vae_batch_input = torch.zeros( + (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), + device=input_imgs.device, + dtype=input_imgs.dtype, + ) + num += 1 + ## flatten the batched out latent to videos and supress the overlapped parts B, C, N, H, W = input_imgs.shape - out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype) - OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 - OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 - OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2] - + out_video_cube = torch.zeros( + (B, OUT_C, N // 4, H // 8, W // 8), device=input_imgs.device, dtype=input_imgs.dtype + ) + OUT_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 + OUT_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 + OVERLAP = OUT_KERNEL[0] - OUT_STRIDE[0], OUT_KERNEL[1] - OUT_STRIDE[1], OUT_KERNEL[2] - OUT_STRIDE[2] + for i in range(out_n): n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0] for j in range(out_h): h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1] for k in range(out_w): w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2] - latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + latent_mean_blend = prepare_for_blend( + (i, out_n, OVERLAP[0]), + (j, out_h, OVERLAP[1]), + (k, out_w, OVERLAP[2]), + out_latent[i * out_h * out_w + j * out_w + k].unsqueeze(0), + ) out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend - + ## final conv - out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w') + out_video_cube = rearrange(out_video_cube, "b c n h w -> (b n) c h w") out_video_cube = self.quant_conv(out_video_cube) - out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B) + out_video_cube = rearrange(out_video_cube, "(b n) c h w -> b c n h w", b=B) posterior = DiagonalGaussianDistribution(out_video_cube) @@ -854,24 +869,26 @@ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - - def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]: + + def decode( + self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1 + ) -> Union[DecoderOutput, torch.Tensor]: # TODO(aryan): rewrite to decode and tiled_decode KERNEL = self.kernel STRIDE = self.stride - + LOCAL_BS = local_batch_size OUT_C = 3 - IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 - IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 + IN_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 + IN_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 B, C, N, H, W = input_latents.shape ## post quant conv (a mapping) - input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w') + input_latents = rearrange(input_latents, "b c n h w -> (b n) c h w") input_latents = self.post_quant_conv(input_latents) - input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B) - + input_latents = rearrange(input_latents, "(b n) c h w -> b c n h w", b=B) + ## out tensor shape out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1 @@ -879,8 +896,16 @@ def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_ba ## cut latent into overlapped small cubes and batch forward num = 0 - decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) - vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + decoded_cube = torch.zeros( + (out_n * out_h * out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), + device=input_latents.device, + dtype=input_latents.dtype, + ) + vae_batch_input = torch.zeros( + (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), + device=input_latents.device, + dtype=input_latents.dtype, + ) for i in range(out_n): for j in range(out_h): for k in range(out_w): @@ -888,38 +913,48 @@ def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_ba h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num%LOCAL_BS] = latent_cube - if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: - + vae_batch_input[num % LOCAL_BS] = latent_cube + if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: latent = self.decoder(vae_batch_input) - - if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: - decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + + if num == out_n * out_h * out_w - 1 and num % LOCAL_BS != LOCAL_BS - 1: + decoded_cube[num - num % LOCAL_BS :] = latent[: num % LOCAL_BS + 1] else: - decoded_cube[num-LOCAL_BS+1:num+1] = latent - vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) - num+=1 + decoded_cube[num - LOCAL_BS + 1 : num + 1] = latent + vae_batch_input = torch.zeros( + (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), + device=input_latents.device, + dtype=input_latents.dtype, + ) + num += 1 B, C, N, H, W = input_latents.shape - - out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype) - OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2] + + out_video = torch.zeros( + (B, OUT_C, N * 4, H * 8, W * 8), device=input_latents.device, dtype=input_latents.dtype + ) + OVERLAP = KERNEL[0] - STRIDE[0], KERNEL[1] - STRIDE[1], KERNEL[2] - STRIDE[2] for i in range(out_n): n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] for j in range(out_h): h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] for k in range(out_w): w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] - out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + out_video_blend = prepare_for_blend( + (i, out_n, OVERLAP[0]), + (j, out_h, OVERLAP[1]), + (k, out_w, OVERLAP[2]), + decoded_cube[i * out_h * out_w + j * out_w + k].unsqueeze(0), + ) out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend - - out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous() + + out_video = rearrange(out_video, "b c t h w -> b t c h w").contiguous() decoded = out_video if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) - + def forward( self, sample: torch.Tensor, @@ -936,7 +971,7 @@ def forward( Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator`, *optional*): PyTorch random number generator. encoder_local_batch_size (`int`, *optional*, defaults to 2): Local batch size for the encoder's batch inference. @@ -962,16 +997,28 @@ def prepare_for_blend(n_param, h_param, w_param, x): h, h_max, overlap_h = h_param w, w_max, overlap_w = w_param if overlap_n > 0: - if n > 0: # the head overlap part decays from 0 to 1 - x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) - if n < n_max-1: # the tail overlap part decays from 1 to 0 - x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) + if n > 0: # the head overlap part decays from 0 to 1 + x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * ( + torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if n < n_max - 1: # the tail overlap part decays from 1 to 0 + x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * ( + 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) if h > 0: - x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) - if h < h_max-1: - x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) + x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * ( + torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if h < h_max - 1: + x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * ( + 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) if w > 0: - x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w) - if w < w_max-1: - x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w) + x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * ( + torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + if w < w_max - 1: + x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * ( + 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) return x diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 99389f3ab8f1..ad8f5f43c512 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -565,7 +565,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): def get_3d_rotary_pos_embed_allegro( - embed_dim, crops_coords, grid_size, temporal_size, interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), theta: int = 10000 + embed_dim, + crops_coords, + grid_size, + temporal_size, + interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), + theta: int = 10000, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: # TODO(aryan): docs start, stop = crops_coords @@ -581,10 +586,16 @@ def get_3d_rotary_pos_embed_allegro( dim_w = embed_dim // 3 # Temporal frequencies - freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False) + freqs_t = get_1d_rotary_pos_embed( + dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False + ) # Spatial frequencies for height and width - freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False) - freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False) + freqs_h = get_1d_rotary_pos_embed( + dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False + ) + freqs_w = get_1d_rotary_pos_embed( + dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False + ) return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w @@ -773,7 +784,7 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): def apply_1d_rope(tokens, pos, cos, sin): cos = F.embedding(pos, cos)[:, None, :, :] sin = F.embedding(pos, sin)[:, None, :, :] - x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2:] + x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :] tokens_rotated = torch.cat((-x2, x1), dim=-1) return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 86a13cc5d582..14fde251c601 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -13,35 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os -from dataclasses import dataclass -from functools import partial -from importlib import import_module -from typing import Any, Callable, Dict, Optional, Tuple - -import numpy as np +from typing import Optional, Tuple + import torch -import collections +import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel +from einops import rearrange + from ...configuration_utils import ConfigMixin, register_to_config -from ..activations import GEGLU, GELU, ApproximateGELU +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward from ..attention_processor import ( - Attention, AllegroAttnProcessor2_0, + Attention, ) -from ..embeddings import PixArtAlphaTextProjection, SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps, PatchEmbed +from ..embeddings import PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, AdaLayerNormZero -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available -from ...utils.torch_utils import maybe_allow_in_graph -from einops import rearrange, repeat -import torch.nn as nn from ..normalization import AllegroAdaLayerNormSingle -from ..modeling_outputs import Transformer2DModelOutput -from ..attention import FeedForward -from ...utils import logging + logger = logging.get_logger(__name__) @@ -51,7 +42,7 @@ class PatchEmbed2D(nn.Module): def __init__( self, - num_frames=1, + num_frames=1, height=224, width=224, patch_size_t=1, @@ -61,7 +52,7 @@ def __init__( layer_norm=False, flatten=True, bias=True, - use_abs_pos=False, + use_abs_pos=False, ): super().__init__() self.use_abs_pos = use_abs_pos @@ -83,7 +74,7 @@ def forward(self, latent): b, _, _, _, _ = latent.shape video_latent = None - latent = rearrange(latent, 'b c t h w -> (b t) c h w') + latent = rearrange(latent, "b c t h w -> (b t) c h w") latent = self.proj(latent) if self.flatten: @@ -91,7 +82,7 @@ def forward(self, latent): if self.layer_norm: latent = self.norm(latent) - latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) + latent = rearrange(latent, "(b t) n c -> b (t n) c", b=b) video_latent = latent return video_latent @@ -167,7 +158,7 @@ def forward( temb: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb = None, + image_rotary_emb=None, ) -> torch.Tensor: # 0. Self-Attention batch_size = hidden_states.shape[0] @@ -178,7 +169,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states.squeeze(1) - + attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, @@ -249,46 +240,45 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ -# { -# "_class_name": "AllegroTransformer3DModel", -# "_diffusers_version": "0.30.3", -# "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/transformer", -# "activation_fn": "gelu-approximate", -# "attention_bias": true, -# "attention_head_dim": 96, -# "ca_attention_mode": "xformers", -# "caption_channels": 4096, -# "cross_attention_dim": 2304, -# "double_self_attention": false, -# "downsampler": null, -# "dropout": 0.0, -# "in_channels": 4, -# "interpolation_scale_h": 2.0, -# "interpolation_scale_t": 2.2, -# "interpolation_scale_w": 2.0, -# "model_max_length": 300, -# "norm_elementwise_affine": false, -# "norm_eps": 1e-06, -# "norm_type": "ada_norm_single", -# "num_attention_heads": 24, -# "num_embeds_ada_norm": 1000, -# "num_layers": 32, -# "only_cross_attention": false, -# "out_channels": 4, -# "patch_size": 2, -# "patch_size_t": 1, -# "sa_attention_mode": "flash", -# "sample_size": [ -# 90, -# 160 -# ], -# "sample_size_t": 22, -# "upcast_attention": false, -# "use_additional_conditions": null, -# "use_linear_projection": false, -# "use_rope": true -# } - + # { + # "_class_name": "AllegroTransformer3DModel", + # "_diffusers_version": "0.30.3", + # "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/transformer", + # "activation_fn": "gelu-approximate", + # "attention_bias": true, + # "attention_head_dim": 96, + # "ca_attention_mode": "xformers", + # "caption_channels": 4096, + # "cross_attention_dim": 2304, + # "double_self_attention": false, + # "downsampler": null, + # "dropout": 0.0, + # "in_channels": 4, + # "interpolation_scale_h": 2.0, + # "interpolation_scale_t": 2.2, + # "interpolation_scale_w": 2.0, + # "model_max_length": 300, + # "norm_elementwise_affine": false, + # "norm_eps": 1e-06, + # "norm_type": "ada_norm_single", + # "num_attention_heads": 24, + # "num_embeds_ada_norm": 1000, + # "num_layers": 32, + # "only_cross_attention": false, + # "out_channels": 4, + # "patch_size": 2, + # "patch_size_t": 1, + # "sa_attention_mode": "flash", + # "sample_size": [ + # 90, + # 160 + # ], + # "sample_size_t": 22, + # "upcast_attention": false, + # "use_additional_conditions": null, + # "use_linear_projection": false, + # "use_rope": true + # } @register_to_config def __init__( @@ -318,15 +308,19 @@ def __init__( model_max_length: int = 300, ): super().__init__() - + self.inner_dim = num_attention_heads * attention_head_dim - + interpolation_scale_t = ( - interpolation_scale_t if interpolation_scale_t is not None else ((sample_frames - 1) // 16 + 1) if sample_frames % 2 == 1 else sample_frames // 16 + interpolation_scale_t + if interpolation_scale_t is not None + else ((sample_frames - 1) // 16 + 1) + if sample_frames % 2 == 1 + else sample_frames // 16 ) interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 - + # 1. Patch embedding self.pos_embed = PatchEmbed2D( height=sample_height, @@ -365,10 +359,8 @@ def __init__( self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) # 5. Caption projection - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim - ) - + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) + self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): @@ -413,23 +405,31 @@ def forward( attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) - attention_mask = (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None + attention_mask = ( + (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None + ) # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # b, 1+use_image_num, l -> a video with images # b, 1, l -> only images encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None + encoder_attention_mask = ( + rearrange(encoder_attention_mask, "b 1 l -> (b 1) 1 l") if encoder_attention_mask.numel() > 0 else None + ) # 1. Input post_patch_num_frames = num_frames // self.config.patch_size_temporal post_patch_height = height // self.config.patch_size post_patch_width = width // self.config.patch_size - timestep, embedded_timestep = self.adaln_single(timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype) - - hidden_states = self.pos_embed(hidden_states) # TODO(aryan): remove dtype conversion here and move to pipeline if needed + timestep, embedded_timestep = self.adaln_single( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + hidden_states = self.pos_embed( + hidden_states + ) # TODO(aryan): remove dtype conversion here and move to pipeline if needed encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) @@ -455,7 +455,9 @@ def forward( hidden_states = hidden_states.squeeze(1) # unpatchify - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 + ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 6a0a71275012..b73b2b9e7f84 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -19,11 +19,13 @@ import re import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union + import torch -from dataclasses import dataclass -from transformers import T5EncoderModel, T5Tokenizer import tqdm +from transformers import T5EncoderModel, T5Tokenizer +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -32,13 +34,11 @@ is_ftfy_available, logging, replace_example_docstring, - BaseOutput ) from ...utils.torch_utils import randn_tensor -from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro -from ...models.embeddings import get_3d_rotary_pos_embed_allegro -from .pipeline_output import AllegroPipelineOutput from ...video_processor import VideoProcessor +from .pipeline_output import AllegroPipelineOutput + logger = logging.get_logger(__name__) @@ -55,13 +55,16 @@ >>> import torch >>> # You can replace the your_path_to_model with your own path. - >>> pipe = AllegroPipeline.from_pretrained(your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True) + >>> pipe = AllegroPipeline.from_pretrained( + ... your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True + ... ) >>> prompt = "A small cactus with a happy face in the Sahara desert." >>> image = pipe(prompt).video[0] ``` """ + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -148,8 +151,21 @@ class AllegroPipeline(DiffusionPipeline): scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ + bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] @@ -219,7 +235,6 @@ def encode_prompt( If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. """ - embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None if device is None: device = self._execution_device @@ -339,7 +354,7 @@ def prepare_extra_step_kwargs(self, generator, eta): def check_inputs( self, prompt, - num_frames, + num_frames, height, width, negative_prompt, @@ -349,7 +364,6 @@ def check_inputs( prompt_attention_mask=None, negative_prompt_attention_mask=None, ): - if num_frames <= 0: raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") if height % 8 != 0 or width % 8 != 0: @@ -406,7 +420,6 @@ def check_inputs( f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): @@ -549,7 +562,7 @@ def _clean_caption(self, caption): caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() - + def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -558,7 +571,7 @@ def prepare_latents( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + if num_frames % 2 == 0: num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) else: @@ -586,7 +599,7 @@ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: frames = self.vae.decode(latents).sample frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] return frames - + def _prepare_rotary_positional_embeddings( self, batch_size: int, @@ -612,7 +625,11 @@ def _prepare_rotary_positional_embeddings( crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, - interpolation_scale=(self.transformer.config.interpolation_scale_t, self.transformer.config.interpolation_scale_h, self.transformer.config.interpolation_scale_w) + interpolation_scale=( + self.transformer.config.interpolation_scale_t, + self.transformer.config.interpolation_scale_h, + self.transformer.config.interpolation_scale_w, + ), ) grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long) @@ -738,7 +755,7 @@ def __call__( self.check_inputs( prompt, - num_frames, + num_frames, height, width, negative_prompt, @@ -796,7 +813,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, - num_frames, + num_frames, height, width, prompt_embeds.dtype, @@ -809,14 +826,15 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare rotary embeddings - image_rotary_emb = self._prepare_rotary_positional_embeddings(batch_size, height, width, latents.size(2), device) + image_rotary_emb = self._prepare_rotary_positional_embeddings( + batch_size, height, width, latents.size(2), device + ) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) progress_wrap = tqdm.tqdm if verbose else (lambda x: x) for i, t in progress_wrap(list(enumerate(timesteps))): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -839,15 +857,15 @@ def __call__( prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l - + # prepare attention_mask. # b c t h w -> b t h w attention_mask = torch.ones_like(latent_model_input)[:, 0] - + # predict noise model_output noise_pred = self.transformer( latent_model_input, - attention_mask=attention_mask, + attention_mask=attention_mask, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, diff --git a/src/diffusers/pipelines/allegro/pipeline_output.py b/src/diffusers/pipelines/allegro/pipeline_output.py index ed8ca1862540..6a721783ca86 100644 --- a/src/diffusers/pipelines/allegro/pipeline_output.py +++ b/src/diffusers/pipelines/allegro/pipeline_output.py @@ -2,8 +2,8 @@ from typing import List, Union import numpy as np -import torch import PIL +import torch from diffusers.utils import BaseOutput From 4f1653c2659cc095973dfb83f0b0e29318db53c2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 02:25:00 +0200 Subject: [PATCH 06/33] refactor part 4; modeling tests --- .../transformers/transformer_allegro.py | 46 +---------- .../pipelines/allegro/pipeline_allegro.py | 2 - .../test_models_transformer_allegro.py | 79 +++++++++++++++++++ tests/pipelines/allegro/__init__.py | 0 tests/pipelines/allegro/test_allegro.py | 0 5 files changed, 81 insertions(+), 46 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_allegro.py create mode 100644 tests/pipelines/allegro/__init__.py create mode 100644 tests/pipelines/allegro/test_allegro.py diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 14fde251c601..cedc9bedfa65 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -239,47 +239,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ - - # { - # "_class_name": "AllegroTransformer3DModel", - # "_diffusers_version": "0.30.3", - # "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/transformer", - # "activation_fn": "gelu-approximate", - # "attention_bias": true, - # "attention_head_dim": 96, - # "ca_attention_mode": "xformers", - # "caption_channels": 4096, - # "cross_attention_dim": 2304, - # "double_self_attention": false, - # "downsampler": null, - # "dropout": 0.0, - # "in_channels": 4, - # "interpolation_scale_h": 2.0, - # "interpolation_scale_t": 2.2, - # "interpolation_scale_w": 2.0, - # "model_max_length": 300, - # "norm_elementwise_affine": false, - # "norm_eps": 1e-06, - # "norm_type": "ada_norm_single", - # "num_attention_heads": 24, - # "num_embeds_ada_norm": 1000, - # "num_layers": 32, - # "only_cross_attention": false, - # "out_channels": 4, - # "patch_size": 2, - # "patch_size_t": 1, - # "sa_attention_mode": "flash", - # "sample_size": [ - # 90, - # 160 - # ], - # "sample_size_t": 22, - # "upcast_attention": false, - # "use_additional_conditions": null, - # "use_linear_projection": false, - # "use_rope": true - # } - + @register_to_config def __init__( self, @@ -304,8 +264,6 @@ def __init__( interpolation_scale_h: float = 2.0, interpolation_scale_w: float = 2.0, interpolation_scale_t: float = 2.2, - use_rotary_positional_embeddings: bool = True, - model_max_length: int = 300, ): super().__init__() @@ -369,8 +327,8 @@ def _set_gradient_checkpointing(self, module, value=False): def forward( self, hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index b73b2b9e7f84..4a1bdf139dab 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -193,7 +193,6 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -207,7 +206,6 @@ def encode_prompt( negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, clean_caption: bool = False, max_sequence_length: int = 300, - **kwargs, ): r""" Encodes the prompt into text encoder hidden states. diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py new file mode 100644 index 000000000000..6e23649176c2 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -0,0 +1,79 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AllegroTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = AllegroTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 8 + height = 8 + width = 8 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 8, 8, 8) + + @property + def output_shape(self): + return (4, 8, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "cross_attention_dim": 16, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "caption_channels": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/pipelines/allegro/__init__.py b/tests/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py new file mode 100644 index 000000000000..e69de29bb2d1 From 412cd7cf848716da8aae261baa6b4ddca585dac8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 02:26:55 +0200 Subject: [PATCH 07/33] make style --- src/diffusers/models/transformers/transformer_allegro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index cedc9bedfa65..e0324e4495ed 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -239,7 +239,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ - + @register_to_config def __init__( self, From 8f9ffa8f615c327c80390a2cb2f9a675a3421492 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 02:59:12 +0200 Subject: [PATCH 08/33] refactor part 5 --- .../autoencoders/autoencoder_kl_allegro.py | 15 ++-- .../transformers/transformer_allegro.py | 81 +++---------------- .../pipelines/allegro/pipeline_allegro.py | 24 +----- 3 files changed, 20 insertions(+), 100 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 33e4cfbb1b35..dab92736694d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ..attention_processor import Attention, SpatialNorm @@ -114,7 +113,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv1(hidden_states) if self.up_sample: - hidden_states = rearrange(hidden_states, "b (d c) f h w -> b c (f d) h w", d=2) + hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3) hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) @@ -858,10 +857,10 @@ def encode( ) out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend - ## final conv - out_video_cube = rearrange(out_video_cube, "b c n h w -> (b n) c h w") + # final conv + out_video_cube = out_video_cube.permute(0, 2, 1, 3, 4).flatten(0, 1) out_video_cube = self.quant_conv(out_video_cube) - out_video_cube = rearrange(out_video_cube, "(b n) c h w -> b c n h w", b=B) + out_video_cube = out_video_cube.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) posterior = DiagonalGaussianDistribution(out_video_cube) @@ -885,9 +884,9 @@ def decode( B, C, N, H, W = input_latents.shape ## post quant conv (a mapping) - input_latents = rearrange(input_latents, "b c n h w -> (b n) c h w") + input_latents = input_latents.permute(0, 2, 1, 3, 4).flatten(0, 1) input_latents = self.post_quant_conv(input_latents) - input_latents = rearrange(input_latents, "(b n) c h w -> b c n h w", b=B) + input_latents = input_latents.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) ## out tensor shape out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 @@ -947,7 +946,7 @@ def decode( ) out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend - out_video = rearrange(out_video, "b c t h w -> b t c h w").contiguous() + out_video = out_video.permute(0, 2, 1, 3, 4).contiguous() decoded = out_video if not return_dict: diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index e0324e4495ed..8e26df15a271 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -18,17 +18,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import ( - AllegroAttnProcessor2_0, - Attention, -) -from ..embeddings import PixArtAlphaTextProjection +from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AllegroAdaLayerNormSingle @@ -37,57 +33,6 @@ logger = logging.get_logger(__name__) -class PatchEmbed2D(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - num_frames=1, - height=224, - width=224, - patch_size_t=1, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - use_abs_pos=False, - ): - super().__init__() - self.use_abs_pos = use_abs_pos - self.flatten = flatten - self.layer_norm = layer_norm - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - self.patch_size_t = patch_size_t - self.patch_size = patch_size - - def forward(self, latent): - b, _, _, _, _ = latent.shape - video_latent = None - - latent = rearrange(latent, "b c t h w -> (b t) c h w") - - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C - if self.layer_norm: - latent = self.norm(latent) - - latent = rearrange(latent, "(b t) n c -> b (t n) c", b=b) - video_latent = latent - - return video_latent - - @maybe_allow_in_graph class AllegroTransformerBlock(nn.Module): r""" @@ -280,13 +225,13 @@ def __init__( interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 # 1. Patch embedding - self.pos_embed = PatchEmbed2D( + self.pos_embed = PatchEmbed( height=sample_height, width=sample_width, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, - # pos_embed_type=None, + pos_embed_type=None, ) # 2. Transformer blocks @@ -327,8 +272,8 @@ def _set_gradient_checkpointing(self, module, value=False): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, @@ -368,13 +313,9 @@ def forward( ) # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: - # b, 1+use_image_num, l -> a video with images - # b, 1, l -> only images + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask = ( - rearrange(encoder_attention_mask, "b 1 l -> (b 1) 1 l") if encoder_attention_mask.numel() > 0 else None - ) + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input post_patch_num_frames = num_frames // self.config.patch_size_temporal @@ -385,9 +326,9 @@ def forward( timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) - hidden_states = self.pos_embed( - hidden_states - ) # TODO(aryan): remove dtype conversion here and move to pipeline if needed + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.pos_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 4a1bdf139dab..d4de3dae9206 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -836,25 +836,11 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - current_timestep = t - if not torch.is_tensor(current_timestep): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) - elif len(current_timestep.shape) == 0: - current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(latent_model_input.shape[0]) + timestep = t.expand(latent_model_input.shape[0]) if prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d - if prompt_attention_mask.ndim == 2: - prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l # prepare attention_mask. # b c t h w -> b t h w @@ -866,7 +852,7 @@ def __call__( attention_mask=attention_mask, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, + timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] @@ -876,12 +862,6 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # learned sigma - if latent_channels == self.transformer.config.out_channels // 2: - noise_pred = noise_pred.chunk(2, dim=1)[0] - else: - noise_pred = noise_pred - # compute previous image: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From c76dc5a0a3e8a57dfd0251a395e6eab803fda765 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 03:12:36 +0200 Subject: [PATCH 09/33] refactor part 6 --- .../pipelines/allegro/pipeline_allegro.py | 161 ++++++++++-------- .../pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 2 files changed, 92 insertions(+), 71 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index d4de3dae9206..f77e27d35ea2 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -18,12 +18,12 @@ import math import re import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch -import tqdm from transformers import T5EncoderModel, T5Tokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro from ...models.embeddings import get_3d_rotary_pos_embed_allegro from ...pipelines.pipeline_utils import DiffusionPipeline @@ -171,6 +171,12 @@ class AllegroPipeline(DiffusionPipeline): _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + def __init__( self, tokenizer: T5Tokenizer, @@ -198,7 +204,7 @@ def encode_prompt( prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -286,10 +292,10 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -320,11 +326,11 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None @@ -355,8 +361,8 @@ def check_inputs( num_frames, height, width, - negative_prompt, - callback_steps, + callback_on_step_end_tensor_inputs, + negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, @@ -367,12 +373,11 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -606,20 +611,16 @@ def _prepare_rotary_positional_embeddings( num_frames: int, device: torch.device, ): - attention_head_dim = 96 - vae_scale_factor_spatial = 8 - patch_size = 2 - - grid_height = height // (vae_scale_factor_spatial * patch_size) - grid_width = width // (vae_scale_factor_spatial * patch_size) - base_size_width = 1280 // (vae_scale_factor_spatial * patch_size) - base_size_height = 720 // (vae_scale_factor_spatial * patch_size) + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 1280 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height ) freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( - embed_dim=attention_head_dim, + embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, @@ -653,10 +654,10 @@ def __call__( num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, + num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -666,11 +667,12 @@ def __call__( negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], clean_caption: bool = True, max_sequence_length: int = 300, - verbose: bool = True, ) -> Union[AllegroPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -746,6 +748,12 @@ def __call__( If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + # 1. Check inputs. Raise error if not correct num_frames = num_frames or self.transformer.config.sample_size_t * self.vae_scale_factor_temporal height = height or self.transformer.config.sample_size[0] * self.vae_scale_factor_spatial @@ -756,13 +764,15 @@ def __call__( num_frames, height, width, + callback_on_step_end_tensor_inputs, negative_prompt, - callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) + self._guidance_scale = guidance_scale + self._interrupt = False # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): @@ -789,7 +799,7 @@ def __call__( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -809,7 +819,7 @@ def __call__( # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( - batch_size * num_images_per_prompt, + batch_size * num_videos_per_prompt, latent_channels, num_frames, height, @@ -831,45 +841,56 @@ def __call__( # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - progress_wrap = tqdm.tqdm if verbose else (lambda x: x) - for i, t in progress_wrap(list(enumerate(timesteps))): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - if prompt_embeds.ndim == 3: - prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d - - # prepare attention_mask. - # b c t h w -> b t h w - attention_mask = torch.ones_like(latent_model_input)[:, 0] - - # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - attention_mask=attention_mask, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute previous image: x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + + # prepare attention_mask. + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() if not output_type == "latent": latents = latents.to(self.vae.dtype) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 9cb042c9e80c..16150cbd79bf 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -649,8 +649,8 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, + device, latents, ) From 015cc78bffc1330c258a91d21e5e0633cfffd3bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 03:27:32 +0200 Subject: [PATCH 10/33] gradient checkpointing --- .../autoencoders/autoencoder_kl_allegro.py | 62 +++++++++++++++---- .../transformers/transformer_allegro.py | 40 +++++++++--- 2 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index dab92736694d..09d82a1b0f47 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -512,12 +512,26 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.temp_conv_in(sample) sample = sample + residual - # Down blocks - for down_block in self.down_blocks: - sample = down_block(sample) + if self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) - # Mid block - sample = self.mid_block(sample) + return custom_forward + + # Down blocks + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + + # Mid block + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + else: + # Down blocks + for down_block in self.down_blocks: + sample = down_block(sample) + + # Mid block + sample = self.mid_block(sample) # Post process sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -625,7 +639,6 @@ def __init__( self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - # TODO(aryan): implement gradient checkpointing self.gradient_checkpointing = False def forward(self, sample: torch.Tensor) -> torch.Tensor: @@ -641,13 +654,34 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # Mid block - sample = self.mid_block(sample) - sample = sample.to(upscale_dtype) + if self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # Mid block + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample + ) + + # Up blocks + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample + ) - # Up blocks - for up_block in self.up_blocks: - sample = up_block(sample) + else: + # Mid block + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # Up blocks + for up_block in self.up_blocks: + sample = up_block(sample) # Post process sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -783,6 +817,10 @@ def __init__( self.sample_size - self.tile_overlap[1], ) # (16, 112, 192) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): + module.gradient_checkpointing = value + def encode( self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1 ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 8e26df15a271..2b7ea524c763 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging +from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention @@ -335,14 +335,34 @@ def forward( for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing - hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=timestep, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - image_rotary_emb=image_rotary_emb, - ) + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + timestep, + attention_mask, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) # 3. Output shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) From 6b53b8596a5d3ad129fca42de29f30e6aad63f11 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 04:01:11 +0200 Subject: [PATCH 11/33] pipeline tests (broken atm) --- .../test_models_transformer_allegro.py | 6 +- tests/pipelines/allegro/test_allegro.py | 311 ++++++++++++++++++ 2 files changed, 314 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py index 6e23649176c2..ad8b7a3824ba 100644 --- a/tests/models/transformers/test_models_transformer_allegro.py +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -37,7 +37,7 @@ class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): batch_size = 2 num_channels = 4 - num_frames = 8 + num_frames = 2 height = 8 width = 8 embedding_dim = 16 @@ -55,11 +55,11 @@ def dummy_input(self): @property def input_shape(self): - return (4, 8, 8, 8) + return (4, 2, 8, 8) @property def output_shape(self): - return (4, 8, 8, 8) + return (4, 2, 8, 8) def prepare_init_args_and_inputs_for_common(self): init_dict = { diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index e69de29bb2d1..daac4e4136b6 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -0,0 +1,311 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AllegroPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AllegroTransformer3DModel( + num_attention_heads=2, + attention_head_dim=12, + in_channels=4, + out_channels=4, + num_layers=1, + cross_attention_dim=32, + sample_width=8, + sample_height=8, + sample_frames=8, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLAllegro( + in_channels=3, + out_channels=3, + down_block_types=( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types=( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 48, + "width": 48, + "num_frames": 8, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + # def test_attention_slicing_forward_pass( + # self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + # ): + # if not self.test_attention_slicing: + # return + + # components = self.get_dummy_components() + # pipe = self.pipeline_class(**components) + # for component in pipe.components.values(): + # if hasattr(component, "set_default_attn_processor"): + # component.set_default_attn_processor() + # pipe.to(torch_device) + # pipe.set_progress_bar_config(disable=None) + + # generator_device = "cpu" + # inputs = self.get_dummy_inputs(generator_device) + # output_without_slicing = pipe(**inputs)[0] + + # pipe.enable_attention_slicing(slice_size=1) + # inputs = self.get_dummy_inputs(generator_device) + # output_with_slicing1 = pipe(**inputs)[0] + + # pipe.enable_attention_slicing(slice_size=2) + # inputs = self.get_dummy_inputs(generator_device) + # output_with_slicing2 = pipe(**inputs)[0] + + # if test_max_difference: + # max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + # max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + # self.assertLess( + # max(max_diff1, max_diff2), + # expected_max_diff, + # "Attention slicing should not affect the inference results", + # ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class AllegroPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogvideox(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=720, + width=1280, + num_frames=88, + generator=generator, + num_inference_steps=2, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 88, 720, 1280, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}" From f64f2d054cf2db25fc757948ef503a3a8e435638 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 22 Oct 2024 04:01:59 +0200 Subject: [PATCH 12/33] update --- .../autoencoders/autoencoder_kl_allegro.py | 51 +++++++++---------- .../pipelines/allegro/pipeline_allegro.py | 12 +++++ 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 09d82a1b0f47..6509242fdc09 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -29,7 +29,7 @@ from ..upsampling import Upsample2D -class AllegroTemporalConvBlock(nn.Module): +class AllegroTemporalConvLayer(nn.Module): r""" Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 @@ -40,6 +40,7 @@ def __init__( in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, + norm_num_groups: int = 32, up_sample: bool = False, down_sample: bool = False, stride: int = 1, @@ -55,44 +56,40 @@ def __init__( if down_sample: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), + nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)), ) elif up_sample: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), + nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)), ) else: self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), + nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), ) self.conv2 = nn.Sequential( - nn.GroupNorm(32, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), ) self.conv3 = nn.Sequential( - nn.GroupNorm(32, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) self.conv4 = nn.Sequential( - nn.GroupNorm(32, out_dim), + nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) - # zero out the last layer params, so the conv block is identity - nn.init.zeros_(self.conv4[-1].weight) - nn.init.zeros_(self.conv4[-1].bias) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: identity = hidden_states @@ -169,10 +166,11 @@ def __init__( ) ) temp_convs.append( - AllegroTemporalConvBlock( + AllegroTemporalConvLayer( out_channels, out_channels, dropout=0.1, + norm_num_groups=resnet_groups, ) ) @@ -180,8 +178,8 @@ def __init__( self.temp_convs = nn.ModuleList(temp_convs) if add_temp_downsample: - self.temp_convs_down = AllegroTemporalConvBlock( - out_channels, out_channels, dropout=0.1, down_sample=True, stride=3 + self.temp_convs_down = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3 ) self.add_temp_downsample = add_temp_downsample @@ -258,10 +256,11 @@ def __init__( ) ) temp_convs.append( - AllegroTemporalConvBlock( + AllegroTemporalConvLayer( out_channels, out_channels, dropout=0.1, + norm_num_groups=resnet_groups, ) ) @@ -270,8 +269,8 @@ def __init__( self.add_temp_upsample = add_temp_upsample if add_temp_upsample: - self.temp_conv_up = AllegroTemporalConvBlock( - out_channels, out_channels, dropout=0.1, up_sample=True, stride=3 + self.temp_conv_up = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3 ) if self.add_upsample: @@ -336,10 +335,11 @@ def __init__( ) ] temp_convs = [ - AllegroTemporalConvBlock( + AllegroTemporalConvLayer( in_channels, in_channels, dropout=0.1, + norm_num_groups=resnet_groups, ) ] attentions = [] @@ -383,10 +383,11 @@ def __init__( ) temp_convs.append( - AllegroTemporalConvBlock( + AllegroTemporalConvLayer( in_channels, in_channels, dropout=0.1, + norm_num_groups=resnet_groups, ) ) @@ -513,6 +514,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample + residual if self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -655,6 +657,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -662,17 +665,11 @@ def custom_forward(*inputs): return custom_forward # Mid block - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample - ) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) # Up blocks for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample - ) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # Mid block diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index f77e27d35ea2..e1e7e9a0f351 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -645,6 +645,18 @@ def _prepare_rotary_positional_embeddings( return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( From 2ef6a9e8960c39d1115623d2866685493bb3728a Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 00:55:42 +0200 Subject: [PATCH 13/33] add coauthor Co-Authored-By: Huan Yang From e53dac24210cd27317796c5cc885cdce5079ecfe Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 00:56:19 +0200 Subject: [PATCH 14/33] refactor part 7 --- .../autoencoders/autoencoder_kl_allegro.py | 417 ++++++++++++------ src/diffusers/models/embeddings.py | 52 --- src/diffusers/models/normalization.py | 39 +- .../transformers/transformer_allegro.py | 135 +++--- .../pipelines/allegro/pipeline_allegro.py | 16 +- 5 files changed, 365 insertions(+), 294 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 6509242fdc09..209e1fa01386 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -20,6 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import Attention, SpatialNorm from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from ..downsampling import Downsample2D @@ -89,40 +90,43 @@ def __init__( nn.SiLU(), nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) + + @staticmethod + def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + return hidden_states - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - identity = hidden_states + def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) if self.down_sample: - identity = identity[:, :, ::2] + identity = hidden_states[:, :, ::2] elif self.up_sample: - hidden_states_new = torch.cat((hidden_states, hidden_states), dim=2) - hidden_states_new[:, :, 0::2] = hidden_states - hidden_states_new[:, :, 1::2] = hidden_states - identity = hidden_states_new - del hidden_states_new + identity = hidden_states.repeat_interleave(2, dim=2) + else: + identity = hidden_states if self.down_sample or self.up_sample: hidden_states = self.conv1(hidden_states) else: - hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv1(hidden_states) if self.up_sample: hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3) - hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv2(hidden_states) - hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv3(hidden_states) - hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) return hidden_states @@ -139,10 +143,10 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - add_temp_downsample=False, - downsample_padding=1, + output_scale_factor: float = 1.0, + spatial_downsample: bool = True, + temporal_downsample: bool = False, + downsample_padding: int = 1, ): super().__init__() @@ -177,13 +181,13 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) - if add_temp_downsample: + if temporal_downsample: self.temp_convs_down = AllegroTemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3 ) - self.add_temp_downsample = add_temp_downsample + self.add_temp_downsample = temporal_downsample - if add_downsample: + if spatial_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( @@ -196,22 +200,21 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - hidden_states = temp_conv(hidden_states) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) if self.add_temp_downsample: - hidden_states = self.temp_convs_down(hidden_states) + hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size) if self.downsamplers is not None: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return hidden_states @@ -227,13 +230,12 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - add_temp_upsample=False, - temb_channels=None, + output_scale_factor: float = 1.0, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + temb_channels: Optional[int] = None, ): super().__init__() - self.add_upsample = add_upsample resnets = [] temp_convs = [] @@ -267,35 +269,34 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) - self.add_temp_upsample = add_temp_upsample - if add_temp_upsample: + self.add_temp_upsample = temporal_upsample + if temporal_upsample: self.temp_conv_up = AllegroTemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3 ) - if self.add_upsample: + if spatial_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = resnet(hidden_states, temb=None) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - hidden_states = temp_conv(hidden_states) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) if self.add_temp_upsample: - hidden_states = self.temp_conv_up(hidden_states) + hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size) if self.upsamplers is not None: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return hidden_states @@ -312,12 +313,10 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, ): super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention # there is always at least one resnet resnets = [ @@ -348,7 +347,7 @@ def __init__( attention_head_dim = in_channels for _ in range(num_layers): - if self.add_attention: + if add_attention: attentions.append( Attention( in_channels, @@ -395,21 +394,20 @@ def __init__( self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) - def forward(self, hidden_states: torch.Tensor): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.resnets[0](hidden_states, temb=None) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - hidden_states = self.temp_convs[0](hidden_states) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size) for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb=None) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - hidden_states = temp_conv(hidden_states) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return hidden_states @@ -424,8 +422,8 @@ def __init__( "AllegroDownBlock3D", "AllegroDownBlock3D", ), - blocks_temp_li=[False, False, False, False], block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False], layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", @@ -433,9 +431,6 @@ def __init__( ): super().__init__() - self.layers_per_block = layers_per_block - self.blocks_temp_li = blocks_temp_li - self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], @@ -462,11 +457,11 @@ def __init__( if down_block_type == "AllegroDownBlock3D": down_block = AllegroDownBlock3D( - num_layers=self.layers_per_block, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, - add_downsample=not is_final_block, - add_temp_downsample=blocks_temp_li[i], + spatial_downsample=not is_final_block, + temporal_downsample=temporal_downsample_blocks[i], resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, @@ -496,7 +491,6 @@ def __init__( conv_out_channels = 2 * out_channels if double_z else out_channels self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False @@ -508,7 +502,6 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_in(sample) sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - residual = sample sample = self.temp_conv_in(sample) sample = sample + residual @@ -539,16 +532,16 @@ def custom_forward(*inputs): sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - residual = sample sample = self.temp_conv_out(sample) sample = sample + residual + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) - sample = self.conv_out(sample) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - return sample @@ -563,7 +556,7 @@ def __init__( "AllegroUpBlock3D", "AllegroUpBlock3D", ), - blocks_temp_li=[False, False, False, False], + temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False], block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, @@ -572,9 +565,6 @@ def __init__( ): super().__init__() - self.layers_per_block = layers_per_block - self.blocks_temp_li = blocks_temp_li - self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], @@ -613,11 +603,11 @@ def __init__( if up_block_type == "AllegroUpBlock3D": up_block = AllegroUpBlock3D( - num_layers=self.layers_per_block + 1, + num_layers=layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, - add_upsample=not is_final_block, - add_temp_upsample=blocks_temp_li[i], + spatial_upsample=not is_final_block, + temporal_upsample=temporal_upsample_blocks[i], resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, @@ -648,8 +638,8 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_in(sample) - sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) residual = sample sample = self.temp_conv_in(sample) sample = sample + residual @@ -684,16 +674,16 @@ def custom_forward(*inputs): sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - residual = sample sample = self.temp_conv_out(sample) sample = sample + residual sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - return sample @@ -706,17 +696,30 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): for all models (such as downloading or saving). Parameters: - in_channels (int, defaults to `3`): Number of channels in the input image. - out_channels (int, defaults to `3`): Number of channels in the output. + in_channels (int, defaults to `3`): + Number of channels in the input image. + out_channels (int, defaults to `3`): + Number of channels in the output. down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, defaults to `(128, 256, 512, 512)`): - Tuple of block output channels. + Tuple of strings denoting which types of down blocks to use. + up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): + Tuple of strings denoting which types of up blocks to use. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + Tuple of integers denoting number of output channels in each block. + temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`): + Tuple of booleans denoting which blocks to enable temporal downsampling in. + latent_channels (`int`, defaults to `4`): + Number of channels in latents. + layers_per_block (`int`, defaults to `2`): + Number of resnet or attention or temporal convolution layers per down/up block. act_fn (`str`, defaults to `"silu"`): The activation function to use. - sample_size (`int`, *optional*, defaults to `32`): Sample input size. + norm_num_groups (`int`, defaults to `32`): + Number of groups to use in normalization layers. + temporal_compression_ratio (`int`, defaults to `4`): + Ratio by which temporal dimension of samples are compressed. + sample_size (`int`, defaults to `320`): + Default latent size. scaling_factor (`float`, defaults to `0.13235`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion @@ -728,7 +731,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - TODO(aryan): docs """ _supports_gradient_checkpointing = True @@ -751,30 +753,24 @@ def __init__( "AllegroUpBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False), + temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False), latent_channels: int = 4, layers_per_block: int = 2, act_fn: str = "silu", norm_num_groups: int = 32, temporal_compression_ratio: float = 4, sample_size: int = 320, - scaling_factor: float = 0.13235, + scaling_factor: float = 0.13, force_upcast: bool = True, - tile_overlap: tuple = (120, 80), - chunk_len: int = 24, - t_over: int = 8, - blocks_tempdown_li=[True, True, False, False], - blocks_tempup_li=[False, True, True, False], ) -> None: super().__init__() - self.blocks_tempdown_li = blocks_tempdown_li - self.blocks_tempup_li = blocks_tempup_li - self.encoder = AllegroEncoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, - blocks_temp_li=blocks_tempdown_li, + temporal_downsample_blocks=temporal_downsample_blocks, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, @@ -785,7 +781,7 @@ def __init__( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, - blocks_temp_li=blocks_tempup_li, + temporal_upsample_blocks=temporal_upsample_blocks, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, @@ -794,40 +790,175 @@ def __init__( self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need + # to use a specific parameter here or in other VAEs. + self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled sample_size = sample_size[0] if isinstance(sample_size, (list, tuple)) else sample_size - self.tile_overlap = tile_overlap self.vae_scale_factor = [4, 8, 8] - self.sample_size = sample_size - self.chunk_len = chunk_len - self.t_over = t_over - self.latent_chunk_len = self.chunk_len // 4 - self.latent_t_over = self.t_over // 4 - self.kernel = (self.chunk_len, self.sample_size, self.sample_size) # (24, 256, 256) + # TODO(aryan): refactor tiling implementation + chunk_len = 24 + t_over = 8 + tile_overlap = (120, 80) + + self.latent_chunk_len = chunk_len // 4 + self.latent_t_over = t_over // 4 + self.kernel = (chunk_len, sample_size, sample_size) # (24, 256, 256) self.stride = ( - self.chunk_len - self.t_over, - self.sample_size - self.tile_overlap[0], - self.sample_size - self.tile_overlap[1], + chunk_len - t_over, + sample_size - tile_overlap[0], + sample_size - tile_overlap[1], ) # (16, 112, 192) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): module.gradient_checkpointing = value + + def enable_tiling( + self, + # tile_sample_min_height: Optional[int] = None, + # tile_sample_min_width: Optional[int] = None, + # tile_overlap_factor_height: Optional[float] = None, + # tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + + # TODO(aryan): refactor tiling implementation + # self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + # self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + # self.tile_latent_min_height = int( + # self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + # ) + # self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + # self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + # self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + # TODO(aryan) + # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + if self.use_tiling: + return self.tiled_encode(x) + + raise NotImplementedError("Encoding without tiling has not been implemented yet.") + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of videos into latents. + + Args: + x (`torch.Tensor`): + Input batch of videos. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) - def encode( - self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1 - ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # TODO(aryan): refactor tiling implementation + # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + if self.use_tiling: + return self.tiled_decode(z) + + raise NotImplementedError("Decoding without tiling has not been implemented yet.") + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of videos. + + Args: + z (`torch.Tensor`): + Input batch of latent vectors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def tiled_encode( + self, x: torch.Tensor + ) -> torch.Tensor: + # TODO(aryan): parameterize this in enable_tiling + local_batch_size = 1 + # TODO(aryan): rewrite to encode and tiled_encode KERNEL = self.kernel STRIDE = self.stride LOCAL_BS = local_batch_size OUT_C = 8 - B, C, N, H, W = input_imgs.shape + B, C, N, H, W = x.shape out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 @@ -838,11 +969,11 @@ def encode( out_latent = torch.zeros( (out_n * out_h * out_w, OUT_C, KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8), - device=input_imgs.device, - dtype=input_imgs.dtype, + device=x.device, + dtype=x.dtype, ) vae_batch_input = torch.zeros( - (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype + (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype ) for i in range(out_n): @@ -851,7 +982,7 @@ def encode( n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] - video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] vae_batch_input[num % LOCAL_BS] = video_cube if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: @@ -863,16 +994,16 @@ def encode( out_latent[num - LOCAL_BS + 1 : num + 1] = latent vae_batch_input = torch.zeros( (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), - device=input_imgs.device, - dtype=input_imgs.dtype, + device=x.device, + dtype=x.dtype, ) num += 1 ## flatten the batched out latent to videos and supress the overlapped parts - B, C, N, H, W = input_imgs.shape + B, C, N, H, W = x.shape out_video_cube = torch.zeros( - (B, OUT_C, N // 4, H // 8, W // 8), device=input_imgs.device, dtype=input_imgs.dtype + (B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype ) OUT_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 OUT_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 @@ -897,16 +1028,14 @@ def encode( out_video_cube = self.quant_conv(out_video_cube) out_video_cube = out_video_cube.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) - posterior = DiagonalGaussianDistribution(out_video_cube) + return out_video_cube - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) + def tiled_decode( + self, z: torch.Tensor + ) -> torch.Tensor: + # TODO(aryan): parameterize this in enable_tiling + local_batch_size = 1 - def decode( - self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1 - ) -> Union[DecoderOutput, torch.Tensor]: # TODO(aryan): rewrite to decode and tiled_decode KERNEL = self.kernel STRIDE = self.stride @@ -916,12 +1045,12 @@ def decode( IN_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 IN_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 - B, C, N, H, W = input_latents.shape + B, C, N, H, W = z.shape ## post quant conv (a mapping) - input_latents = input_latents.permute(0, 2, 1, 3, 4).flatten(0, 1) - input_latents = self.post_quant_conv(input_latents) - input_latents = input_latents.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) + z = z.permute(0, 2, 1, 3, 4).flatten(0, 1) + z = self.post_quant_conv(z) + z = z.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) ## out tensor shape out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 @@ -932,13 +1061,13 @@ def decode( num = 0 decoded_cube = torch.zeros( (out_n * out_h * out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), - device=input_latents.device, - dtype=input_latents.dtype, + device=z.device, + dtype=z.dtype, ) vae_batch_input = torch.zeros( (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), - device=input_latents.device, - dtype=input_latents.dtype, + device=z.device, + dtype=z.dtype, ) for i in range(out_n): for j in range(out_h): @@ -946,7 +1075,7 @@ def decode( n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0] h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] - latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + latent_cube = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] vae_batch_input[num % LOCAL_BS] = latent_cube if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: latent = self.decoder(vae_batch_input) @@ -957,14 +1086,14 @@ def decode( decoded_cube[num - LOCAL_BS + 1 : num + 1] = latent vae_batch_input = torch.zeros( (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), - device=input_latents.device, - dtype=input_latents.dtype, + device=z.device, + dtype=z.dtype, ) num += 1 - B, C, N, H, W = input_latents.shape + B, C, N, H, W = z.shape out_video = torch.zeros( - (B, OUT_C, N * 4, H * 8, W * 8), device=input_latents.device, dtype=input_latents.dtype + (B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype ) OVERLAP = KERNEL[0] - STRIDE[0], KERNEL[1] - STRIDE[1], KERNEL[2] - STRIDE[2] for i in range(out_n): @@ -983,11 +1112,7 @@ def decode( out_video = out_video.permute(0, 2, 1, 3, 4).contiguous() - decoded = out_video - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) + return out_video def forward( self, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ad8f5f43c512..66917dce6107 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1598,58 +1598,6 @@ def forward( return objs -class AllegroCombinedTimestepSizeEmbeddings(nn.Module): - """ - For Allegro. TODO(aryan) - """ - - def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False): - super().__init__() - - self.outdim = size_emb_dim - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.use_additional_conditions = use_additional_conditions - if use_additional_conditions: - self.use_additional_conditions = True - self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - - def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): - if size.ndim == 1: - size = size[:, None] - - if size.shape[0] != batch_size: - size = size.repeat(batch_size // size.shape[0], 1) - if size.shape[0] != batch_size: - raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") - - current_batch_size, dims = size.shape[0], size.shape[1] - size = size.reshape(-1) - size_freq = self.additional_condition_proj(size).to(size.dtype) - - size_emb = embedder(size_freq) - size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) - return size_emb - - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - if self.use_additional_conditions: - resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) - aspect_ratio = self.apply_condition( - aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder - ) - conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) - else: - conditioning = timesteps_emb - - return conditioning - - class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index f7cbe58a71c0..324a016b3b40 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -23,9 +23,8 @@ from ..utils import is_torch_version from .activations import get_activation from .embeddings import ( - AllegroCombinedTimestepSizeEmbeddings, CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings ) @@ -267,6 +266,7 @@ def forward( hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep @@ -390,41 +390,6 @@ def forward( return x -class AllegroAdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - use_additional_conditions (`bool`): To use additional conditions for normalization or not. - """ - - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): - super().__init__() - - self.emb = AllegroCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions - ) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - - def forward( - self, - timestep: torch.Tensor, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - batch_size: int = None, - hidden_dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # No modulation happening here. - embedded_timestep = self.emb( - timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None - ) - return self.linear(self.silu(embedded_timestep)), embedded_timestep - - class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 2b7ea524c763..599792da42c7 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -27,7 +27,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AllegroAdaLayerNormSingle +from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -36,7 +36,30 @@ @maybe_allow_in_graph class AllegroTransformerBlock(nn.Module): r""" - TODO(aryan): docs + Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + only_cross_attention (`bool`, defaults to `False`): + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. """ def __init__( @@ -48,11 +71,8 @@ def __init__( cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, - final_dropout: bool = False, ): super().__init__() @@ -65,8 +85,7 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, + cross_attention_dim=cross_attention_dim, processor=AllegroAttnProcessor2_0(), ) @@ -79,7 +98,6 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, - upcast_attention=upcast_attention, processor=AllegroAttnProcessor2_0(), ) # is self-attn if encoder_hidden_states is none @@ -90,7 +108,6 @@ def __init__( dim, dropout=dropout, activation_fn=activation_fn, - final_dropout=final_dropout, ) # 4. Scale-shift @@ -147,49 +164,63 @@ def forward( ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states - - # TODO(aryan): maybe following line is not required - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - return hidden_states class AllegroTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. + r""" + A 3D Transformer model for video-like data. + + Args: + patch_size (`int`, defaults to `2`): + The size of spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `96`): + The number of channels in each head. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `4`): + The number of channels in the output. + num_layers (`int`, defaults to `32`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_height (`int`, defaults to `90`): + The height of the input latents. + sample_width (`int`, defaults to `160`): + The width of the input latents. + sample_frames (`int`, defaults to `22`): + The number of frames in the input latents. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + caption_channels (`int`, defaults to `4096`): + Number of channels to use for projecting the caption embeddings. + interpolation_scale_h (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across height dimension. + interpolation_scale_w (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across width dimension. + interpolation_scale_t (`float`, defaults to `2.2`): + Scaling factor to apply in 3D positional embeddings across time dimension. """ @register_to_config def __init__( self, patch_size: int = 2, - patch_size_temporal: int = 1, + patch_size_t: int = 1, num_attention_heads: int = 24, attention_head_dim: int = 96, in_channels: int = 4, @@ -202,7 +233,6 @@ def __init__( sample_width: int = 160, sample_frames: int = 22, activation_fn: str = "gelu-approximate", - upcast_attention: bool = False, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = 4096, @@ -245,7 +275,6 @@ def __init__( cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, - upcast_attention=upcast_attention, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) @@ -259,7 +288,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings - self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) # 5. Caption projection self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) @@ -280,9 +309,13 @@ def forward( return_dict: bool = True, ): batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t = self.config.patch_size_temporal + p_t = self.config.patch_size_t p = self.config.patch_size + post_patch_num_frames = num_frames // self.config.patch_size_t + post_patch_height = height // self.config.patch_size + post_patch_width = width // self.config.patch_size + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -317,15 +350,12 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 1. Input - post_patch_num_frames = num_frames // self.config.patch_size_temporal - post_patch_height = height // self.config.patch_size - post_patch_width = width // self.config.patch_size - + # 1. Timestep embeddings timestep, embedded_timestep = self.adaln_single( timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) + # 2. Patch embeddings hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.pos_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) @@ -333,6 +363,7 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing if self.gradient_checkpointing: @@ -364,16 +395,16 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, ) - # 3. Output + # 4. Output normalization & projection shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) - # Modulation + # modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) - # unpatchify + # 5. Unpatchify hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 ) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index e1e7e9a0f351..1c639d5a1c30 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -211,7 +211,7 @@ def encode_prompt( prompt_attention_mask: Optional[torch.FloatTensor] = None, negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, clean_caption: bool = False, - max_sequence_length: int = 300, + max_sequence_length: int = 512, ): r""" Encodes the prompt into text encoder hidden states. @@ -237,7 +237,8 @@ def encode_prompt( string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. - max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length to use for the prompt. """ if device is None: @@ -684,7 +685,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], clean_caption: bool = True, - max_sequence_length: int = 300, + max_sequence_length: int = 512, ) -> Union[AllegroPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -751,7 +752,8 @@ def __call__( Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + max_sequence_length (`int` defaults to `512`): + Maximum sequence length to use with the `prompt`. Examples: @@ -767,9 +769,9 @@ def __call__( num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct - num_frames = num_frames or self.transformer.config.sample_size_t * self.vae_scale_factor_temporal - height = height or self.transformer.config.sample_size[0] * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size[1] * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial self.check_inputs( prompt, From f702af0cde9150fcb8e6b3b9b3459d611f0d6847 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 01:05:25 +0200 Subject: [PATCH 15/33] add docs --- docs/source/en/_toctree.yml | 6 +++ .../en/api/models/allegro_transformer3d.md | 30 +++++++++++++++ .../en/api/models/autoencoderkl_allegro.md | 37 +++++++++++++++++++ docs/source/en/api/pipelines/allegro.md | 34 +++++++++++++++++ 4 files changed, 107 insertions(+) create mode 100644 docs/source/en/api/models/allegro_transformer3d.md create mode 100644 docs/source/en/api/models/autoencoderkl_allegro.md create mode 100644 docs/source/en/api/pipelines/allegro.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 58218c0272bd..4a5201855fbb 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -250,6 +250,8 @@ title: SparseControlNetModel title: ControlNets - sections: + - local: api/models/allegro_transformer3d + title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/cogvideox_transformer3d @@ -298,6 +300,8 @@ - sections: - local: api/models/autoencoderkl title: AutoencoderKL + - local: api/models/autoencoderkl_allegro + title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX - local: api/models/asymmetricautoencoderkl @@ -316,6 +320,8 @@ sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/allegro + title: Allegro - local: api/pipelines/amused title: aMUSEd - local: api/pipelines/animatediff diff --git a/docs/source/en/api/models/allegro_transformer3d.md b/docs/source/en/api/models/allegro_transformer3d.md new file mode 100644 index 000000000000..e70026fe4bfc --- /dev/null +++ b/docs/source/en/api/models/allegro_transformer3d.md @@ -0,0 +1,30 @@ + + +# AllegroTransformer3DModel + +A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AllegroTransformer3DModel + +vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## AllegroTransformer3DModel + +[[autodoc]] AllegroTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/models/autoencoderkl_allegro.md b/docs/source/en/api/models/autoencoderkl_allegro.md new file mode 100644 index 000000000000..fd9d10d5724b --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_allegro.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLAllegro + +The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLAllegro + +vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLAllegro + +[[autodoc]] AutoencoderKLAllegro + - decode + - encode + - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md new file mode 100644 index 000000000000..e13e339944e5 --- /dev/null +++ b/docs/source/en/api/pipelines/allegro.md @@ -0,0 +1,34 @@ + + +# Allegro + +[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang. + +The abstract from the paper is: + +*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## AllegroPipeline + +[[autodoc]] AllegroPipeline + - all + - __call__ + +## AllegroPipelineOutput + +[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput From 3d412811c8db454be04f1c11a512ab05a113761d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 01:09:04 +0200 Subject: [PATCH 16/33] make style --- .../autoencoders/autoencoder_kl_allegro.py | 64 ++++++++----------- src/diffusers/models/normalization.py | 5 +- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 209e1fa01386..d76dabe84c8b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -90,7 +90,7 @@ def __init__( nn.SiLU(), nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), ) - + @staticmethod def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) @@ -118,10 +118,10 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv2(hidden_states) - + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv3(hidden_states) - + hidden_states = self._pad_temporal_dim(hidden_states) hidden_states = self.conv4(hidden_states) @@ -200,7 +200,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] - + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for resnet, temp_conv in zip(self.resnets, self.temp_convs): @@ -213,7 +213,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return hidden_states @@ -282,7 +282,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] - + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) for resnet, temp_conv in zip(self.resnets, self.temp_convs): @@ -295,7 +295,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return hidden_states @@ -399,7 +399,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.resnets[0](hidden_states, temb=None) - + hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size) for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): @@ -532,15 +532,15 @@ def custom_forward(*inputs): sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) - + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) residual = sample sample = self.temp_conv_out(sample) sample = sample + residual - + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) - + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return sample @@ -674,7 +674,7 @@ def custom_forward(*inputs): sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) - + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) residual = sample sample = self.temp_conv_out(sample) @@ -682,7 +682,7 @@ def custom_forward(*inputs): sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) sample = self.conv_out(sample) - + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) return sample @@ -804,7 +804,7 @@ def __init__( chunk_len = 24 t_over = 8 tile_overlap = (120, 80) - + self.latent_chunk_len = chunk_len // 4 self.latent_t_over = t_over // 4 self.kernel = (chunk_len, sample_size, sample_size) # (24, 256, 256) @@ -817,7 +817,7 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): module.gradient_checkpointing = value - + def enable_tiling( self, # tile_sample_min_height: Optional[int] = None, @@ -876,17 +876,19 @@ def disable_slicing(self) -> None: decoding in one step. """ self.use_slicing = False - + def _encode(self, x: torch.Tensor) -> torch.Tensor: # TODO(aryan) # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): if self.use_tiling: return self.tiled_encode(x) - + raise NotImplementedError("Encoding without tiling has not been implemented yet.") - + @apply_forward_hook - def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: r""" Encode a batch of videos into latents. @@ -919,7 +921,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor: return self.tiled_decode(z) raise NotImplementedError("Decoding without tiling has not been implemented yet.") - + @apply_forward_hook def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: """ @@ -946,12 +948,10 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) - def tiled_encode( - self, x: torch.Tensor - ) -> torch.Tensor: + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: # TODO(aryan): parameterize this in enable_tiling local_batch_size = 1 - + # TODO(aryan): rewrite to encode and tiled_encode KERNEL = self.kernel STRIDE = self.stride @@ -972,9 +972,7 @@ def tiled_encode( device=x.device, dtype=x.dtype, ) - vae_batch_input = torch.zeros( - (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype - ) + vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype) for i in range(out_n): for j in range(out_h): @@ -1002,9 +1000,7 @@ def tiled_encode( ## flatten the batched out latent to videos and supress the overlapped parts B, C, N, H, W = x.shape - out_video_cube = torch.zeros( - (B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype - ) + out_video_cube = torch.zeros((B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype) OUT_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 OUT_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 OVERLAP = OUT_KERNEL[0] - OUT_STRIDE[0], OUT_KERNEL[1] - OUT_STRIDE[1], OUT_KERNEL[2] - OUT_STRIDE[2] @@ -1030,9 +1026,7 @@ def tiled_encode( return out_video_cube - def tiled_decode( - self, z: torch.Tensor - ) -> torch.Tensor: + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: # TODO(aryan): parameterize this in enable_tiling local_batch_size = 1 @@ -1092,9 +1086,7 @@ def tiled_decode( num += 1 B, C, N, H, W = z.shape - out_video = torch.zeros( - (B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype - ) + out_video = torch.zeros((B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype) OVERLAP = KERNEL[0] - STRIDE[0], KERNEL[1] - STRIDE[1], KERNEL[2] - STRIDE[2] for i in range(out_n): n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 324a016b3b40..87dec66935da 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,10 +22,7 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import ( - CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings -) +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): From 37e8a95f4fd6987ceb6cdf4fda5ed5de87649df7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 01:09:50 +0200 Subject: [PATCH 17/33] add coauthor Co-Authored-By: YiYi Xu From 2c4645c0687d7883dd1308190b50d9e90e7fb79d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 01:12:33 +0200 Subject: [PATCH 18/33] make fix-copies --- .../pipelines/allegro/pipeline_allegro.py | 35 +++++++++++++------ src/diffusers/utils/dummy_pt_objects.py | 30 ++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 ++++++++ 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 1c639d5a1c30..340f749c48d5 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -71,9 +71,10 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -86,14 +87,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -104,6 +109,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -458,14 +473,12 @@ def _clean_caption(self, caption): caption = re.sub("", "person", caption) # urls: caption = re.sub( - r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", - # noqa + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( - r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", - # noqa + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls @@ -488,13 +501,12 @@ def _clean_caption(self, caption): caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) - # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( - r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", - # noqa + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) @@ -565,6 +577,7 @@ def _clean_caption(self, caption): caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) + return caption.strip() def prepare_latents( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 10d0399a6761..8a87b04a66cb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -47,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLAllegro(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLCogVideoX(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9046a4f73533..83d160b08df4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From e26604cd75d318709cb6bbe92da204d16fb668a7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 01:19:06 +0200 Subject: [PATCH 19/33] undo unrelated change --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 16150cbd79bf..9cb042c9e80c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -649,8 +649,8 @@ def __call__( height, width, prompt_embeds.dtype, - generator, device, + generator, latents, ) From bb321e7a17c2cf2ca1c675327c492cbe2ee6a024 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 03:59:25 +0200 Subject: [PATCH 20/33] revert changes to embeddings, normalization, transformer --- src/diffusers/models/embeddings.py | 52 +++++++ src/diffusers/models/normalization.py | 42 +++++- .../transformers/transformer_allegro.py | 135 +++++++----------- 3 files changed, 144 insertions(+), 85 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 66917dce6107..ad8f5f43c512 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1598,6 +1598,58 @@ def forward( return objs +class AllegroCombinedTimestepSizeEmbeddings(nn.Module): + """ + For Allegro. TODO(aryan) + """ + + def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.use_additional_conditions = True + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): + if size.ndim == 1: + size = size[:, None] + + if size.shape[0] != batch_size: + size = size.repeat(batch_size // size.shape[0], 1) + if size.shape[0] != batch_size: + raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") + + current_batch_size, dims = size.shape[0], size.shape[1] + size = size.reshape(-1) + size_freq = self.additional_condition_proj(size).to(size.dtype) + + size_emb = embedder(size_freq) + size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) + return size_emb + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) + aspect_ratio = self.apply_condition( + aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder + ) + conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 87dec66935da..f7cbe58a71c0 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,7 +22,11 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings +from .embeddings import ( + AllegroCombinedTimestepSizeEmbeddings, + CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) class AdaLayerNorm(nn.Module): @@ -263,7 +267,6 @@ def forward( hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. - added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep @@ -387,6 +390,41 @@ def forward( return x +class AllegroAdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = AllegroCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb( + timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None + ) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 599792da42c7..2b7ea524c763 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -27,7 +27,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AllegroAdaLayerNormSingle logger = logging.get_logger(__name__) @@ -36,30 +36,7 @@ @maybe_allow_in_graph class AllegroTransformerBlock(nn.Module): r""" - Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. - - Args: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - cross_attention_dim (`int`, defaults to `2304`): - The dimension of the cross attention features. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - only_cross_attention (`bool`, defaults to `False`): - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. + TODO(aryan): docs """ def __init__( @@ -71,8 +48,11 @@ def __init__( cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, + final_dropout: bool = False, ): super().__init__() @@ -85,7 +65,8 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, processor=AllegroAttnProcessor2_0(), ) @@ -98,6 +79,7 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, + upcast_attention=upcast_attention, processor=AllegroAttnProcessor2_0(), ) # is self-attn if encoder_hidden_states is none @@ -108,6 +90,7 @@ def __init__( dim, dropout=dropout, activation_fn=activation_fn, + final_dropout=final_dropout, ) # 4. Scale-shift @@ -164,63 +147,49 @@ def forward( ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states + + # TODO(aryan): maybe following line is not required + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states class AllegroTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True - r""" - A 3D Transformer model for video-like data. - - Args: - patch_size (`int`, defaults to `2`): - The size of spatial patches to use in the patch embedding layer. - patch_size_t (`int`, defaults to `1`): - The size of temporal patches to use in the patch embedding layer. - num_attention_heads (`int`, defaults to `24`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `96`): - The number of channels in each head. - in_channels (`int`, defaults to `4`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `4`): - The number of channels in the output. - num_layers (`int`, defaults to `32`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - cross_attention_dim (`int`, defaults to `2304`): - The dimension of the cross attention features. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_height (`int`, defaults to `90`): - The height of the input latents. - sample_width (`int`, defaults to `160`): - The width of the input latents. - sample_frames (`int`, defaults to `22`): - The number of frames in the input latents. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): - The epsilon value to use in normalization layers. - caption_channels (`int`, defaults to `4096`): - Number of channels to use for projecting the caption embeddings. - interpolation_scale_h (`float`, defaults to `2.0`): - Scaling factor to apply in 3D positional embeddings across height dimension. - interpolation_scale_w (`float`, defaults to `2.0`): - Scaling factor to apply in 3D positional embeddings across width dimension. - interpolation_scale_t (`float`, defaults to `2.2`): - Scaling factor to apply in 3D positional embeddings across time dimension. + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. """ @register_to_config def __init__( self, patch_size: int = 2, - patch_size_t: int = 1, + patch_size_temporal: int = 1, num_attention_heads: int = 24, attention_head_dim: int = 96, in_channels: int = 4, @@ -233,6 +202,7 @@ def __init__( sample_width: int = 160, sample_frames: int = 22, activation_fn: str = "gelu-approximate", + upcast_attention: bool = False, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = 4096, @@ -275,6 +245,7 @@ def __init__( cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, + upcast_attention=upcast_attention, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) @@ -288,7 +259,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings - self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) # 5. Caption projection self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) @@ -309,13 +280,9 @@ def forward( return_dict: bool = True, ): batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t = self.config.patch_size_t + p_t = self.config.patch_size_temporal p = self.config.patch_size - post_patch_num_frames = num_frames // self.config.patch_size_t - post_patch_height = height // self.config.patch_size - post_patch_width = width // self.config.patch_size - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -350,12 +317,15 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 1. Timestep embeddings + # 1. Input + post_patch_num_frames = num_frames // self.config.patch_size_temporal + post_patch_height = height // self.config.patch_size + post_patch_width = width // self.config.patch_size + timestep, embedded_timestep = self.adaln_single( timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) - # 2. Patch embeddings hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.pos_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) @@ -363,7 +333,6 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) - # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing if self.gradient_checkpointing: @@ -395,16 +364,16 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, ) - # 4. Output normalization & projection + # 3. Output shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) - # modulation + # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) - # 5. Unpatchify + # unpatchify hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 ) From 174621f34dbbf81a58424ed1caa0b667b6b74cf6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 04:26:20 +0200 Subject: [PATCH 21/33] refactor part 8 --- src/diffusers/models/embeddings.py | 52 -------- src/diffusers/models/normalization.py | 39 +----- .../transformers/transformer_allegro.py | 126 +++++++++++------- 3 files changed, 82 insertions(+), 135 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ad8f5f43c512..66917dce6107 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1598,58 +1598,6 @@ def forward( return objs -class AllegroCombinedTimestepSizeEmbeddings(nn.Module): - """ - For Allegro. TODO(aryan) - """ - - def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False): - super().__init__() - - self.outdim = size_emb_dim - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.use_additional_conditions = use_additional_conditions - if use_additional_conditions: - self.use_additional_conditions = True - self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - - def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): - if size.ndim == 1: - size = size[:, None] - - if size.shape[0] != batch_size: - size = size.repeat(batch_size // size.shape[0], 1) - if size.shape[0] != batch_size: - raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") - - current_batch_size, dims = size.shape[0], size.shape[1] - size = size.reshape(-1) - size_freq = self.additional_condition_proj(size).to(size.dtype) - - size_emb = embedder(size_freq) - size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) - return size_emb - - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - if self.use_additional_conditions: - resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) - aspect_ratio = self.apply_condition( - aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder - ) - conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) - else: - conditioning = timesteps_emb - - return conditioning - - class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index f7cbe58a71c0..324a016b3b40 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -23,9 +23,8 @@ from ..utils import is_torch_version from .activations import get_activation from .embeddings import ( - AllegroCombinedTimestepSizeEmbeddings, CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings ) @@ -267,6 +266,7 @@ def forward( hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep @@ -390,41 +390,6 @@ def forward( return x -class AllegroAdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - use_additional_conditions (`bool`): To use additional conditions for normalization or not. - """ - - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): - super().__init__() - - self.emb = AllegroCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions - ) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - - def forward( - self, - timestep: torch.Tensor, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - batch_size: int = None, - hidden_dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # No modulation happening here. - embedded_timestep = self.emb( - timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None - ) - return self.linear(self.silu(embedded_timestep)), embedded_timestep - - class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 2b7ea524c763..e1f2f08cfeb4 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -27,7 +27,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AllegroAdaLayerNormSingle +from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -36,7 +36,29 @@ @maybe_allow_in_graph class AllegroTransformerBlock(nn.Module): r""" - TODO(aryan): docs + Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + only_cross_attention (`bool`, defaults to `False`): + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. """ def __init__( @@ -48,11 +70,8 @@ def __init__( cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, - final_dropout: bool = False, ): super().__init__() @@ -65,8 +84,7 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, + cross_attention_dim=None, processor=AllegroAttnProcessor2_0(), ) @@ -79,9 +97,8 @@ def __init__( dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, - upcast_attention=upcast_attention, processor=AllegroAttnProcessor2_0(), - ) # is self-attn if encoder_hidden_states is none + ) # 3. Feed Forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) @@ -90,7 +107,6 @@ def __init__( dim, dropout=dropout, activation_fn=activation_fn, - final_dropout=final_dropout, ) # 4. Scale-shift @@ -159,37 +175,55 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. + A 3D Transformer model for video-like data. + Args: + patch_size (`int`, defaults to `2`): + The size of spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `96`): + The number of channels in each head. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `4`): + The number of channels in the output. + num_layers (`int`, defaults to `32`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_height (`int`, defaults to `90`): + The height of the input latents. + sample_width (`int`, defaults to `160`): + The width of the input latents. + sample_frames (`int`, defaults to `22`): + The number of frames in the input latents. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + caption_channels (`int`, defaults to `4096`): + Number of channels to use for projecting the caption embeddings. + interpolation_scale_h (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across height dimension. + interpolation_scale_w (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across width dimension. + interpolation_scale_t (`float`, defaults to `2.2`): + Scaling factor to apply in 3D positional embeddings across time dimension. """ @register_to_config def __init__( self, patch_size: int = 2, - patch_size_temporal: int = 1, + patch_size_t: int = 1, num_attention_heads: int = 24, attention_head_dim: int = 96, in_channels: int = 4, @@ -202,7 +236,6 @@ def __init__( sample_width: int = 160, sample_frames: int = 22, activation_fn: str = "gelu-approximate", - upcast_attention: bool = False, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = 4096, @@ -245,7 +278,6 @@ def __init__( cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, - upcast_attention=upcast_attention, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) @@ -259,7 +291,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings - self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) # 5. Caption projection self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) @@ -280,9 +312,13 @@ def forward( return_dict: bool = True, ): batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t = self.config.patch_size_temporal + p_t = self.config.patch_size_t p = self.config.patch_size + post_patch_num_frames = num_frames // self.config.patch_size_temporal + post_patch_height = height // self.config.patch_size + post_patch_width = width // self.config.patch_size + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -317,15 +353,12 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 1. Input - post_patch_num_frames = num_frames // self.config.patch_size_temporal - post_patch_height = height // self.config.patch_size - post_patch_width = width // self.config.patch_size - + # 1. Timestep embeddings timestep, embedded_timestep = self.adaln_single( timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) + # 2. Patch embeddings hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.pos_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) @@ -333,6 +366,7 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing if self.gradient_checkpointing: @@ -364,7 +398,7 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, ) - # 3. Output + # 4. Output normalization & projection shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) @@ -373,7 +407,7 @@ def custom_forward(*inputs): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) - # unpatchify + # 5. Unpatchify hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 ) From 2a82064786aa2ba63e48b94df6fa7d466a97b613 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 04:26:32 +0200 Subject: [PATCH 22/33] make style --- src/diffusers/models/normalization.py | 5 +---- src/diffusers/models/transformers/transformer_allegro.py | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 324a016b3b40..87dec66935da 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,10 +22,7 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import ( - CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings -) +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index e1f2f08cfeb4..3d6a8b53172d 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -37,6 +37,7 @@ class AllegroTransformerBlock(nn.Module): r""" Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + Args: dim (`int`): The number of channels in the input and output. @@ -176,6 +177,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ A 3D Transformer model for video-like data. + Args: patch_size (`int`, defaults to `2`): The size of spatial patches to use in the patch embedding layer. From 762ccd5d2c5c215e664335e36a35bd44322161f5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 08:42:45 +0200 Subject: [PATCH 23/33] refactor part 9 --- .../autoencoders/autoencoder_kl_allegro.py | 292 +++++++----------- 1 file changed, 119 insertions(+), 173 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index d76dabe84c8b..1d2b306a8189 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -796,66 +796,33 @@ def __init__( self.use_slicing = False self.use_tiling = False - # only relevant if vae tiling is enabled - sample_size = sample_size[0] if isinstance(sample_size, (list, tuple)) else sample_size - self.vae_scale_factor = [4, 8, 8] + self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) + self.tile_overlap_t = 8 + self.tile_overlap_h = 120 + self.tile_overlap_w = 80 + sample_frames = 24 - # TODO(aryan): refactor tiling implementation - chunk_len = 24 - t_over = 8 - tile_overlap = (120, 80) - - self.latent_chunk_len = chunk_len // 4 - self.latent_t_over = t_over // 4 - self.kernel = (chunk_len, sample_size, sample_size) # (24, 256, 256) + self.kernel = (sample_frames, sample_size, sample_size) self.stride = ( - chunk_len - t_over, - sample_size - tile_overlap[0], - sample_size - tile_overlap[1], - ) # (16, 112, 192) + sample_frames - self.tile_overlap_t, + sample_size - self.tile_overlap_h, + sample_size - self.tile_overlap_w, + ) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): module.gradient_checkpointing = value def enable_tiling( - self, - # tile_sample_min_height: Optional[int] = None, - # tile_sample_min_width: Optional[int] = None, - # tile_overlap_factor_height: Optional[float] = None, - # tile_overlap_factor_width: Optional[float] = None, + self ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. - - Args: - tile_sample_min_height (`int`, *optional*): - The minimum height required for a sample to be separated into tiles across the height dimension. - tile_sample_min_width (`int`, *optional*): - The minimum width required for a sample to be separated into tiles across the width dimension. - tile_overlap_factor_height (`int`, *optional*): - The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are - no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. - tile_overlap_factor_width (`int`, *optional*): - The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there - are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. """ self.use_tiling = True - # TODO(aryan): refactor tiling implementation - # self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height - # self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width - # self.tile_latent_min_height = int( - # self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) - # ) - # self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) - # self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height - # self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width - def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing @@ -949,162 +916,140 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return DecoderOutput(sample=decoded) def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: - # TODO(aryan): parameterize this in enable_tiling local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio - # TODO(aryan): rewrite to encode and tiled_encode - KERNEL = self.kernel - STRIDE = self.stride - LOCAL_BS = local_batch_size - OUT_C = 8 + batch_size, num_channels, num_frames, height, width = x.shape - B, C, N, H, W = x.shape + output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1 + output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1 + output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1 - out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 - out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 - out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 - - ## cut video into overlapped small cubes and batch forward - num = 0 - - out_latent = torch.zeros( - (out_n * out_h * out_w, OUT_C, KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8), - device=x.device, - dtype=x.dtype, + count = 0 + output_latent = x.new_zeros( + (output_num_frames * output_height * output_width, 2 * self.config.latent_channels, self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs) ) - vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype) - - for i in range(out_n): - for j in range(out_h): - for k in range(out_w): - n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] - h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] - w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] + vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num % LOCAL_BS] = video_cube + vae_batch_input[count % local_batch_size] = video_cube - if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: + if count % local_batch_size == local_batch_size - 1 or count == output_num_frames * output_height * output_width - 1: latent = self.encoder(vae_batch_input) - if num == out_n * out_h * out_w - 1 and num % LOCAL_BS != LOCAL_BS - 1: - out_latent[num - num % LOCAL_BS :] = latent[: num % LOCAL_BS + 1] + if count == output_num_frames * output_height * output_width - 1 and count % local_batch_size != local_batch_size - 1: + output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1] else: - out_latent[num - LOCAL_BS + 1 : num + 1] = latent - vae_batch_input = torch.zeros( - (LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), - device=x.device, - dtype=x.dtype, + output_latent[count - local_batch_size + 1 : count + 1] = latent + + vae_batch_input = x.new_zeros( + (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]) ) - num += 1 - - ## flatten the batched out latent to videos and supress the overlapped parts - B, C, N, H, W = x.shape - - out_video_cube = torch.zeros((B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype) - OUT_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 - OUT_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 - OVERLAP = OUT_KERNEL[0] - OUT_STRIDE[0], OUT_KERNEL[1] - OUT_STRIDE[1], OUT_KERNEL[2] - OUT_STRIDE[2] - - for i in range(out_n): - n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0] - for j in range(out_h): - h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1] - for k in range(out_w): - w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2] - latent_mean_blend = prepare_for_blend( - (i, out_n, OVERLAP[0]), - (j, out_h, OVERLAP[1]), - (k, out_w, OVERLAP[2]), - out_latent[i * out_h * out_w + j * out_w + k].unsqueeze(0), + + count += 1 + + latent = x.new_zeros((batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)) + output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + output_overlap = output_kernel[0] - output_stride[0], output_kernel[1] - output_stride[1], output_kernel[2] - output_stride[2] + + for i in range(output_num_frames): + n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0] + for j in range(output_height): + h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1] + for k in range(output_width): + w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2] + latent_mean = _prepare_for_blend( + (i, output_num_frames, output_overlap[0]), + (j, output_height, output_overlap[1]), + (k, output_width, output_overlap[2]), + output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0), ) - out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend + latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean - # final conv - out_video_cube = out_video_cube.permute(0, 2, 1, 3, 4).flatten(0, 1) - out_video_cube = self.quant_conv(out_video_cube) - out_video_cube = out_video_cube.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) - - return out_video_cube + latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1) + latent = self.quant_conv(latent) + latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return latent def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: - # TODO(aryan): parameterize this in enable_tiling local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio - # TODO(aryan): rewrite to decode and tiled_decode - KERNEL = self.kernel - STRIDE = self.stride - - LOCAL_BS = local_batch_size - OUT_C = 3 - IN_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8 - IN_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8 + latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs - B, C, N, H, W = z.shape + batch_size, num_channels, num_frames, height, width = z.shape ## post quant conv (a mapping) z = z.permute(0, 2, 1, 3, 4).flatten(0, 1) z = self.post_quant_conv(z) - z = z.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4) - - ## out tensor shape - out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 - out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1 - out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1 - - ## cut latent into overlapped small cubes and batch forward - num = 0 - decoded_cube = torch.zeros( - (out_n * out_h * out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), - device=z.device, - dtype=z.dtype, + z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1 + output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1 + output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1 + + count = 0 + decoded_videos = z.new_zeros( + (output_num_frames * output_height * output_width, self.config.out_channels, self.kernel[0], self.kernel[1], self.kernel[2]) ) - vae_batch_input = torch.zeros( - (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), - device=z.device, - dtype=z.dtype, + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) ) - for i in range(out_n): - for j in range(out_h): - for k in range(out_w): - n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0] - h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] - w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] - latent_cube = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num % LOCAL_BS] = latent_cube - if num % LOCAL_BS == LOCAL_BS - 1 or num == out_n * out_h * out_w - 1: - latent = self.decoder(vae_batch_input) - - if num == out_n * out_h * out_w - 1 and num % LOCAL_BS != LOCAL_BS - 1: - decoded_cube[num - num % LOCAL_BS :] = latent[: num % LOCAL_BS + 1] + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0] + h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1] + w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2] + + current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = current_latent + + if count % local_batch_size == local_batch_size - 1 or count == output_num_frames * output_height * output_width - 1: + current_video = self.decoder(vae_batch_input) + + if count == output_num_frames * output_height * output_width - 1 and count % local_batch_size != local_batch_size - 1: + decoded_videos[count - count % local_batch_size :] = current_video[: count % local_batch_size + 1] else: - decoded_cube[num - LOCAL_BS + 1 : num + 1] = latent - vae_batch_input = torch.zeros( - (LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), - device=z.device, - dtype=z.dtype, + decoded_videos[count - local_batch_size + 1 : count + 1] = current_video + + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) ) - num += 1 - B, C, N, H, W = z.shape - - out_video = torch.zeros((B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype) - OVERLAP = KERNEL[0] - STRIDE[0], KERNEL[1] - STRIDE[1], KERNEL[2] - STRIDE[2] - for i in range(out_n): - n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] - for j in range(out_h): - h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] - for k in range(out_w): - w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] - out_video_blend = prepare_for_blend( - (i, out_n, OVERLAP[0]), - (j, out_h, OVERLAP[1]), - (k, out_w, OVERLAP[2]), - decoded_cube[i * out_h * out_w + j * out_w + k].unsqueeze(0), + + count += 1 + + video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs)) + video_overlap = self.kernel[0] - self.stride[0], self.kernel[1] - self.stride[1], self.kernel[2] - self.stride[2] + + for i in range(output_num_frames): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + for j in range(output_height): + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + for k in range(output_width): + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + out_video_blend = _prepare_for_blend( + (i, output_num_frames, video_overlap[0]), + (j, output_height, video_overlap[1]), + (k, output_width, video_overlap[2]), + decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0), ) - out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend - - out_video = out_video.permute(0, 2, 1, 3, 4).contiguous() + video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend - return out_video + video = video.permute(0, 2, 1, 3, 4).contiguous() + return video def forward( self, @@ -1143,7 +1088,8 @@ def forward( return DecoderOutput(sample=dec) -def prepare_for_blend(n_param, h_param, w_param, x): +def _prepare_for_blend(n_param, h_param, w_param, x): + # TODO(aryan): refactor n, n_max, overlap_n = n_param h, h_max, overlap_h = h_param w, w_max, overlap_w = w_param From cf5dec1d3f40021e51eeecc643b266b8423d9a5b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 08:43:01 +0200 Subject: [PATCH 24/33] make style --- .../autoencoders/autoencoder_kl_allegro.py | 76 ++++++++++++++----- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 1d2b306a8189..dbdcb5699090 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -813,9 +813,7 @@ def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): module.gradient_checkpointing = value - def enable_tiling( - self - ) -> None: + def enable_tiling(self) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow @@ -928,7 +926,13 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: count = 0 output_latent = x.new_zeros( - (output_num_frames * output_height * output_width, 2 * self.config.latent_channels, self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs) + ( + output_num_frames * output_height * output_width, + 2 * self.config.latent_channels, + self.kernel[0] // rt, + self.kernel[1] // rs, + self.kernel[2] // rs, + ) ) vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])) @@ -938,28 +942,40 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] - + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] vae_batch_input[count % local_batch_size] = video_cube - if count % local_batch_size == local_batch_size - 1 or count == output_num_frames * output_height * output_width - 1: + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): latent = self.encoder(vae_batch_input) - if count == output_num_frames * output_height * output_width - 1 and count % local_batch_size != local_batch_size - 1: + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1] else: output_latent[count - local_batch_size + 1 : count + 1] = latent - + vae_batch_input = x.new_zeros( (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]) ) - + count += 1 - latent = x.new_zeros((batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)) + latent = x.new_zeros( + (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs) + ) output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs - output_overlap = output_kernel[0] - output_stride[0], output_kernel[1] - output_stride[1], output_kernel[2] - output_stride[2] + output_overlap = ( + output_kernel[0] - output_stride[0], + output_kernel[1] - output_stride[1], + output_kernel[2] - output_stride[2], + ) for i in range(output_num_frames): n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0] @@ -1001,7 +1017,13 @@ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: count = 0 decoded_videos = z.new_zeros( - (output_num_frames * output_height * output_width, self.config.out_channels, self.kernel[0], self.kernel[1], self.kernel[2]) + ( + output_num_frames * output_height * output_width, + self.config.out_channels, + self.kernel[0], + self.kernel[1], + self.kernel[2], + ) ) vae_batch_input = z.new_zeros( (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) @@ -1013,27 +1035,39 @@ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0] h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1] w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2] - + current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] vae_batch_input[count % local_batch_size] = current_latent - - if count % local_batch_size == local_batch_size - 1 or count == output_num_frames * output_height * output_width - 1: + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): current_video = self.decoder(vae_batch_input) - if count == output_num_frames * output_height * output_width - 1 and count % local_batch_size != local_batch_size - 1: - decoded_videos[count - count % local_batch_size :] = current_video[: count % local_batch_size + 1] + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + decoded_videos[count - count % local_batch_size :] = current_video[ + : count % local_batch_size + 1 + ] else: decoded_videos[count - local_batch_size + 1 : count + 1] = current_video - + vae_batch_input = z.new_zeros( (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) ) - + count += 1 video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs)) - video_overlap = self.kernel[0] - self.stride[0], self.kernel[1] - self.stride[1], self.kernel[2] - self.stride[2] - + video_overlap = ( + self.kernel[0] - self.stride[0], + self.kernel[1] - self.stride[1], + self.kernel[2] - self.stride[2], + ) + for i in range(output_num_frames): n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] for j in range(output_height): From d9eabf843acb2ec8a176d5177b870c5416fa437b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 11:49:20 +0200 Subject: [PATCH 25/33] fix --- src/diffusers/models/transformers/transformer_allegro.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 3d6a8b53172d..5503ff847797 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -317,9 +317,9 @@ def forward( p_t = self.config.patch_size_t p = self.config.patch_size - post_patch_num_frames = num_frames // self.config.patch_size_temporal - post_patch_height = height // self.config.patch_size - post_patch_width = width // self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. From cf010fc25908a9a5aaae71a3fe843ebd4221ed76 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Oct 2024 11:52:49 +0200 Subject: [PATCH 26/33] apply suggestions from review --- .../autoencoders/autoencoder_kl_allegro.py | 6 +- tests/pipelines/allegro/test_allegro.py | 70 +++++++++---------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index dbdcb5699090..4836de7e16ab 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -300,7 +300,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class UNetMidBlock3DConv(nn.Module): +class AllegroMidBlock3DConv(nn.Module): def __init__( self, in_channels: int, @@ -473,7 +473,7 @@ def __init__( self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlock3DConv( + self.mid_block = AllegroMidBlock3DConv( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, @@ -581,7 +581,7 @@ def __init__( temb_channels = in_channels if norm_type == "spatial" else None # mid - self.mid_block = UNetMidBlock3DConv( + self.mid_block = AllegroMidBlock3DConv( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index daac4e4136b6..41305f0eebb8 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -206,40 +206,40 @@ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) - # def test_attention_slicing_forward_pass( - # self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 - # ): - # if not self.test_attention_slicing: - # return - - # components = self.get_dummy_components() - # pipe = self.pipeline_class(**components) - # for component in pipe.components.values(): - # if hasattr(component, "set_default_attn_processor"): - # component.set_default_attn_processor() - # pipe.to(torch_device) - # pipe.set_progress_bar_config(disable=None) - - # generator_device = "cpu" - # inputs = self.get_dummy_inputs(generator_device) - # output_without_slicing = pipe(**inputs)[0] - - # pipe.enable_attention_slicing(slice_size=1) - # inputs = self.get_dummy_inputs(generator_device) - # output_with_slicing1 = pipe(**inputs)[0] - - # pipe.enable_attention_slicing(slice_size=2) - # inputs = self.get_dummy_inputs(generator_device) - # output_with_slicing2 = pipe(**inputs)[0] - - # if test_max_difference: - # max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() - # max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() - # self.assertLess( - # max(max_diff1, max_diff2), - # expected_max_diff, - # "Attention slicing should not affect the inference results", - # ) + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) def test_vae_tiling(self, expected_diff_max: float = 0.2): generator_device = "cpu" @@ -287,7 +287,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_cogvideox(self): + def test_allegro(self): generator = torch.Generator("cpu").manual_seed(0) pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16) From d44a5c8a830a7061b91c985684c4370d2f37b135 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 01:43:15 +0530 Subject: [PATCH 27/33] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/transformers/transformer_allegro.py | 4 ++-- src/diffusers/pipelines/allegro/pipeline_allegro.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b91a605f0cae..db88ecbbb9d3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1524,7 +1524,7 @@ def __call__( class AllegroAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Allegro model. It applies a s normalization layer and rotary embedding on query and key vector. + used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector. """ def __init__(self): diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5503ff847797..f756399a378a 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -207,9 +207,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the input latents. activation_fn (`str`, defaults to `"gelu-approximate"`): Activation function to use in feed-forward. - norm_elementwise_affine (`bool`, defaults to `True`): + norm_elementwise_affine (`bool`, defaults to `False`): Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): + norm_eps (`float`, defaults to `1e-6`): The epsilon value to use in normalization layers. caption_channels (`int`, defaults to `4096`): Number of channels to use for projecting the caption embeddings. diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 340f749c48d5..36f56871d6dc 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -717,14 +717,14 @@ def __call__( timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): + guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. num_frames: (`int`, *optional*, defaults to 88): The number controls the generated video frames. height (`int`, *optional*, defaults to self.unet.config.sample_size): From b036386b64a8ff51566f0ee50df20de770f0c23b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 24 Oct 2024 03:52:15 +0200 Subject: [PATCH 28/33] update example --- .../pipelines/allegro/pipeline_allegro.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 36f56871d6dc..afcc2a568670 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -53,14 +53,19 @@ Examples: ```py >>> import torch + >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline + >>> from diffusers.utils import export_to_video - >>> # You can replace the your_path_to_model with your own path. - >>> pipe = AllegroPipeline.from_pretrained( - ... your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True - ... ) + >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) + >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") - >>> prompt = "A small cactus with a happy face in the Sahara desert." - >>> image = pipe(prompt).video[0] + >>> prompt = ( + ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " + ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " + ... "location might be a popular spot for docking fishing boats." + ... ) + >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] + >>> export_to_video(video, "output.mp4", fps=15) ``` """ From 9214f4a3782a74e510eff7e09b59457fe8b63511 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 28 Oct 2024 05:59:23 +0100 Subject: [PATCH 29/33] remove attention mask for self-attention --- src/diffusers/pipelines/allegro/pipeline_allegro.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index afcc2a568670..a8042e75cd14 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -843,6 +843,8 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -884,17 +886,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - if prompt_embeds.ndim == 3: - prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d - - # prepare attention_mask. - # b c t h w -> b t h w - attention_mask = torch.ones_like(latent_model_input)[:, 0] - # predict noise model_output noise_pred = self.transformer( - latent_model_input, - attention_mask=attention_mask, + hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=timestep, From 3354ee180646de5729f6f936a94084cbd93420e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Oct 2024 06:29:38 +0100 Subject: [PATCH 30/33] update --- .../pipelines/allegro/pipeline_allegro.py | 47 +++++++++---------- tests/pipelines/allegro/test_allegro.py | 38 ++++++++++++--- 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index a8042e75cd14..701893ab9e6b 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -188,7 +188,7 @@ class AllegroPipeline(DiffusionPipeline): + r"]{1,}" ) # noqa - _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] + _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ @@ -226,10 +226,10 @@ def encode_prompt( negative_prompt: str = "", num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 512, ): @@ -249,10 +249,10 @@ def encode_prompt( number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on - prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): @@ -632,15 +632,11 @@ def _prepare_rotary_positional_embeddings( ): grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 1280 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) + start, stop = (0, 0), (grid_height, grid_width) freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, + crops_coords=(start, stop), grid_size=(grid_height, grid_width), temporal_size=num_frames, interpolation_scale=( @@ -655,7 +651,7 @@ def _prepare_rotary_positional_embeddings( grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long) pos = torch.cartesian_prod(grid_t, grid_h, grid_w) - pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, batch_size, -1) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() grid_t, grid_h, grid_w = pos freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device)) @@ -691,11 +687,11 @@ def __call__( num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ @@ -742,18 +738,18 @@ def __call__( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between @@ -762,7 +758,7 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -874,6 +870,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 41305f0eebb8..d09fc0488378 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, T5Config, T5EncoderModel from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler from diffusers.utils.testing_utils import ( @@ -62,11 +62,11 @@ def get_dummy_components(self): in_channels=4, out_channels=4, num_layers=1, - cross_attention_dim=32, + cross_attention_dim=24, sample_width=8, sample_height=8, sample_frames=8, - caption_channels=32, + caption_channels=24, ) torch.manual_seed(0) @@ -92,9 +92,25 @@ def get_dummy_components(self): temporal_compression_ratio=4, ) + # TODO(aryan): Only for now, since VAE decoding without tiling is not yet implemented here + vae.enable_tiling() + torch.manual_seed(0) scheduler = DDIMScheduler() - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + text_encoder_config = T5Config( + **{ + "d_ff": 37, + "d_kv": 8, + "d_model": 24, + "num_decoder_layers": 2, + "num_heads": 4, + "num_layers": 2, + "relative_attention_num_buckets": 8, + "vocab_size": 1103, + } + ) + text_encoder = T5EncoderModel(text_encoder_config) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") components = { @@ -118,8 +134,8 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "height": 48, - "width": 48, + "height": 16, + "width": 16, "num_frames": 8, "max_sequence_length": 16, "output_type": "pt", @@ -127,6 +143,14 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_local(self): + pass + + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_optional_components(self): + pass + def test_inference(self): device = "cpu" @@ -241,6 +265,8 @@ def test_attention_slicing_forward_pass( "Attention slicing should not affect the inference results", ) + # TODO(aryan) + @unittest.skip("Decoding without tiling is not yet implemented.") def test_vae_tiling(self, expected_diff_max: float = 0.2): generator_device = "cpu" components = self.get_dummy_components() From 28e57585d13adfcbd1ced94800dedb816900eb4a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Oct 2024 07:29:35 +0100 Subject: [PATCH 31/33] copied from --- .../pipelines/allegro/pipeline_allegro.py | 19 ++++++++++++------- tests/pipelines/test_pipelines_common.py | 1 + 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 701893ab9e6b..fa4f8e35d1ec 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BACKENDS_MAPPING, + deprecate, is_bs4_available, is_ftfy_available, logging, @@ -219,6 +220,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -232,6 +234,7 @@ def encode_prompt( negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 512, + **kwargs, ): r""" Encodes the prompt into text encoder hidden states. @@ -245,7 +248,7 @@ def encode_prompt( PixArt-Alpha, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not - num_images_per_prompt (`int`, *optional*, defaults to 1): + num_videos_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on @@ -257,10 +260,13 @@ def encode_prompt( string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. - max_sequence_length (`int`, defaults to `512`): - Maximum sequence length to use for the prompt. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt. """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + if device is None: device = self._execution_device @@ -292,7 +298,7 @@ def encode_prompt( ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" + "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) @@ -320,7 +326,7 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( @@ -336,8 +342,7 @@ def encode_prompt( negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=negative_prompt_attention_mask, + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask ) negative_prompt_embeds = negative_prompt_embeds[0] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 3e6f9d1278e8..295a94c1d2e4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1103,6 +1103,7 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): + print(batch_size, batched_input) output = pipe(**batched_input) assert len(output[0]) == batch_size From 1ec17d5147d5344cd003c864a59331110423d2b7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Oct 2024 08:15:24 +0100 Subject: [PATCH 32/33] update --- .../pipelines/allegro/pipeline_allegro.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index fa4f8e35d1ec..d2073d8b5e98 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -131,25 +131,6 @@ def retrieve_timesteps( return timesteps, num_inference_steps -# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) - - class AllegroPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Allegro. From 4d6d4e43a33e7e92852f730136c66b2575446304 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Oct 2024 08:15:44 +0100 Subject: [PATCH 33/33] update --- .../pipelines/allegro/pipeline_allegro.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index d2073d8b5e98..9314960f9618 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -133,14 +133,14 @@ def retrieve_timesteps( class AllegroPipeline(DiffusionPipeline): r""" - Pipeline for text-to-image generation using Allegro. + Pipeline for text-to-video generation using Allegro. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AllegroAutoEncoderKL3D`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. text_encoder ([`T5EncoderModel`]): Frozen text-encoder. PixArt-Alpha uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the @@ -149,9 +149,9 @@ class AllegroPipeline(DiffusionPipeline): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). transformer ([`AllegroTransformer3DModel`]): - A text conditioned `AllegroTransformer3DModel` to denoise the encoded image latents. + A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ bad_punct_regex = re.compile( @@ -692,14 +692,14 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` @@ -708,16 +708,16 @@ def __call__( Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. num_frames: (`int`, *optional*, defaults to 88): The number controls the generated video frames. height (`int`, *optional*, defaults to self.unet.config.sample_size): - The height in pixels of the generated image. + The height in pixels of the generated video. width (`int`, *optional*, defaults to self.unet.config.sample_size): - The width in pixels of the generated image. + The width in pixels of the generated video. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -725,8 +725,8 @@ def __call__( One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -738,7 +738,7 @@ def __call__( negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between + The output format of the generate video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. @@ -758,9 +758,9 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images + [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):