Skip to content

Commit 804f5cc

Browse files
Merge pull request #3 from zRzRzRzRzRzRzR/main
update
2 parents 36b1682 + 7b100ce commit 804f5cc

File tree

93 files changed

+4140
-1928
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+4140
-1928
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@
600600
title: Attention Processor
601601
- local: api/activations
602602
title: Custom activation functions
603+
- local: api/cache
604+
title: Caching methods
603605
- local: api/normalization
604606
title: Custom normalization layers
605607
- local: api/utilities

docs/source/en/api/cache.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Caching methods
13+
14+
## Pyramid Attention Broadcast
15+
16+
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
17+
18+
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
19+
20+
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
21+
22+
```python
23+
import torch
24+
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
25+
26+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
27+
pipe.to("cuda")
28+
29+
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
30+
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
31+
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
32+
# poorer quality of generated videos.
33+
config = PyramidAttentionBroadcastConfig(
34+
spatial_attention_block_skip_range=2,
35+
spatial_attention_timestep_skip_range=(100, 800),
36+
current_timestep_callback=lambda: pipe.current_timestep,
37+
)
38+
pipe.transformer.enable_cache(config)
39+
```
40+
41+
### CacheMixin
42+
43+
[[autodoc]] CacheMixin
44+
45+
### PyramidAttentionBroadcastConfig
46+
47+
[[autodoc]] PyramidAttentionBroadcastConfig
48+
49+
[[autodoc]] apply_pyramid_attention_broadcast

examples/community/matryoshka.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
USE_PEFT_BACKEND,
8181
BaseOutput,
8282
deprecate,
83-
is_torch_version,
8483
is_torch_xla_available,
8584
logging,
8685
replace_example_docstring,
@@ -869,23 +868,7 @@ def forward(
869868

870869
for i, (resnet, attn) in enumerate(blocks):
871870
if torch.is_grad_enabled() and self.gradient_checkpointing:
872-
873-
def create_custom_forward(module, return_dict=None):
874-
def custom_forward(*inputs):
875-
if return_dict is not None:
876-
return module(*inputs, return_dict=return_dict)
877-
else:
878-
return module(*inputs)
879-
880-
return custom_forward
881-
882-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883-
hidden_states = torch.utils.checkpoint.checkpoint(
884-
create_custom_forward(resnet),
885-
hidden_states,
886-
temb,
887-
**ckpt_kwargs,
888-
)
871+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
889872
hidden_states = attn(
890873
hidden_states,
891874
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ def forward(
10301013
hidden_states = self.resnets[0](hidden_states, temb)
10311014
for attn, resnet in zip(self.attentions, self.resnets[1:]):
10321015
if torch.is_grad_enabled() and self.gradient_checkpointing:
1033-
1034-
def create_custom_forward(module, return_dict=None):
1035-
def custom_forward(*inputs):
1036-
if return_dict is not None:
1037-
return module(*inputs, return_dict=return_dict)
1038-
else:
1039-
return module(*inputs)
1040-
1041-
return custom_forward
1042-
1043-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
10441016
hidden_states = attn(
10451017
hidden_states,
10461018
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
10491021
encoder_attention_mask=encoder_attention_mask,
10501022
return_dict=False,
10511023
)[0]
1052-
hidden_states = torch.utils.checkpoint.checkpoint(
1053-
create_custom_forward(resnet),
1054-
hidden_states,
1055-
temb,
1056-
**ckpt_kwargs,
1057-
)
1024+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
10581025
else:
10591026
hidden_states = attn(
10601027
hidden_states,
@@ -1192,23 +1159,7 @@ def forward(
11921159
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931160

11941161
if torch.is_grad_enabled() and self.gradient_checkpointing:
1195-
1196-
def create_custom_forward(module, return_dict=None):
1197-
def custom_forward(*inputs):
1198-
if return_dict is not None:
1199-
return module(*inputs, return_dict=return_dict)
1200-
else:
1201-
return module(*inputs)
1202-
1203-
return custom_forward
1204-
1205-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1206-
hidden_states = torch.utils.checkpoint.checkpoint(
1207-
create_custom_forward(resnet),
1208-
hidden_states,
1209-
temb,
1210-
**ckpt_kwargs,
1211-
)
1162+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
12121163
hidden_states = attn(
12131164
hidden_states,
12141165
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
12821233
]
12831234
)
12841235

1285-
def _set_gradient_checkpointing(self, module, value=False):
1286-
if hasattr(module, "gradient_checkpointing"):
1287-
module.gradient_checkpointing = value
1288-
12891236
def forward(
12901237
self,
12911238
hidden_states: torch.Tensor,
@@ -1365,27 +1312,15 @@ def forward(
13651312
# Blocks
13661313
for block in self.transformer_blocks:
13671314
if torch.is_grad_enabled() and self.gradient_checkpointing:
1368-
1369-
def create_custom_forward(module, return_dict=None):
1370-
def custom_forward(*inputs):
1371-
if return_dict is not None:
1372-
return module(*inputs, return_dict=return_dict)
1373-
else:
1374-
return module(*inputs)
1375-
1376-
return custom_forward
1377-
1378-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1379-
hidden_states = torch.utils.checkpoint.checkpoint(
1380-
create_custom_forward(block),
1315+
hidden_states = self._gradient_checkpointing_func(
1316+
block,
13811317
hidden_states,
13821318
attention_mask,
13831319
encoder_hidden_states,
13841320
encoder_attention_mask,
13851321
timestep,
13861322
cross_attention_kwargs,
13871323
class_labels,
1388-
**ckpt_kwargs,
13891324
)
13901325
else:
13911326
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
27242659
for module in self.children():
27252660
fn_recursive_set_attention_slice(module, reversed_slice_size)
27262661

2727-
def _set_gradient_checkpointing(self, module, value=False):
2728-
if hasattr(module, "gradient_checkpointing"):
2729-
module.gradient_checkpointing = value
2730-
27312662
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
27322663
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
27332664

0 commit comments

Comments
 (0)