Skip to content

Commit cfddf35

Browse files
authored
Merge branch 'main' into integrations/wan2.2-s2v
2 parents 87f0c39 + 6156cf8 commit cfddf35

File tree

47 files changed

+6014
-153
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+6014
-153
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@
359359
title: HunyuanDiT2DModel
360360
- local: api/models/hunyuanimage_transformer_2d
361361
title: HunyuanImageTransformer2DModel
362+
- local: api/models/hunyuan_video15_transformer_3d
363+
title: HunyuanVideo15Transformer3DModel
362364
- local: api/models/hunyuan_video_transformer_3d
363365
title: HunyuanVideoTransformer3DModel
364366
- local: api/models/latte_transformer3d
@@ -433,6 +435,8 @@
433435
title: AutoencoderKLHunyuanImageRefiner
434436
- local: api/models/autoencoder_kl_hunyuan_video
435437
title: AutoencoderKLHunyuanVideo
438+
- local: api/models/autoencoder_kl_hunyuan_video15
439+
title: AutoencoderKLHunyuanVideo15
436440
- local: api/models/autoencoderkl_ltx_video
437441
title: AutoencoderKLLTXVideo
438442
- local: api/models/autoencoderkl_magvit
@@ -652,6 +656,8 @@
652656
title: Framepack
653657
- local: api/pipelines/hunyuan_video
654658
title: HunyuanVideo
659+
- local: api/pipelines/hunyuan_video15
660+
title: HunyuanVideo1.5
655661
- local: api/pipelines/i2vgenxl
656662
title: I2VGen-XL
657663
- local: api/pipelines/kandinsky5_video
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLHunyuanVideo15
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLHunyuanVideo15
20+
21+
vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32)
22+
23+
# make sure to enable tiling to avoid OOM
24+
vae.enable_tiling()
25+
```
26+
27+
## AutoencoderKLHunyuanVideo15
28+
29+
[[autodoc]] AutoencoderKLHunyuanVideo15
30+
- decode
31+
- encode
32+
- all
33+
34+
## DecoderOutput
35+
36+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# HunyuanVideo15Transformer3DModel
13+
14+
A Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5).
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import HunyuanVideo15Transformer3DModel
20+
21+
transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16)
22+
```
23+
24+
## HunyuanVideo15Transformer3DModel
25+
26+
[[autodoc]] HunyuanVideo15Transformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
16+
# HunyuanVideo-1.5
17+
18+
HunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models.
19+
20+
You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.
21+
22+
> [!TIP]
23+
> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.
24+
>
25+
> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.
26+
27+
The example below demonstrates how to generate a video optimized for memory or inference speed.
28+
29+
<hfoptions id="usage">
30+
<hfoption id="memory">
31+
32+
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
33+
34+
35+
```py
36+
import torch
37+
from diffusers import AutoModel, HunyuanVideo15Pipeline
38+
from diffusers.utils import export_to_video
39+
40+
41+
pipeline = HunyuanVideo15Pipeline.from_pretrained(
42+
"HunyuanVideo-1.5-Diffusers-480p_t2v",
43+
torch_dtype=torch.bfloat16,
44+
)
45+
46+
# model-offloading and tiling
47+
pipeline.enable_model_cpu_offload()
48+
pipeline.vae.enable_tiling()
49+
50+
prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
51+
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
52+
export_to_video(video, "output.mp4", fps=15)
53+
```
54+
55+
## Notes
56+
57+
- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
58+
59+
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
60+
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen`
61+
- **Other GPUs:** `sage_hub`
62+
63+
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
64+
65+
66+
```py
67+
pipe.transformer.set_attention_backend("flash_hub") # or your preferred backend
68+
```
69+
70+
- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime.
71+
72+
You can check the default guider configuration using `pipe.guider`:
73+
74+
```py
75+
>>> pipe.guider
76+
ClassifierFreeGuidance {
77+
"_class_name": "ClassifierFreeGuidance",
78+
"_diffusers_version": "0.36.0.dev0",
79+
"enabled": true,
80+
"guidance_rescale": 0.0,
81+
"guidance_scale": 6.0,
82+
"start": 0.0,
83+
"stop": 1.0,
84+
"use_original_formulation": false
85+
}
86+
87+
State:
88+
step: None
89+
num_inference_steps: None
90+
timestep: None
91+
count_prepared: 0
92+
enabled: True
93+
num_conditions: 2
94+
```
95+
96+
To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
97+
98+
```py
99+
pipe.guider = pipe.guider.new(guidance_scale=5.0)
100+
```
101+
102+
Read more on Guider [here](../../modular_diffusers/guiders).
103+
104+
105+
106+
## HunyuanVideo15Pipeline
107+
108+
[[autodoc]] HunyuanVideo15Pipeline
109+
- all
110+
- __call__
111+
112+
## HunyuanVideo15ImageToVideoPipeline
113+
114+
[[autodoc]] HunyuanVideo15ImageToVideoPipeline
115+
- all
116+
- __call__
117+
118+
## HunyuanVideo15PipelineOutput
119+
120+
[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput

docs/source/en/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
159159
```py
160160
guider_spec = t2i_pipeline.get_component_spec("guider")
161161
guider_spec.default_creation_method="from_pretrained"
162-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
162+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
163163
guider_spec.subfolder="pag_guider"
164164
pag_guider = guider_spec.load()
165165
t2i_pipeline.update_components(guider=pag_guider)

