Skip to content

Commit c774f45

Browse files
resolve conflicts with huggingface#10981
1 parent 98a2417 commit c774f45

File tree

1 file changed

+52
-26
lines changed

1 file changed

+52
-26
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Tuple, Union
15+
from typing import Any, Dict, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22-
from ...models.attention import FeedForward
23-
from ...models.attention_processor import Attention
24-
from ...models.modeling_utils import ModelMixin
25-
from ...models.normalization import AdaLayerNormContinuous
26-
from ...utils import logging
22+
from ...loaders import PeftAdapterMixin
23+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ..attention import FeedForward
25+
from ..attention_processor import Attention
2726
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2827
from ..modeling_outputs import Transformer2DModelOutput
29-
from ...loaders import PeftAdapterMixin
28+
from ..modeling_utils import ModelMixin
29+
from ..normalization import AdaLayerNormContinuous
3030
from ..cache_utils import CacheMixin
3131

3232

@@ -124,7 +124,7 @@ def __call__(
124124
attn: Attention,
125125
hidden_states: torch.Tensor,
126126
encoder_hidden_states: torch.Tensor,
127-
attention_mask: Optional[torch.LongTensor] = None,
127+
attention_mask: Optional[torch.Tensor] = None,
128128
image_rotary_emb: Optional[torch.Tensor] = None,
129129
) -> torch.Tensor:
130130
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
@@ -157,7 +157,7 @@ def __call__(
157157
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
158158
)
159159

160-
# 4. Attention and Attention Mask
160+
# 4. Attention
161161
if attention_mask is not None:
162162
text_attention_mask = attention_mask.float().to(query.device)
163163
actual_text_seq_length = text_attention_mask.size(1)
@@ -167,7 +167,9 @@ def __call__(
167167
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
168168
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169169

170-
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
170+
hidden_states = F.scaled_dot_product_attention(
171+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
172+
)
171173
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
172174
hidden_states = hidden_states.type_as(query)
173175

@@ -246,8 +248,8 @@ def forward(
246248
1 + c_scale_mlp.unsqueeze(1)
247249
) + c_shift_mlp.unsqueeze(1)
248250

249-
ff_output = self.ff(norm_hidden_states, **kwargs)
250-
ff_output_context = self.ff(norm_encoder_hidden_states, **kwargs)
251+
ff_output = self.ff(norm_hidden_states)
252+
ff_output_context = self.ff(norm_encoder_hidden_states)
251253
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
252254
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
253255

@@ -258,30 +260,34 @@ class CogView4RotaryPosEmbed(nn.Module):
258260
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
259261
super().__init__()
260262

263+
self.dim = dim
261264
self.patch_size = patch_size
262265
self.rope_axes_dim = rope_axes_dim
263-
264-
dim_h, dim_w = dim // 2, dim // 2
265-
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
266-
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
267-
h_seq = torch.arange(self.rope_axes_dim[0])
268-
w_seq = torch.arange(self.rope_axes_dim[1])
269-
self.freqs_h = torch.outer(h_seq, h_inv_freq)
270-
self.freqs_w = torch.outer(w_seq, w_inv_freq)
266+
self.theta = theta
271267

272268
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
273269
batch_size, num_channels, height, width = hidden_states.shape
274270
height, width = height // self.patch_size, width // self.patch_size
275271

276-
h_idx = torch.arange(height)
277-
w_idx = torch.arange(width)
272+
dim_h, dim_w = self.dim // 2, self.dim // 2
273+
h_inv_freq = 1.0 / (
274+
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
275+
)
276+
w_inv_freq = 1.0 / (
277+
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
278+
)
279+
h_seq = torch.arange(self.rope_axes_dim[0])
280+
w_seq = torch.arange(self.rope_axes_dim[1])
281+
freqs_h = torch.outer(h_seq, h_inv_freq)
282+
freqs_w = torch.outer(w_seq, w_inv_freq)
283+
284+
h_idx = torch.arange(height, device=freqs_h.device)
285+
w_idx = torch.arange(width, device=freqs_w.device)
278286
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
279287
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
280288

281-
self.freqs_h = self.freqs_h.to(hidden_states.device)
282-
self.freqs_w = self.freqs_w.to(hidden_states.device)
283-
freqs_h = self.freqs_h[inner_h_idx]
284-
freqs_w = self.freqs_w[inner_w_idx]
289+
freqs_h = freqs_h[inner_h_idx]
290+
freqs_w = freqs_w[inner_w_idx]
285291

286292
# Create position matrices for height and width
287293
# [height, 1, dim//4] and [1, width, dim//4]
@@ -393,10 +399,26 @@ def forward(
393399
original_size: torch.Tensor,
394400
target_size: torch.Tensor,
395401
crop_coords: torch.Tensor,
402+
attention_kwargs: Optional[Dict[str, Any]] = None,
396403
return_dict: bool = True,
397404
attention_mask: Optional[torch.Tensor] = None,
398405
**kwargs,
399406
) -> Union[torch.Tensor, Transformer2DModelOutput]:
407+
if attention_kwargs is not None:
408+
attention_kwargs = attention_kwargs.copy()
409+
lora_scale = attention_kwargs.pop("scale", 1.0)
410+
else:
411+
lora_scale = 1.0
412+
413+
if USE_PEFT_BACKEND:
414+
# weight the lora layers by setting `lora_scale` for each PEFT layer
415+
scale_lora_layers(self, lora_scale)
416+
else:
417+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
418+
logger.warning(
419+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
420+
)
421+
400422
batch_size, num_channels, height, width = hidden_states.shape
401423

402424
# 1. RoPE
@@ -431,6 +453,10 @@ def forward(
431453
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
432454
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
433455

456+
if USE_PEFT_BACKEND:
457+
# remove `lora_scale` from each PEFT layer
458+
unscale_lora_layers(self, lora_scale)
459+
434460
if not return_dict:
435461
return (output,)
436462
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)