Skip to content

Commit 5218bae

Browse files
committed
fix for reference images
1 parent 23f6bc1 commit 5218bae

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def preprocess_conditions(
359359
self,
360360
video: Optional[List[PipelineImageInput]] = None,
361361
mask: Optional[List[PipelineImageInput]] = None,
362-
reference_images: Optional[List[PipelineImageInput]] = None,
362+
reference_images: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
363363
batch_size: int = 1,
364364
height: int = 480,
365365
width: int = 832,
@@ -513,8 +513,9 @@ def prepare_video_latents(
513513
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
514514
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
515515
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
516-
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
517-
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
516+
reference_latent = reference_latent.squeeze(0) # [C, 1, H, W]
517+
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0)
518+
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1)
518519
latent_list.append(latent)
519520
return torch.stack(latent_list)
520521

@@ -811,6 +812,7 @@ def __call__(
811812
torch.float32,
812813
device,
813814
)
815+
num_reference_images = len(reference_images[0])
814816

815817
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
816818
mask = self.prepare_masks(mask, reference_images, generator)
@@ -823,7 +825,7 @@ def __call__(
823825
num_channels_latents,
824826
height,
825827
width,
826-
num_frames,
828+
num_frames + num_reference_images * self.vae_scale_factor_temporal,
827829
torch.float32,
828830
device,
829831
generator,
@@ -893,6 +895,8 @@ def __call__(
893895
self._current_timestep = None
894896

895897
if not output_type == "latent":
898+
print(latents.shape, num_reference_images)
899+
latents = latents[:, :, num_reference_images:]
896900
latents = latents.to(vae_dtype)
897901
latents_mean = (
898902
torch.tensor(self.vae.config.latents_mean)

0 commit comments

Comments
 (0)