docs/source/en/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# modify to load from a different repository
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# load component with modified spec
326326
unet = unet_spec.load(torch_dtype=torch.float16)

docs/source/zh/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
157157
```py
158158
guider_spec = t2i_pipeline.get_component_spec("guider")
159159
guider_spec.default_creation_method="from_pretrained"
160-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
160+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
161161
guider_spec.subfolder="pag_guider"
162162
pag_guider = guider_spec.load()
163163
t2i_pipeline.update_components(guider=pag_guider)

docs/source/zh/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# 修改以从不同的仓库加载
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# 使用修改后的规范加载组件
326326
unet = unet_spec.load(torch_dtype=torch.float16)

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
3939
from peft import LoraConfig
40-
from peft.utils import get_peft_model_state_dict
40+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
4141
from torchvision import transforms
4242
from tqdm.auto import tqdm
4343
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@
4646
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.training_utils import cast_training_params, compute_snr
49-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
49+
from diffusers.utils import (
50+
check_min_version,
51+
convert_state_dict_to_diffusers,
52+
convert_unet_state_dict_to_peft,
53+
is_wandb_available,
54+
)
5055
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5156
from diffusers.utils.import_utils import is_xformers_available
5257
from diffusers.utils.torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def collate_fn(examples):
708713
num_workers=args.dataloader_num_workers,
709714
)
710715

716+
def save_model_hook(models, weights, output_dir):
717+
if accelerator.is_main_process:
718+
unet_lora_layers_to_save = None
719+
720+
for model in models:
721+
if isinstance(model, type(unwrap_model(unet))):
722+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
723+
else:
724+
raise ValueError(f"Unexpected save model: {model.__class__}")
725+
726+
# make sure to pop weight so that corresponding model is not saved again
727+
weights.pop()
728+
729+
StableDiffusionPipeline.save_lora_weights(
730+
save_directory=output_dir,
731+
unet_lora_layers=unet_lora_layers_to_save,
732+
safe_serialization=True,
733+
)
734+
735+
def load_model_hook(models, input_dir):
736+
unet_ = None
737+
738+
while len(models) > 0:
739+
model = models.pop()
740+
if isinstance(model, type(unwrap_model(unet))):
741+
unet_ = model
742+
else:
743+
raise ValueError(f"unexpected save model: {model.__class__}")
744+
745+
# returns a tuple of state dictionary and network alphas
746+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
747+
748+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
749+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
750+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
751+
752+
if incompatible_keys is not None:
753+
# check only for unexpected keys
754+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
755+
# throw warning if some unexpected keys are found and continue loading
756+
if unexpected_keys:
757+
logger.warning(
758+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
759+
f" {unexpected_keys}. "
760+
)
761+
762+
# Make sure the trainable params are in float32
763+
if args.mixed_precision in ["fp16"]:
764+
cast_training_params([unet_], dtype=torch.float32)
765+
711766
# Scheduler and math around the number of training steps.
712767
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713768
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -732,6 +787,10 @@ def collate_fn(examples):
732787
unet, optimizer, train_dataloader, lr_scheduler
733788
)
734789

790+
# Register the hooks for efficient saving and loading of LoRA weights
791+
accelerator.register_save_state_pre_hook(save_model_hook)
792+
accelerator.register_load_state_pre_hook(load_model_hook)
793+
735794
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
736795
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737796
if args.max_train_steps is None:
@@ -906,17 +965,6 @@ def collate_fn(examples):
906965
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
907966
accelerator.save_state(save_path)
908967

909-
unwrapped_unet = unwrap_model(unet)
910-
unet_lora_state_dict = convert_state_dict_to_diffusers(
911-
get_peft_model_state_dict(unwrapped_unet)
912-
)
913-
914-
StableDiffusionPipeline.save_lora_weights(
915-
save_directory=save_path,
916-
unet_lora_layers=unet_lora_state_dict,
917-
safe_serialization=True,
918-
)
919-
920968
logger.info(f"Saved state to {save_path}")
921969

922970
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}

0 commit comments

Comments
 (0)