@@ -1259,14 +1259,20 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12591259 `torch.Tensor`:
12601260 The latent representation of the encoded videos.
12611261 """
1262+
12621263 _ , _ , num_frames , height , width = x .shape
1263- latent_height = height // self .spatial_compression_ratio
1264- latent_width = width // self .spatial_compression_ratio
1264+ encode_spatial_compression_ratio = self .spatial_compression_ratio
1265+ if self .config .patch_size is not None :
1266+ assert encode_spatial_compression_ratio % self .config .patch_size == 0
1267+ encode_spatial_compression_ratio = self .spatial_compression_ratio // self .config .patch_size
12651268
1266- tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1267- tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1268- tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1269- tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1269+ latent_height = height // encode_spatial_compression_ratio
1270+ latent_width = width // encode_spatial_compression_ratio
1271+
1272+ tile_latent_min_height = self .tile_sample_min_height // encode_spatial_compression_ratio
1273+ tile_latent_min_width = self .tile_sample_min_width // encode_spatial_compression_ratio
1274+ tile_latent_stride_height = self .tile_sample_stride_height // encode_spatial_compression_ratio
1275+ tile_latent_stride_width = self .tile_sample_stride_width // encode_spatial_compression_ratio
12701276
12711277 blend_height = tile_latent_min_height - tile_latent_stride_height
12721278 blend_width = tile_latent_min_width - tile_latent_stride_width
0 commit comments