Skip to content

Commit cd00ba6

Browse files
jerry2102lirui.926
andauthored
fix spatial compression ratio error for AutoEncoderKLWan doing tiled encode (#12753)
fix spatial compression ratio compute error for AutoEncoderKLWan Co-authored-by: lirui.926 <lirui.926@bytedance.com>
1 parent 2842c14 commit cd00ba6

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,14 +1259,20 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12591259
`torch.Tensor`:
12601260
The latent representation of the encoded videos.
12611261
"""
1262+
12621263
_, _, num_frames, height, width = x.shape
1263-
latent_height = height // self.spatial_compression_ratio
1264-
latent_width = width // self.spatial_compression_ratio
1264+
encode_spatial_compression_ratio = self.spatial_compression_ratio
1265+
if self.config.patch_size is not None:
1266+
assert encode_spatial_compression_ratio % self.config.patch_size == 0
1267+
encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
12651268

1266-
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1267-
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1268-
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1269-
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1269+
latent_height = height // encode_spatial_compression_ratio
1270+
latent_width = width // encode_spatial_compression_ratio
1271+
1272+
tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio
1273+
tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio
1274+
tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio
1275+
tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio
12701276

12711277
blend_height = tile_latent_min_height - tile_latent_stride_height
12721278
blend_width = tile_latent_min_width - tile_latent_stride_width

0 commit comments

Comments
 (0)