Skip to content

Commit afc60a0

Browse files
a-r-r-o-wyiyixuxu
authored andcommitted
apply suggestions from review; update docs
1 parent e96d5ad commit afc60a0

File tree

6 files changed

+102
-26
lines changed

6 files changed

+102
-26
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ image = pipe(
5454
image.save("sd3_hello_world.png")
5555
```
5656

57+
**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
58+
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
59+
- [`stabilityai/stable-diffusion-3.5-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium-diffusers)
60+
- [`stabilityai/stable-diffusion-3.5-large-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-diffusers)
61+
5762
## Memory Optimisations for SD3
5863

5964
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.

scripts/convert_sd3_to_diffusers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def swap_scale_shift(weight, dim):
4040

4141

4242
def convert_sd3_transformer_checkpoint_to_diffusers(
43-
original_state_dict, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm
43+
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
4444
):
4545
converted_state_dict = {}
4646

@@ -142,7 +142,7 @@ def convert_sd3_transformer_checkpoint_to_diffusers(
142142
)
143143

144144
# attn2
145-
if i in add_attn2_layers:
145+
if i in dual_attention_layers:
146146
# Q, K, V
147147
sample_q2, sample_k2, sample_v2 = torch.chunk(
148148
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
@@ -244,14 +244,14 @@ def is_vae_in_checkpoint(original_state_dict):
244244
)
245245

246246

247-
def get_add_attn2_layers(state_dict):
248-
add_attn2_layers = []
247+
def get_attn2_layers(state_dict):
248+
attn2_layers = []
249249
for key in state_dict.keys():
250250
if "attn2." in key:
251251
# Extract the layer number from the key
252252
layer_num = int(key.split(".")[1])
253-
add_attn2_layers.append(layer_num)
254-
return tuple(sorted(set(add_attn2_layers)))
253+
attn2_layers.append(layer_num)
254+
return tuple(sorted(set(attn2_layers)))
255255

256256

257257
def get_pos_embed_max_size(state_dict):
@@ -284,14 +284,16 @@ def main(args):
284284
raise ValueError(f"Unsupported dtype: {args.dtype}")
285285

286286
if dtype != original_dtype:
287-
print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution.")
287+
print(
288+
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
289+
)
288290

289291
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
290292

291293
caption_projection_dim = get_caption_projection_dim(original_ckpt)
292294

293295
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
294-
add_attn2_layers = get_add_attn2_layers(original_ckpt)
296+
attn2_layers = get_attn2_layers(original_ckpt)
295297

296298
# sd3.5 use qk norm("rms_norm")
297299
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
@@ -300,7 +302,7 @@ def main(args):
300302
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
301303

302304
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
303-
original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm
305+
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
304306
)
305307

306308
with CTX():
@@ -314,7 +316,7 @@ def main(args):
314316
num_attention_heads=num_layers,
315317
pos_embed_max_size=pos_embed_max_size,
316318
qk_norm="rms_norm" if has_qk_norm else None,
317-
add_attn2_layers=add_attn2_layers,
319+
dual_attention_layers=attn2_layers,
318320
)
319321
if is_accelerate_available():
320322
load_model_dict_into_meta(transformer, converted_transformer_state_dict)

src/diffusers/models/attention.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,21 @@ class JointTransformerBlock(nn.Module):
101101
"""
102102

103103
def __init__(
104-
self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, qk_norm=None, add_attn2=False
104+
self,
105+
dim: int,
106+
num_attention_heads: int,
107+
attention_head_dim: int,
108+
context_pre_only: bool = False,
109+
qk_norm: Optional[str] = None,
110+
use_dual_attention: bool = False,
105111
):
106112
super().__init__()
107113

108-
self.add_attn2 = add_attn2
114+
self.use_dual_attention = use_dual_attention
109115
self.context_pre_only = context_pre_only
110116
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
111117

112-
if add_attn2:
118+
if use_dual_attention:
113119
self.norm1 = SD35AdaLayerNormZeroX(dim)
114120
else:
115121
self.norm1 = AdaLayerNormZero(dim)
@@ -124,12 +130,14 @@ def __init__(
124130
raise ValueError(
125131
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
126132
)
133+
127134
if hasattr(F, "scaled_dot_product_attention"):
128135
processor = JointAttnProcessor2_0()
129136
else:
130137
raise ValueError(
131138
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
132139
)
140+
133141
self.attn = Attention(
134142
query_dim=dim,
135143
cross_attention_dim=None,
@@ -144,7 +152,7 @@ def __init__(
144152
eps=1e-6,
145153
)
146154

147-
if add_attn2:
155+
if use_dual_attention:
148156
self.attn2 = Attention(
149157
query_dim=dim,
150158
cross_attention_dim=None,
@@ -182,7 +190,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
182190
def forward(
183191
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
184192
):
185-
if self.add_attn2:
193+
if self.use_dual_attention:
186194
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
187195
hidden_states, emb=temb
188196
)
@@ -205,7 +213,7 @@ def forward(
205213
attn_output = gate_msa.unsqueeze(1) * attn_output
206214
hidden_states = hidden_states + attn_output
207215

208-
if self.add_attn2:
216+
if self.use_dual_attention:
209217
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
210218
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
211219
hidden_states = hidden_states + attn_output2

src/diffusers/models/normalization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
9999

100100
class SD35AdaLayerNormZeroX(nn.Module):
101101
r"""
102-
Norm layer adaptive layer norm zero (adaLN-Zero).
102+
Norm layer adaptive layer norm zero (AdaLN-Zero).
103103
104104
Parameters:
105105
embedding_dim (`int`): The size of each embedding vector.
106106
num_embeddings (`int`): The size of the embeddings dictionary.
107107
"""
108108

109-
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
109+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
110110
super().__init__()
111111

112112
self.silu = nn.SiLU()
@@ -118,17 +118,17 @@ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
118118

119119
def forward(
120120
self,
121-
x: torch.Tensor,
121+
hidden_states: torch.Tensor,
122122
emb: Optional[torch.Tensor] = None,
123-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
123+
) -> Tuple[torch.Tensor, ...]:
124124
emb = self.linear(self.silu(emb))
125125
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
126126
9, dim=1
127127
)
128-
normed_x = self.norm(x)
129-
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
130-
x2 = normed_x * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
131-
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2
128+
norm_hidden_states = self.norm(hidden_states)
129+
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
130+
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
131+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
132132

133133

134134
class AdaLayerNormZero(nn.Module):

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def __init__(
6969
pooled_projection_dim: int = 2048,
7070
out_channels: int = 16,
7171
pos_embed_max_size: int = 96,
72-
add_attn2_layers: Tuple[int, ...] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
72+
dual_attention_layers: Tuple[
73+
int, ...
74+
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
7375
qk_norm: Optional[str] = None,
7476
):
7577
super().__init__()
@@ -100,7 +102,7 @@ def __init__(
100102
attention_head_dim=self.config.attention_head_dim,
101103
context_pre_only=i == num_layers - 1,
102104
qk_norm=qk_norm,
103-
add_attn2=True if i in add_attn2_layers else False,
105+
use_dual_attention=True if i in dual_attention_layers else False,
104106
)
105107
for i in range(self.config.num_layers)
106108
]

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self):
7373
"joint_attention_dim": 32,
7474
"pooled_projection_dim": 64,
7575
"out_channels": 4,
76+
"pos_embed_max_size": 96,
77+
"dual_attention_layers": (),
78+
"qk_norm": None,
79+
}
80+
inputs_dict = self.dummy_input
81+
return init_dict, inputs_dict
82+
83+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
84+
def test_set_attn_processor_for_determinism(self):
85+
pass
86+
87+
88+
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
89+
model_class = SD3Transformer2DModel
90+
main_input_name = "hidden_states"
91+
92+
@property
93+
def dummy_input(self):
94+
batch_size = 2
95+
num_channels = 4
96+
height = width = embedding_dim = 32
97+
pooled_embedding_dim = embedding_dim * 2
98+
sequence_length = 154
99+
100+
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
101+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
102+
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
103+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
104+
105+
return {
106+
"hidden_states": hidden_states,
107+
"encoder_hidden_states": encoder_hidden_states,
108+
"pooled_projections": pooled_prompt_embeds,
109+
"timestep": timestep,
110+
}
111+
112+
@property
113+
def input_shape(self):
114+
return (4, 32, 32)
115+
116+
@property
117+
def output_shape(self):
118+
return (4, 32, 32)
119+
120+
def prepare_init_args_and_inputs_for_common(self):
121+
init_dict = {
122+
"sample_size": 32,
123+
"patch_size": 1,
124+
"in_channels": 4,
125+
"num_layers": 2,
126+
"attention_head_dim": 8,
127+
"num_attention_heads": 4,
128+
"caption_projection_dim": 32,
129+
"joint_attention_dim": 32,
130+
"pooled_projection_dim": 64,
131+
"out_channels": 4,
132+
"pos_embed_max_size": 96,
133+
"dual_attention_layers": (0,),
134+
"qk_norm": "rms_norm",
76135
}
77136
inputs_dict = self.dummy_input
78137
return init_dict, inputs_dict

0 commit comments

Comments
 (0)