@@ -359,7 +359,7 @@ def preprocess_conditions(
359359 self ,
360360 video : Optional [List [PipelineImageInput ]] = None ,
361361 mask : Optional [List [PipelineImageInput ]] = None ,
362- reference_images : Optional [List [PipelineImageInput ]] = None ,
362+ reference_images : Optional [Union [ PIL . Image . Image , List [PIL . Image . Image ], List [ List [ PIL . Image . Image ]] ]] = None ,
363363 batch_size : int = 1 ,
364364 height : int = 480 ,
365365 width : int = 832 ,
@@ -513,8 +513,9 @@ def prepare_video_latents(
513513 reference_image = reference_image [None , :, None , :, :] # [1, C, 1, H, W]
514514 reference_latent = retrieve_latents (self .vae .encode (reference_image ), generator , sample_mode = "argmax" )
515515 reference_latent = ((reference_latent .float () - latents_mean ) * latents_std ).to (vae_dtype )
516- reference_latent = torch .cat ([reference_latent , torch .zeros_like (reference_latent )], dim = 1 )
517- latent = torch .cat ([reference_latent .squeeze (0 ), latent ], dim = 1 ) # Concat across frame dimension
516+ reference_latent = reference_latent .squeeze (0 ) # [C, 1, H, W]
517+ reference_latent = torch .cat ([reference_latent , torch .zeros_like (reference_latent )], dim = 0 )
518+ latent = torch .cat ([reference_latent .squeeze (0 ), latent ], dim = 1 )
518519 latent_list .append (latent )
519520 return torch .stack (latent_list )
520521
@@ -811,6 +812,7 @@ def __call__(
811812 torch .float32 ,
812813 device ,
813814 )
815+ num_reference_images = len (reference_images [0 ])
814816
815817 conditioning_latents = self .prepare_video_latents (video , mask , reference_images , generator , device )
816818 mask = self .prepare_masks (mask , reference_images , generator )
@@ -823,7 +825,7 @@ def __call__(
823825 num_channels_latents ,
824826 height ,
825827 width ,
826- num_frames ,
828+ num_frames + num_reference_images * self . vae_scale_factor_temporal ,
827829 torch .float32 ,
828830 device ,
829831 generator ,
@@ -893,6 +895,8 @@ def __call__(
893895 self ._current_timestep = None
894896
895897 if not output_type == "latent" :
898+ print (latents .shape , num_reference_images )
899+ latents = latents [:, :, num_reference_images :]
896900 latents = latents .to (vae_dtype )
897901 latents_mean = (
898902 torch .tensor (self .vae .config .latents_mean )
0 commit comments