From b02915b5cd9776575cac1d3892b39e1cb7255abe Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 6 Nov 2024 22:10:53 +0800 Subject: [PATCH 01/36] CogVideoX1_1PatchEmbed test --- scripts/convert_cogvideox_to_diffusers.py | 2 +- src/diffusers/models/embeddings.py | 117 ++++++++++++++++++ .../transformers/cogvideox_transformer_3d.py | 11 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 13 +- 4 files changed, 129 insertions(+), 14 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 4343eaf34038..cc4e407d5450 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -241,7 +241,7 @@ def get_args(): if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) - text_encoder_id = "google/t5-v1_1-xxl" + text_encoder_id = "/share/home/zyx/Models/CogVideoX1.1-5B-SAT/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7cbd958e1d6e..5cf6e3988d50 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn.functional as F +from einops import rearrange from torch import nn from ..utils import deprecate @@ -333,6 +334,122 @@ def forward(self, x, freqs_cis): freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), ) +class CogVideoX1_1PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 81, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, + use_learned_positional_embeddings: bool = True, + ) -> None: + super().__init__() + + # Adjust patch_size to handle three dimensions + self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + # Use Linear layer for projection + self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size[1] + post_patch_width = sample_width // self.patch_size[2] + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + + return joint_pos_embedding + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + """ + Args: + text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). + image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + first_frame = image_embeds[:, 0:1, :, :, :] + duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width) + # Copy the first frames, for t_patch + image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1) + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous() + image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1) + + rope_patch_t = num_frames // self.patch_size[0] + rope_patch_h = height // self.patch_size[1] + rope_patch_w = width // self.patch_size[2] + + image_embeds = image_embeds.view( + batch, + rope_patch_t, self.patch_size[0], + rope_patch_h, self.patch_size[1], + rope_patch_w, self.patch_size[2], + channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) + image_embeds = self.proj(image_embeds) + # Concatenate text and image embeddings + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + # Add positional embeddings if applicable + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + + return embeds + class CogVideoXPatchEmbed(nn.Module): def __init__( diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 821da6d032d5..6477c91c87a1 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,7 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, CogVideoX1_1PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -249,12 +249,13 @@ def __init__( ) # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed( + # self.patch_embed = CogVideoXPatchEmbed( + self.patch_embed = CogVideoX1_1PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - bias=True, + # bias=True, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, @@ -298,7 +299,7 @@ def __init__( norm_eps=norm_eps, chunk_dim=1, ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) # For CogVideoX1.1-5B self.gradient_checkpointing = False @@ -504,4 +505,4 @@ def custom_forward(*inputs): if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 9cb042c9e80c..619e7d389985 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -442,8 +442,11 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + # TODO: Here, compatibility is needed for both the CogVideoX-5B and CogVideoX1.1-5B models. + # CogVideoX1.0 is 720 X 480 and CogVideoX1.1-5B T2V is 768 * 1360, CogVideoX1.1-5B I2V use with image + base_size_width = 768 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 1360 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -583,11 +586,6 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -679,7 +677,6 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, From 87535d6a0e3c0c1ffc13c2a3bf524494f9445481 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 6 Nov 2024 22:13:23 +0800 Subject: [PATCH 02/36] 1360 * 768 --- .../models/transformers/cogvideox_transformer_3d.py | 4 +++- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 6477c91c87a1..29ee28835832 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -249,13 +249,15 @@ def __init__( ) # 1. Patch embedding + #TODO: different git push --set-upstream origin cogvideox1.1-5b + # self.patch_embed = CogVideoXPatchEmbed( self.patch_embed = CogVideoX1_1PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - # bias=True, + # bias=True, # Only using in CogVideoX-5B sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 619e7d389985..6708332ca4e0 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -444,9 +444,9 @@ def _prepare_rotary_positional_embeddings( grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) # TODO: Here, compatibility is needed for both the CogVideoX-5B and CogVideoX1.1-5B models. - # CogVideoX1.0 is 720 X 480 and CogVideoX1.1-5B T2V is 768 * 1360, CogVideoX1.1-5B I2V use with image - base_size_width = 768 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 1360 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + # CogVideoX1.0 is 720 X 480 and CogVideoX1.1-5B T2V is 1360 * 768, CogVideoX1.1-5B I2V use with image + base_size_width = 1360 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 768 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height From b033aada3b1b15e7056e83ae6a7f706b4921d621 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:08:55 +0100 Subject: [PATCH 03/36] refactor --- scripts/convert_cogvideox_to_diffusers.py | 38 ++++- src/diffusers/models/embeddings.py | 152 ++++-------------- .../transformers/cogvideox_transformer_3d.py | 47 ++++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 22 ++- 4 files changed, 110 insertions(+), 149 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index cc4e407d5450..ff8ff556832e 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -140,6 +140,7 @@ def convert_transformer( use_rotary_positional_embeddings: bool, i2v: bool, dtype: torch.dtype, + init_kwargs: Dict[str, Any] ): PREFIX_KEY = "model.diffusion_model." @@ -150,6 +151,7 @@ def convert_transformer( num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, use_learned_positional_embeddings=i2v, + **init_kwargs, ).to(dtype=dtype) for key in list(original_state_dict.keys()): @@ -163,6 +165,7 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -187,6 +190,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): return vae +def get_init_kwargs(version: str): + if version == "1.0": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": None, + "patch_bias": True, + "sample_height": 480 // vae_scale_factor_spatial, + "sample_width": 720 // vae_scale_factor_spatial, + "sample_frames": 49, + } + + elif version == "1.5": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": 2, + "patch_bias": False, + "sample_height": 768 // vae_scale_factor_spatial, + "sample_width": 1360 // vae_scale_factor_spatial, + "sample_frames": 81, + } + else: + raise ValueError("Unsupported version of CogVideoX.") + + return init_kwargs + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -214,7 +245,8 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") - parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether the model to be converted is the Image-to-Video version of CogVideoX.") + parser.add_argument("--version", choices=["1.0", "1.5"], default="1.0", help="Which version of CogVideoX to use for initializing default modeling parameters.") return parser.parse_args() @@ -230,6 +262,7 @@ def get_args(): dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 if args.transformer_ckpt_path is not None: + init_kwargs = get_init_kwargs(args.version) transformer = convert_transformer( args.transformer_ckpt_path, args.num_layers, @@ -237,11 +270,12 @@ def get_args(): args.use_rotary_positional_embeddings, args.i2v, dtype, + init_kwargs, ) if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) - text_encoder_id = "/share/home/zyx/Models/CogVideoX1.1-5B-SAT/t5-v1_1-xxl" + text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5cf6e3988d50..f349f03b2a60 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -334,127 +334,12 @@ def forward(self, x, freqs_cis): freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), ) -class CogVideoX1_1PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 81, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_positional_embeddings: bool = True, - use_learned_positional_embeddings: bool = True, - ) -> None: - super().__init__() - - # Adjust patch_size to handle three dimensions - self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) - self.embed_dim = embed_dim - self.sample_height = sample_height - self.sample_width = sample_width - self.sample_frames = sample_frames - self.temporal_compression_ratio = temporal_compression_ratio - self.max_text_seq_length = max_text_seq_length - self.spatial_interpolation_scale = spatial_interpolation_scale - self.temporal_interpolation_scale = temporal_interpolation_scale - self.use_positional_embeddings = use_positional_embeddings - self.use_learned_positional_embeddings = use_learned_positional_embeddings - - # Use Linear layer for projection - self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) - self.text_proj = nn.Linear(text_embed_dim, embed_dim) - - if use_positional_embeddings or use_learned_positional_embeddings: - persistent = use_learned_positional_embeddings - pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) - self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) - - def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: - post_patch_height = sample_height // self.patch_size[1] - post_patch_width = sample_width // self.patch_size[2] - post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 - num_patches = post_patch_height * post_patch_width * post_time_compression_frames - - pos_embedding = get_3d_sincos_pos_embed( - self.embed_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - self.spatial_interpolation_scale, - self.temporal_interpolation_scale, - ) - pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) - joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) - joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) - - return joint_pos_embedding - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - """ - Args: - text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). - image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). - """ - text_embeds = self.text_proj(text_embeds) - first_frame = image_embeds[:, 0:1, :, :, :] - duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width) - # Copy the first frames, for t_patch - image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1) - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous() - image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1) - - rope_patch_t = num_frames // self.patch_size[0] - rope_patch_h = height // self.patch_size[1] - rope_patch_w = width // self.patch_size[2] - - image_embeds = image_embeds.view( - batch, - rope_patch_t, self.patch_size[0], - rope_patch_h, self.patch_size[1], - rope_patch_w, self.patch_size[2], - channels - ) - image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() - image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) - image_embeds = self.proj(image_embeds) - # Concatenate text and image embeddings - embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() - - # Add positional embeddings if applicable - if self.use_positional_embeddings or self.use_learned_positional_embeddings: - if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): - raise ValueError( - "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." - "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." - ) - - pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 - - if ( - self.sample_height != height - or self.sample_width != width - or self.sample_frames != pre_time_compression_frames - ): - pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) - pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) - else: - pos_embedding = self.pos_embedding - - embeds = embeds + pos_embedding - - return embeds - class CogVideoXPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, + patch_size_t: Optional[int] = None, in_channels: int = 16, embed_dim: int = 1920, text_embed_dim: int = 4096, @@ -472,6 +357,7 @@ def __init__( super().__init__() self.patch_size = patch_size + self.patch_size_t = patch_size_t self.embed_dim = embed_dim self.sample_height = sample_height self.sample_width = sample_width @@ -483,9 +369,15 @@ def __init__( self.use_positional_embeddings = use_positional_embeddings self.use_learned_positional_embeddings = use_learned_positional_embeddings - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) + if patch_size_t is None: + # CogVideoX 1.0 checkpoints + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + else: + # CogVideoX 1.5 checkpoints + self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) if use_positional_embeddings or use_learned_positional_embeddings: @@ -524,12 +416,22 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): """ text_embeds = self.text_proj(text_embeds) - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + batch_size, num_frames, channels, height, width = image_embeds.shape + + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + else: + p = self.patch_size + p_t = self.patch_size_t + + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape(batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) embeds = torch.cat( [text_embeds, image_embeds], dim=1 diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 29ee28835832..85cf7c04dbe0 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,7 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, CogVideoX1_1PatchEmbed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -227,6 +227,7 @@ def __init__( sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, + patch_size_t: int = 2, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", @@ -237,6 +238,7 @@ def __init__( temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, + patch_bias: bool = True, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -249,15 +251,13 @@ def __init__( ) # 1. Patch embedding - #TODO: different git push --set-upstream origin cogvideox1.1-5b - - # self.patch_embed = CogVideoXPatchEmbed( - self.patch_embed = CogVideoX1_1PatchEmbed( + self.patch_embed = CogVideoXPatchEmbed( patch_size=patch_size, + patch_size_t=patch_size_t, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - # bias=True, # Only using in CogVideoX-5B + bias=patch_bias, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, @@ -301,7 +301,15 @@ def __init__( norm_eps=norm_eps, chunk_dim=1, ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) # For CogVideoX1.1-5B + + if patch_size_t is None: + # For CogVideox 1.0 + output_dim = patch_size * patch_size * out_channels + else: + # For CogVideoX 1.5 + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = nn.Linear(inner_dim, output_dim) self.gradient_checkpointing = False @@ -446,6 +454,16 @@ def forward( emb = self.time_embedding(t_emb, timestep_cond) # 2. Patch embedding + p = self.config.patch_size + p_t = self.config.patch_size_t + + # We know that the hidden states height and width will always be divisible by patch_size. + # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. + if p_t is not None: + remaining_frames = p_t - num_frames % p_t + first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) + hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) @@ -494,12 +512,13 @@ def custom_forward(*inputs): hidden_states = self.proj_out(hidden_states) # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + output = output[:, remaining_frames:] if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer @@ -507,4 +526,4 @@ def custom_forward(*inputs): if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) \ No newline at end of file + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 6708332ca4e0..50486b19b54f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -443,11 +443,13 @@ def _prepare_rotary_positional_embeddings( grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - # TODO: Here, compatibility is needed for both the CogVideoX-5B and CogVideoX1.1-5B models. - # CogVideoX1.0 is 720 X 480 and CogVideoX1.1-5B T2V is 1360 * 768, CogVideoX1.1-5B I2V use with image - base_size_width = 1360 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 768 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t + grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height ) @@ -455,7 +457,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -484,9 +486,9 @@ def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - num_frames: int = 49, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -589,6 +591,10 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct From 67cb3735845f5ccff2539e6d96684887c9dde9d5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:09:18 +0100 Subject: [PATCH 04/36] make style --- scripts/convert_cogvideox_to_diffusers.py | 22 ++++++++++++++----- src/diffusers/models/embeddings.py | 7 +++--- .../transformers/cogvideox_transformer_3d.py | 6 +++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index ff8ff556832e..0aad301751f1 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -140,7 +140,7 @@ def convert_transformer( use_rotary_positional_embeddings: bool, i2v: bool, dtype: torch.dtype, - init_kwargs: Dict[str, Any] + init_kwargs: Dict[str, Any], ): PREFIX_KEY = "model.diffusion_model." @@ -165,7 +165,7 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) - + transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -201,7 +201,7 @@ def get_init_kwargs(version: str): "sample_width": 720 // vae_scale_factor_spatial, "sample_frames": 49, } - + elif version == "1.5": vae_scale_factor_spatial = 8 init_kwargs = { @@ -214,7 +214,7 @@ def get_init_kwargs(version: str): } else: raise ValueError("Unsupported version of CogVideoX.") - + return init_kwargs @@ -245,8 +245,18 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") - parser.add_argument("--i2v", action="store_true", default=False, help="Whether the model to be converted is the Image-to-Video version of CogVideoX.") - parser.add_argument("--version", choices=["1.0", "1.5"], default="1.0", help="Which version of CogVideoX to use for initializing default modeling parameters.") + parser.add_argument( + "--i2v", + action="store_true", + default=False, + help="Whether the model to be converted is the Image-to-Video version of CogVideoX.", + ) + parser.add_argument( + "--version", + choices=["1.0", "1.5"], + default="1.0", + help="Which version of CogVideoX to use for initializing default modeling parameters.", + ) return parser.parse_args() diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f349f03b2a60..b3212e43a61f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -17,7 +17,6 @@ import numpy as np import torch import torch.nn.functional as F -from einops import rearrange from torch import nn from ..utils import deprecate @@ -377,7 +376,7 @@ def __init__( else: # CogVideoX 1.5 checkpoints self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) - + self.text_proj = nn.Linear(text_embed_dim, embed_dim) if use_positional_embeddings or use_learned_positional_embeddings: @@ -429,7 +428,9 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): p_t = self.patch_size_t image_embeds = image_embeds.permute(0, 1, 3, 4, 2) - image_embeds = image_embeds.reshape(batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) image_embeds = self.proj(image_embeds) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 85cf7c04dbe0..84bc8ef5145c 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -308,7 +308,7 @@ def __init__( else: # For CogVideoX 1.5 output_dim = patch_size * patch_size * patch_size_t * out_channels - + self.proj_out = nn.Linear(inner_dim, output_dim) self.gradient_checkpointing = False @@ -516,7 +516,9 @@ def custom_forward(*inputs): output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) else: - output = hidden_states.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p) + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) output = output[:, remaining_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 50486b19b54f..17831c13847b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -449,7 +449,7 @@ def _prepare_rotary_positional_embeddings( base_size_width = self.transformer.config.sample_width // p base_size_height = self.transformer.config.sample_height // p base_num_frames = (num_frames + p_t - 1) // p_t - + grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height ) From e48184389f50caf7292a890823c7a03bb14b80cf Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:40:39 +0100 Subject: [PATCH 05/36] update docs --- docs/source/en/api/pipelines/cogvideox.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index f0f4fd37e6d5..7afec984fdd6 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -29,16 +29,18 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). -There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines: -- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`. -- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`. +There are three official models available that can be used with the text-to-video and video-to-video CogVideoX pipelines: +- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `torch.float16`. +- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `torch.bfloat16`. +- [`THUDM/CogVideoX-1.5-5b`](https://huggingface.co/THUDM/CogVideoX-1.5-5b): The recommended dtype for running this mdoel is `torch.bfloat16`. There is one model available that can be used with the image-to-video CogVideoX pipeline: -- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. +- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. +- [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): -- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. -- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. +- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `torch.bfloat16`. +- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `torch.bfloat16`. ## Inference From 9edddc1da8232a344080ceab608d5238890e8b92 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:43:54 +0100 Subject: [PATCH 06/36] add modeling tests for cogvideox 1.5 --- .../test_models_transformer_cogvideox.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 1342577f0114..e0350ef8dd99 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -76,6 +76,7 @@ def prepare_init_args_and_inputs_for_common(self): "sample_height": 8, "sample_frames": 8, "patch_size": 2, + "patch_size_t": None, "temporal_compression_ratio": 4, "max_text_seq_length": 8, } @@ -85,3 +86,63 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"CogVideoXTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = CogVideoXTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (1, 4, 8, 8) + + @property + def output_shape(self): + return (1, 4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "time_embed_dim": 2, + "text_embed_dim": 8, + "num_layers": 1, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "patch_size": 2, + "patch_size_t": 2, + "temporal_compression_ratio": 4, + "max_text_seq_length": 8, + "use_rotary_positional_embeddings": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogVideoXTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) From ea56788ea43c81b129b543172dd30fa7983b4685 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:44:18 +0100 Subject: [PATCH 07/36] update --- scripts/convert_cogvideox_to_diffusers.py | 17 +++++++++++------ .../pipelines/cogvideo/pipeline_cogvideox.py | 1 + 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 0aad301751f1..a4d4f2481834 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -233,6 +233,12 @@ def get_args(): parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) + parser.add_argument( + "--typecast_text_encoder", + action="store_true", + default=False, + help="Whether or not to apply fp16/bf16 precision to text_encoder", + ) # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 @@ -283,12 +289,16 @@ def get_args(): init_kwargs, ) if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + # Keep VAE in float32 for better quality + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, torch.float32) text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + if args.typecast_text_encoder: + text_encoder = text_encoder.to(dtype=dtype) + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -320,11 +330,6 @@ def get_args(): scheduler=scheduler, ) - if args.fp16: - pipe = pipe.to(dtype=torch.float16) - if args.bf16: - pipe = pipe.to(dtype=torch.bfloat16) - # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # for users to specify variant when the default is not fp32 and they want to run with the correct default (which # is either fp16/bf16 here). diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 17831c13847b..241df8233075 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -683,6 +683,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, From d833f72f80de328b1c830e53da392a69b7f2d423 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 23:45:32 +0100 Subject: [PATCH 08/36] make fix-copies --- .../cogvideo/pipeline_cogvideox_fun_control.py | 11 ++++++++--- .../cogvideo/pipeline_cogvideox_image2video.py | 11 ++++++++--- .../cogvideo/pipeline_cogvideox_video2video.py | 11 ++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3655075bd519..814afb540f5a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -488,8 +488,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -498,7 +503,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 783dae569bec..8cbd343a60cd 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -522,8 +522,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -532,7 +537,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index e1e816eca16d..118b1064b21e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -518,8 +518,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -528,7 +533,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) From b87b07e6f4467018d1c941f8b4e13c5516f1b0da Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Sat, 9 Nov 2024 15:14:40 +0800 Subject: [PATCH 09/36] add ofs embed(for convert) --- scripts/convert_cogvideox_to_diffusers.py | 7 +++++-- .../models/transformers/cogvideox_transformer_3d.py | 8 +++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index a4d4f2481834..bd7d3064ce92 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "post_attn1_layernorm": "norm2.norm", "time_embed.0": "time_embedding.linear_1", "time_embed.2": "time_embedding.linear_2", + "ofs_embed.0": "ofs_embedding.linear_1", + "ofs_embed.2": "ofs_embedding.linear_2", "mixins.patch_embed": "patch_embed", "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", @@ -150,7 +152,8 @@ def convert_transformer( num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, - use_learned_positional_embeddings=i2v, + ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V + use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V **init_kwargs, ).to(dtype=dtype) @@ -210,7 +213,7 @@ def get_init_kwargs(version: str): "patch_bias": False, "sample_height": 768 // vae_scale_factor_spatial, "sample_width": 1360 // vae_scale_factor_spatial, - "sample_frames": 81, + "sample_frames": 81, # TODO: Need Test with 161 for 10 seconds } else: raise ValueError("Unsupported version of CogVideoX.") diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 5a08c9a75cfb..92858f848561 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -219,6 +219,7 @@ def __init__( flip_sin_to_cos: bool = True, freq_shift: int = 0, time_embed_dim: int = 512, + ofs_embed_dim: Optional[int] = 512, text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, @@ -270,10 +271,15 @@ def __init__( ) self.embedding_dropout = nn.Dropout(dropout) - # 2. Time embeddings + # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have) + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + if ofs_embed_dim: + self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs + self.ofs_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ From e254bcb4063b63a4eb8bef5415dcfe0f056c4496 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Sat, 9 Nov 2024 15:34:13 +0800 Subject: [PATCH 10/36] add ofs embed(for convert) --- .../transformers/cogvideox_transformer_3d.py | 14 ++++++++++---- .../cogvideo/pipeline_cogvideox_image2video.py | 5 ----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 92858f848561..55b65a097e63 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Whether to flip the sin to cos in the time embedding. time_embed_dim (`int`, defaults to `512`): Output dimension of timestep embeddings. + ofs_embed_dim (`int`, defaults to `512`): + scaling factor in the VAE process for the Image-to-Video (I2V) transformation in CogVideoX1.5-5B. text_embed_dim (`int`, defaults to `4096`): Input dimension of text embeddings from the text encoder. num_layers (`int`, defaults to `30`): @@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): dropout (`float`, defaults to `0.0`): The dropout probability to use. attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. + Whether to use bias in the attention projection layers. sample_width (`int`, defaults to `90`): The width of the input latents. sample_height (`int`, defaults to `60`): @@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): timestep_activation_fn (`str`, defaults to `"silu"`): Activation function to use when generating the timestep embeddings. norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. + Whether to use elementwise affine in normalization layers. norm_eps (`float`, defaults to `1e-5`): The epsilon value to use in normalization layers. spatial_interpolation_scale (`float`, defaults to `1.875`): @@ -219,7 +221,7 @@ def __init__( flip_sin_to_cos: bool = True, freq_shift: int = 0, time_embed_dim: int = 512, - ofs_embed_dim: Optional[int] = 512, + ofs_embed_dim: Optional[int] = None, text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, @@ -276,9 +278,10 @@ def __init__( self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + self.ofs_embedding = None + if ofs_embed_dim: self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs - self.ofs_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( @@ -458,6 +461,9 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) + if self.ofs_embedding is not None: + emb_ofs = self.ofs_embedding(emb, timestep_cond) + emb = emb + emb_ofs # 2. Patch embedding p = self.config.patch_size diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 8cbd343a60cd..71e0a7798717 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -671,11 +671,6 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs From be80dbf1187496940acfb17526eea5ada3306c6e Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Sun, 10 Nov 2024 17:50:54 +0800 Subject: [PATCH 11/36] more resolution for cogvideox1.5-5b-i2v --- .../pipelines/cogvideo/pipeline_cogvideox.py | 1 + .../pipeline_cogvideox_image2video.py | 22 +++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 241df8233075..4ebf81e278f4 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -734,6 +734,7 @@ def __call__( progress_bar.update() if not output_type == "latent": + breakpoint() video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 71e0a7798717..1d5270925708 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -567,8 +567,8 @@ def __call__( image: PipelineImageInput, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, + height: int = 768, + width: int = 1360, num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, @@ -675,7 +675,6 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs num_videos_per_prompt = 1 - # 1. Check inputs. Raise error if not correct self.check_inputs( image=image, @@ -726,6 +725,22 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latents + # TODO: Only CogVideoX1.5-5B-I2V can use this method. Need to Change + def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_width, divisor=16): + # Step 1: Compare image dimensions with target dimensions + if image_height > tgt_height: + image_height = tgt_height + if image_width > tgt_width: + image_width = tgt_width + + # Step 2: Ensure height and width are divisible by the divisor + image_height = (image_height // divisor) * divisor + image_width = (image_width // divisor) * divisor + return image_height, image_width + + image_width, image_height = image.size[-2:] + + height, width = adjust_resolution_to_divisible(image_height, image_width, height, width) image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) @@ -746,7 +761,6 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) From b94c7047ffc0d7f9cc58ee075d82151940f7c5a0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 10 Nov 2024 21:37:24 +0100 Subject: [PATCH 12/36] use even number of latent frames only --- scripts/convert_cogvideox_to_diffusers.py | 2 +- .../transformers/cogvideox_transformer_3d.py | 14 +++----------- .../pipelines/cogvideo/pipeline_cogvideox.py | 7 ++++++- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index bd7d3064ce92..ca2cba598ebf 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -213,7 +213,7 @@ def get_init_kwargs(version: str): "patch_bias": False, "sample_height": 768 // vae_scale_factor_spatial, "sample_width": 1360 // vae_scale_factor_spatial, - "sample_frames": 81, # TODO: Need Test with 161 for 10 seconds + "sample_frames": 85, } else: raise ValueError("Unsupported version of CogVideoX.") diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 55b65a097e63..12a26c202e36 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -466,16 +466,6 @@ def forward( emb = emb + emb_ofs # 2. Patch embedding - p = self.config.patch_size - p_t = self.config.patch_size_t - - # We know that the hidden states height and width will always be divisible by patch_size. - # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. - if p_t is not None: - remaining_frames = p_t - num_frames % p_t - first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) - hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) @@ -524,6 +514,9 @@ def custom_forward(*inputs): hidden_states = self.proj_out(hidden_states) # 5. Unpatchify + p = self.config.patch_size + p_t = self.config.patch_size_t + if p_t is None: output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) @@ -532,7 +525,6 @@ def custom_forward(*inputs): batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - output = output[:, remaining_frames:] if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 4ebf81e278f4..8f4e71a05494 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -368,12 +368,12 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, prompt, height, width, + num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, @@ -382,6 +382,10 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: + raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -602,6 +606,7 @@ def __call__( prompt, height, width, + num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, From 048a5f02d96f6529e255ed2cdc2310af7307ecaa Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 10 Nov 2024 21:50:00 +0100 Subject: [PATCH 13/36] update pipeline implementations --- .../pipeline_cogvideox_fun_control.py | 21 +++++++++++++------ .../pipeline_cogvideox_image2video.py | 15 +++++++++++-- .../pipeline_cogvideox_video2video.py | 14 +++++++++++-- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 814afb540f5a..5ddbbd6c35ed 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -412,6 +412,7 @@ def check_inputs( prompt, height, width, + num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, @@ -421,6 +422,10 @@ def check_inputs( ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: + raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -533,8 +538,8 @@ def __call__( prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, control_video: Optional[List[Image.Image]] = None, - height: int = 480, - width: int = 720, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -638,7 +643,14 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if control_video is not None and isinstance(control_video[0], Image.Image): + control_video = [control_video] + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -646,6 +658,7 @@ def __call__( prompt, height, width, + num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, @@ -665,9 +678,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - if control_video is not None and isinstance(control_video[0], Image.Image): - control_video = [control_video] - device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -695,7 +705,6 @@ def __call__( # 5. Prepare latents. latent_channels = self.transformer.config.in_channels // 2 - num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 1d5270925708..d8b48a976b31 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -440,6 +440,7 @@ def check_inputs( prompt, height, width, + num_frames, negative_prompt, callback_on_step_end_tensor_inputs, latents=None, @@ -459,6 +460,10 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: + raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -567,8 +572,8 @@ def __call__( image: PipelineImageInput, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 768, - width: int = 1360, + height: Optional[int] = None, + width: Optional[int] = None, num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, @@ -674,12 +679,18 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + num_videos_per_prompt = 1 + # 1. Check inputs. Raise error if not correct self.check_inputs( image=image, prompt=prompt, height=height, + num_frames=num_frames, width=width, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 118b1064b21e..4b42fa6e1e10 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -438,6 +438,7 @@ def check_inputs( prompt, height, width, + num_frames, strength, negative_prompt, callback_on_step_end_tensor_inputs, @@ -448,6 +449,10 @@ def check_inputs( ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: + raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -563,8 +568,8 @@ def __call__( video: List[Image.Image] = None, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, strength: float = 0.8, @@ -667,6 +672,10 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(video) if latents is None else latents.size(1) + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -674,6 +683,7 @@ def __call__( prompt=prompt, height=height, width=width, + num_frames=num_frames, strength=strength, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, From 0c98aad98d14fe95a314aa173e6bde8d959ce630 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 10 Nov 2024 21:50:29 +0100 Subject: [PATCH 14/36] make style --- .../transformers/cogvideox_transformer_3d.py | 6 ++++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 9 +++++++-- .../cogvideo/pipeline_cogvideox_fun_control.py | 15 ++++++++++----- .../cogvideo/pipeline_cogvideox_image2video.py | 11 ++++++++--- .../cogvideo/pipeline_cogvideox_video2video.py | 13 +++++++++---- 5 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 12a26c202e36..d43a5a562069 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -281,7 +281,9 @@ def __init__( self.ofs_embedding = None if ofs_embed_dim: - self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs + self.ofs_embedding = TimestepEmbedding( + ofs_embed_dim, ofs_embed_dim, timestep_activation_fn + ) # same as time embeddings, for ofs # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( @@ -516,7 +518,7 @@ def custom_forward(*inputs): # 5. Unpatchify p = self.config.patch_size p_t = self.config.patch_size_t - + if p_t is None: output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 8f4e71a05494..5e0ad64f82a6 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -383,8 +383,13 @@ def check_inputs( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: - raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if ( + self.transformer.config.patch_size_t is not None + and latent_frames % self.transformer.config.patch_size_t != 0 + ): + raise ValueError( + f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 5ddbbd6c35ed..58fc94ab9fd5 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -422,10 +422,15 @@ def check_inputs( ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: - raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if ( + self.transformer.config.patch_size_t is not None + and latent_frames % self.transformer.config.patch_size_t != 0 + ): + raise ValueError( + f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -643,14 +648,14 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - + if control_video is not None and isinstance(control_video[0], Image.Image): control_video = [control_video] height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) - + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index d8b48a976b31..677b41bd775f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -461,8 +461,13 @@ def check_inputs( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: - raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if ( + self.transformer.config.patch_size_t is not None + and latent_frames % self.transformer.config.patch_size_t != 0 + ): + raise ValueError( + f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -682,7 +687,7 @@ def __call__( height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial num_frames = num_frames or self.transformer.config.sample_frames - + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 4b42fa6e1e10..674c6e5d55e6 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -449,10 +449,15 @@ def check_inputs( ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if self.transformer.config.patch_size_t is not None and latent_frames % self.transformer.config.patch_size_t != 0: - raise ValueError(f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}.") + if ( + self.transformer.config.patch_size_t is not None + and latent_frames % self.transformer.config.patch_size_t != 0 + ): + raise ValueError( + f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." + ) if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -675,7 +680,7 @@ def __call__( height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial num_frames = len(video) if latents is None else latents.size(1) - + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct From 7a1b579d9368d579de3a878febbe8c4ebad48fb7 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 11 Nov 2024 19:24:09 +0800 Subject: [PATCH 15/36] set patch_size_t as None by default --- .../transformers/cogvideox_transformer_3d.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 15 ++------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index d43a5a562069..1fe80c40a865 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -230,7 +230,7 @@ def __init__( sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, - patch_size_t: int = 2, + patch_size_t: Optional[int] = None, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 5e0ad64f82a6..8afa0a80779e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -368,12 +368,12 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, prompt, height, width, - num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, @@ -382,15 +382,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if ( - self.transformer.config.patch_size_t is not None - and latent_frames % self.transformer.config.patch_size_t != 0 - ): - raise ValueError( - f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -611,7 +602,6 @@ def __call__( prompt, height, width, - num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, @@ -744,8 +734,7 @@ def __call__( progress_bar.update() if not output_type == "latent": - breakpoint() - video = self.decode_latents(latents) + video = self.decode_latents(latents[:,1:]) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents From 27441fc2daf79c7838b43ff280e4e61c7639a40d Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 11 Nov 2024 21:20:53 +0800 Subject: [PATCH 16/36] #skip frames 0 --- .../pipelines/cogvideo/pipeline_cogvideox.py | 12 ++++++- .../pipeline_cogvideox_image2video.py | 34 ++++++++++++++----- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 8afa0a80779e..e8a252e751a7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -334,6 +334,10 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) + # For CogVideoX1.5, the latent should add 1 for padding (Not use) + if self.transformer.config.patch_size_t is not None: + shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:] + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -734,7 +738,13 @@ def __call__( progress_bar.update() if not output_type == "latent": - video = self.decode_latents(latents[:,1:]) + # Calculate the number of start frames based on the size of the second dimension of latents + num_latent_frames = latents.size(1) # Get the size of the second dimension + # (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0 + start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1) + + # Slice latents starting from start_frames + video = self.decode_latents(latents[:, start_frames:]) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 677b41bd775f..35830252c198 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -367,6 +367,10 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) + # For CogVideoX1.5, the latent should add 1 for padding (Not use) + if self.transformer.config.patch_size_t is not None: + shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:] + image = image.unsqueeze(2) # [B, C, F, H, W] if isinstance(generator, list): @@ -386,9 +390,15 @@ def prepare_latents( height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) image_latents = torch.cat([image_latents, latent_padding], dim=1) + # Select the first frame along the second dimension + if self.transformer.config.patch_size_t is not None: + first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...] + image_latents = torch.cat([first_frame, image_latents], dim=1) + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -460,14 +470,14 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if ( - self.transformer.config.patch_size_t is not None - and latent_frames % self.transformer.config.patch_size_t != 0 - ): - raise ValueError( - f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." - ) + # latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # if ( + # self.transformer.config.patch_size_t is not None + # and latent_frames % self.transformer.config.patch_size_t != 0 + # ): + # raise ValueError( + # f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." + # ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -853,7 +863,13 @@ def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_wi progress_bar.update() if not output_type == "latent": - video = self.decode_latents(latents) + # Calculate the number of start frames based on the size of the second dimension of latents + num_latent_frames = latents.size(1) # Get the size of the second dimension + # (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0 + start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1) + + # Slice latents starting from start_frames + video = self.decode_latents(latents[:, start_frames:]) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents From 7a15767657f0c42d0d922a0444687f2853a660c6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 15:38:51 +0100 Subject: [PATCH 17/36] refactor --- scripts/convert_cogvideox_to_diffusers.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 25 +++++----- .../pipeline_cogvideox_fun_control.py | 23 +++++----- .../pipeline_cogvideox_image2video.py | 46 +++++-------------- .../pipeline_cogvideox_video2video.py | 21 ++++----- 5 files changed, 47 insertions(+), 70 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index ca2cba598ebf..a29aa4e978d1 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -213,7 +213,7 @@ def get_init_kwargs(version: str): "patch_bias": False, "sample_height": 768 // vae_scale_factor_spatial, "sample_width": 1360 // vae_scale_factor_spatial, - "sample_frames": 85, + "sample_frames": 81, } else: raise ValueError("Unsupported version of CogVideoX.") diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index e8a252e751a7..313b753443bb 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -334,10 +334,6 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) - # For CogVideoX1.5, the latent should add 1 for padding (Not use) - if self.transformer.config.patch_size_t is not None: - shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:] - if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -648,7 +644,16 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) - # 5. Prepare latents. + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -738,13 +743,9 @@ def __call__( progress_bar.update() if not output_type == "latent": - # Calculate the number of start frames based on the size of the second dimension of latents - num_latent_frames = latents.size(1) # Get the size of the second dimension - # (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0 - start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1) - - # Slice latents starting from start_frames - video = self.decode_latents(latents[:, start_frames:]) + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] + video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 58fc94ab9fd5..aeca4abc8d23 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -412,7 +412,6 @@ def check_inputs( prompt, height, width, - num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, @@ -423,15 +422,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if ( - self.transformer.config.patch_size_t is not None - and latent_frames % self.transformer.config.patch_size_t != 0 - ): - raise ValueError( - f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -663,7 +653,6 @@ def __call__( prompt, height, width, - num_frames, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, @@ -708,7 +697,17 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) - # 5. Prepare latents. + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + latent_channels = self.transformer.config.in_channels // 2 latents = self.prepare_latents( batch_size * num_videos_per_prompt, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 35830252c198..fc217b9381a3 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -450,7 +450,6 @@ def check_inputs( prompt, height, width, - num_frames, negative_prompt, callback_on_step_end_tensor_inputs, latents=None, @@ -470,15 +469,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - # latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - # if ( - # self.transformer.config.patch_size_t is not None - # and latent_frames % self.transformer.config.patch_size_t != 0 - # ): - # raise ValueError( - # f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." - # ) - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -705,7 +695,6 @@ def __call__( image=image, prompt=prompt, height=height, - num_frames=num_frames, width=width, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -751,22 +740,15 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latents - # TODO: Only CogVideoX1.5-5B-I2V can use this method. Need to Change - def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_width, divisor=16): - # Step 1: Compare image dimensions with target dimensions - if image_height > tgt_height: - image_height = tgt_height - if image_width > tgt_width: - image_width = tgt_width - - # Step 2: Ensure height and width are divisible by the divisor - image_height = (image_height // divisor) * divisor - image_width = (image_width // divisor) * divisor - return image_height, image_width - - image_width, image_height = image.size[-2:] - - height, width = adjust_resolution_to_divisible(image_height, image_width, height, width) + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) @@ -863,13 +845,9 @@ def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_wi progress_bar.update() if not output_type == "latent": - # Calculate the number of start frames based on the size of the second dimension of latents - num_latent_frames = latents.size(1) # Get the size of the second dimension - # (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0 - start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1) - - # Slice latents starting from start_frames - video = self.decode_latents(latents[:, start_frames:]) + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] + video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 674c6e5d55e6..c9ba2ba96c9a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -438,7 +438,6 @@ def check_inputs( prompt, height, width, - num_frames, strength, negative_prompt, callback_on_step_end_tensor_inputs, @@ -450,15 +449,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if ( - self.transformer.config.patch_size_t is not None - and latent_frames % self.transformer.config.patch_size_t != 0 - ): - raise ValueError( - f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}." - ) - if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -688,7 +678,6 @@ def __call__( prompt=prompt, height=height, width=width, - num_frames=num_frames, strength=strength, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -737,6 +726,16 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) From e2a88cb43e809de94dc0166457bab93665630664 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 15:39:31 +0100 Subject: [PATCH 18/36] make style --- .../pipelines/cogvideo/pipeline_cogvideox_fun_control.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox_image2video.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index aeca4abc8d23..4838335dc856 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -707,7 +707,7 @@ def __call__( f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " f"contains {latent_frames=}, which is not divisible." ) - + latent_channels = self.transformer.config.in_channels // 2 latents = self.prepare_latents( batch_size * num_videos_per_prompt, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index fc217b9381a3..88f1ece5de5f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -748,7 +748,7 @@ def __call__( if patch_size_t is not None and latent_frames % patch_size_t != 0: additional_frames = patch_size_t - latent_frames % patch_size_t num_frames += additional_frames * self.vae_scale_factor_temporal - + image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index c9ba2ba96c9a..6af0ab4e115b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -735,7 +735,7 @@ def __call__( f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " f"contains {latent_frames=}, which is not divisible." ) - + if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) From 8966cb0cdedd0dba3c55d1d626bcf0a5b744869f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 15:45:30 +0100 Subject: [PATCH 19/36] update docs --- docs/source/en/api/pipelines/cogvideox.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 7afec984fdd6..01e5bd0aaa5f 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -38,6 +38,11 @@ There is one model available that can be used with the image-to-video CogVideoX - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. - [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. +For the CogVideoX 1.5 series of models, note that: +- Text-to-video works best at `1360 x 768` resolution because it is trained with that specific resolution +- Image-to-video works for multiple resolutions. Width can vary from `256` to `1360`, and height can vary from `256` to `768`. Note that the height/width must be divisible by `16`. +- Both T2V and I2V models support generation with `81` and `161` frames and work best at this value. It is recommended to export videos at 16 FPS. + There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): - [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `torch.bfloat16`. - [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `torch.bfloat16`. From f2213e8ab104a96c8c9706ec8b3e1a6471dad493 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 16:24:01 +0100 Subject: [PATCH 20/36] fix ofs_embed --- .../models/transformers/cogvideox_transformer_3d.py | 11 ++++++++--- .../cogvideo/pipeline_cogvideox_image2video.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 1fe80c40a865..40e9e9fcbaab 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -278,9 +278,10 @@ def __init__( self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + self.ofs_proj = None self.ofs_embedding = None - if ofs_embed_dim: + self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift) self.ofs_embedding = TimestepEmbedding( ofs_embed_dim, ofs_embed_dim, timestep_activation_fn ) # same as time embeddings, for ofs @@ -433,6 +434,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -463,9 +465,12 @@ def forward( # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) + if self.ofs_embedding is not None: - emb_ofs = self.ofs_embedding(emb, timestep_cond) - emb = emb + emb_ofs + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 88f1ece5de5f..4eaf5845f0a6 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -769,6 +769,7 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) @@ -776,6 +777,9 @@ def __call__( else None ) + # 8. Create ofs embeds if required + ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0) + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -800,6 +804,7 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, + ofs=ofs_emb, image_rotary_emb=image_rotary_emb, attention_kwargs=attention_kwargs, return_dict=False, From 8b2823265b10686f7c8b1a932267f83ab662e9e4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 19:32:56 +0100 Subject: [PATCH 21/36] update docs --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 01e5bd0aaa5f..3d6df1883397 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -40,7 +40,7 @@ There is one model available that can be used with the image-to-video CogVideoX For the CogVideoX 1.5 series of models, note that: - Text-to-video works best at `1360 x 768` resolution because it is trained with that specific resolution -- Image-to-video works for multiple resolutions. Width can vary from `256` to `1360`, and height can vary from `256` to `768`. Note that the height/width must be divisible by `16`. +- Image-to-video works for multiple resolutions. Width can vary from `768` to `1360`, and height must be `768`. Note that the height/width must be divisible by `16`. - Both T2V and I2V models support generation with `81` and `161` frames and work best at this value. It is recommended to export videos at 16 FPS. There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): From 3587317a36821bcbd951f0bdb83563745e86a303 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 19:33:33 +0100 Subject: [PATCH 22/36] invert_scale_latents --- scripts/convert_cogvideox_to_diffusers.py | 14 +++++++++----- .../autoencoders/autoencoder_kl_cogvideox.py | 1 + .../cogvideo/pipeline_cogvideox_image2video.py | 8 +++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index a29aa4e978d1..930a4921339f 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -173,9 +173,13 @@ def convert_transformer( return transformer -def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): +def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): + init_kwargs = {"scaling_factor": scaling_factor} + if args.version == "1.5": + init_kwargs.update({"invert_scale_latents": True}) + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -193,7 +197,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): return vae -def get_init_kwargs(version: str): +def get_transformer_init_kwargs(version: str): if version == "1.0": vae_scale_factor_spatial = 8 init_kwargs = { @@ -281,7 +285,7 @@ def get_args(): dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 if args.transformer_ckpt_path is not None: - init_kwargs = get_init_kwargs(args.version) + init_kwargs = get_transformer_init_kwargs(args.version) transformer = convert_transformer( args.transformer_ckpt_path, args.num_layers, @@ -293,7 +297,7 @@ def get_args(): ) if args.vae_ckpt_path is not None: # Keep VAE in float32 for better quality - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, torch.float32) + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32) text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index d9ee15062daf..fbcb964392f9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1057,6 +1057,7 @@ def __init__( force_upcast: float = True, use_quant_conv: bool = False, use_post_quant_conv: bool = False, + invert_scale_latents: bool = False, ): super().__init__() diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 4eaf5845f0a6..4abd28a16599 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -381,7 +381,13 @@ def prepare_latents( image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - image_latents = self.vae_scaling_factor_image * image_latents + + if not self.vae.config.invert_scale_latents: + image_latents = self.vae_scaling_factor_image * image_latents + else: + # This is awkward but required because the CogVideoX team forgot to multiply the + # scaling factor during training :) + image_latents = 1 / self.vae_scaling_factor_image * image_latents padding_shape = ( batch_size, From 17957d00ee05995b144df66d1a4b156e5e45650d Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 19:35:09 +0100 Subject: [PATCH 23/36] update --- scripts/convert_cogvideox_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 930a4921339f..6b4d64b19507 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -175,7 +175,7 @@ def convert_transformer( def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): init_kwargs = {"scaling_factor": scaling_factor} - if args.version == "1.5": + if version == "1.5": init_kwargs.update({"invert_scale_latents": True}) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) From 25a9e1c567f86bf6de538891d9f07c7f155e70af Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 00:22:16 +0100 Subject: [PATCH 24/36] fix --- src/diffusers/models/embeddings.py | 38 ++++++++++++++--- .../pipeline_cogvideox_image2video.py | 41 +++++++++++++------ 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b3212e43a61f..80775d477c0d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -517,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens def get_3d_rotary_pos_embed( - embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ RoPE for video tokens with 3D structure. @@ -533,17 +540,30 @@ def get_3d_rotary_pos_embed( The size of the temporal dimension. theta (`float`): Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ if use_real is not True: raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") - start, stop = crops_coords - grid_size_h, grid_size_w = grid_size - grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") # Compute dimensions for each axis dim_t = embed_dim // 4 @@ -579,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + cos = combine_time_height_width(t_cos, h_cos, w_cos) sin = combine_time_height_width(t_sin, h_sin, w_sin) return cos, sin diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 4abd28a16599..2b9212941ec4 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -540,21 +540,36 @@ def _prepare_rotary_positional_embeddings( grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size - p_t = self.transformer.config.patch_size_t or 1 + p_t = self.transformer.config.patch_size_t - base_size_width = self.transformer.config.sample_width // p - base_size_height = self.transformer.config.sample_height // p - base_num_frames = (num_frames + p_t - 1) // p_t + if p_t is None: + # CogVideoX 1.0 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) From 7990958e5042de65ff0d779d96f63bde43c622fc Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 05:06:12 +0530 Subject: [PATCH 25/36] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 3d6df1883397..5f250b1ab108 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -29,10 +29,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). -There are three official models available that can be used with the text-to-video and video-to-video CogVideoX pipelines: -- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `torch.float16`. -- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `torch.bfloat16`. -- [`THUDM/CogVideoX-1.5-5b`](https://huggingface.co/THUDM/CogVideoX-1.5-5b): The recommended dtype for running this mdoel is `torch.bfloat16`. +There are three official CogVideoX checkpoints for text-to-video and video-to-video. +| checkpoints | recommended inference dtype | +|---|---| +| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 | +| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 | +| [`THUDM/CogVideoX-1.5-5b`](https://huggingface.co/THUDM/CogVideoX-1.5-5b) | torch.bfloat16 | There is one model available that can be used with the image-to-video CogVideoX pipeline: - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. From 2c3b78d6d1a448dc0860e794434eba2f4a3ed957 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 05:06:23 +0530 Subject: [PATCH 26/36] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 5f250b1ab108..a6140a298d7e 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -40,10 +40,10 @@ There is one model available that can be used with the image-to-video CogVideoX - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. - [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. -For the CogVideoX 1.5 series of models, note that: -- Text-to-video works best at `1360 x 768` resolution because it is trained with that specific resolution -- Image-to-video works for multiple resolutions. Width can vary from `768` to `1360`, and height must be `768`. Note that the height/width must be divisible by `16`. -- Both T2V and I2V models support generation with `81` and `161` frames and work best at this value. It is recommended to export videos at 16 FPS. +For the CogVideoX 1.5 series: +- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution. +- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16. +- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended. There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): - [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `torch.bfloat16`. From e063e9d8225f2898a78d7180fd920b6c6bffc7f4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 05:06:37 +0530 Subject: [PATCH 27/36] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index a6140a298d7e..81d24baa277a 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -36,7 +36,11 @@ There are three official CogVideoX checkpoints for text-to-video and video-to-vi | [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 | | [`THUDM/CogVideoX-1.5-5b`](https://huggingface.co/THUDM/CogVideoX-1.5-5b) | torch.bfloat16 | -There is one model available that can be used with the image-to-video CogVideoX pipeline: +There are two official CogVideoX checkpoints available for image-to-video. +| checkpoints | recommended inference dtype | +|---|---| +| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 | +| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 | - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. - [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. From f054c4407a1d0b894d42e6b5a849d171440f66ef Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 05:06:50 +0530 Subject: [PATCH 28/36] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 81d24baa277a..6570e1d1c16e 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -49,7 +49,11 @@ For the CogVideoX 1.5 series: - Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16. - Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended. -There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): +There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team). +| checkpoints | recommended inference dtype | +|---|---| +| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 | +| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 | - [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `torch.bfloat16`. - [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `torch.bfloat16`. From 3849caef94e9e8cdad6a1cbdb2aabd8fb809fe16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 05:07:04 +0530 Subject: [PATCH 29/36] Update src/diffusers/models/transformers/cogvideox_transformer_3d.py --- src/diffusers/models/transformers/cogvideox_transformer_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 40e9e9fcbaab..b47d439774cc 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -171,7 +171,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): time_embed_dim (`int`, defaults to `512`): Output dimension of timestep embeddings. ofs_embed_dim (`int`, defaults to `512`): - scaling factor in the VAE process for the Image-to-Video (I2V) transformation in CogVideoX1.5-5B. + Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5 text_embed_dim (`int`, defaults to `4096`): Input dimension of text embeddings from the text encoder. num_layers (`int`, defaults to `30`): From 4d14abbb494d90aa88ef1228c3121bf571ee1a4e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 00:37:46 +0100 Subject: [PATCH 30/36] update conversion script --- scripts/convert_cogvideox_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 6b4d64b19507..7eeed240c4de 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -215,8 +215,8 @@ def get_transformer_init_kwargs(version: str): "patch_size": 2, "patch_size_t": 2, "patch_bias": False, - "sample_height": 768 // vae_scale_factor_spatial, - "sample_width": 1360 // vae_scale_factor_spatial, + "sample_height": 300, + "sample_width": 300, "sample_frames": 81, } else: From 9c846ebf79cc02f09c9b70e4069386f4c4ec5144 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 00:38:21 +0100 Subject: [PATCH 31/36] remove copied from --- .../pipelines/cogvideo/pipeline_cogvideox_image2video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 2b9212941ec4..6fa8731dc99e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -528,7 +528,6 @@ def unfuse_qkv_projections(self) -> None: self.transformer.unfuse_qkv_projections() self.fusing_transformer = False - # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings def _prepare_rotary_positional_embeddings( self, height: int, From 9ef66d1f31aa22d61b7971d43ddfcbe0c9a26dd1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 15 Nov 2024 00:52:11 +0100 Subject: [PATCH 32/36] fix test --- tests/models/transformers/test_models_transformer_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index e0350ef8dd99..4c13b54e0620 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -97,7 +97,7 @@ class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): batch_size = 2 num_channels = 4 - num_frames = 1 + num_frames = 2 height = 8 width = 8 embedding_dim = 8 From 23abe7bc48a1a5bbc9b4d4858ffe7a250ce04274 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 17 Nov 2024 12:38:13 +0530 Subject: [PATCH 33/36] Update docs/source/en/api/pipelines/cogvideox.md --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 6570e1d1c16e..52ea947d8c8b 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -34,7 +34,7 @@ There are three official CogVideoX checkpoints for text-to-video and video-to-vi |---|---| | [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 | | [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 | -| [`THUDM/CogVideoX-1.5-5b`](https://huggingface.co/THUDM/CogVideoX-1.5-5b) | torch.bfloat16 | +| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 | There are two official CogVideoX checkpoints available for image-to-video. | checkpoints | recommended inference dtype | From f47516d8f1397b0bb9053185a61401a68279f2e1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 17 Nov 2024 12:38:22 +0530 Subject: [PATCH 34/36] Update docs/source/en/api/pipelines/cogvideox.md --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 52ea947d8c8b..9c60f5e69414 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -42,7 +42,7 @@ There are two official CogVideoX checkpoints available for image-to-video. | [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 | | [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 | - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. -- [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. +- [`THUDM/CogVideoX1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. For the CogVideoX 1.5 series: - Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution. From 4a4df63f985a478ecb7e87f361e6184ec65859f3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 17 Nov 2024 13:28:34 +0530 Subject: [PATCH 35/36] Update docs/source/en/api/pipelines/cogvideox.md --- docs/source/en/api/pipelines/cogvideox.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 9c60f5e69414..aca325dac816 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -54,8 +54,6 @@ There are two official CogVideoX checkpoints that support pose controllable gene |---|---| | [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 | | [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 | -- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `torch.bfloat16`. -- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `torch.bfloat16`. ## Inference From ea166f85ad0090d182ec5f0e24123d5b8e9aca57 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 17 Nov 2024 13:28:43 +0530 Subject: [PATCH 36/36] Update docs/source/en/api/pipelines/cogvideox.md --- docs/source/en/api/pipelines/cogvideox.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index aca325dac816..40320896881c 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -41,8 +41,6 @@ There are two official CogVideoX checkpoints available for image-to-video. |---|---| | [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 | | [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 | -- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `torch.bfloat16`. -- [`THUDM/CogVideoX1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX1.5-5b-I2V): The recommended dtype for running this mdoel is `torch.bfloat16`. For the CogVideoX 1.5 series: - Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.