1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Optional , Tuple , Union
15+ from typing import Any , Dict , Optional , Tuple , Union
1616
1717import torch
1818import torch .nn as nn
1919import torch .nn .functional as F
2020
2121from ...configuration_utils import ConfigMixin , register_to_config
22- from ...models .attention import FeedForward
23- from ...models .attention_processor import Attention
24- from ...models .modeling_utils import ModelMixin
25- from ...models .normalization import AdaLayerNormContinuous
26- from ...utils import logging
22+ from ...loaders import PeftAdapterMixin
23+ from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
24+ from ..attention import FeedForward
25+ from ..attention_processor import Attention
2726from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2827from ..modeling_outputs import Transformer2DModelOutput
29- from ...loaders import PeftAdapterMixin
28+ from ..modeling_utils import ModelMixin
29+ from ..normalization import AdaLayerNormContinuous
3030from ..cache_utils import CacheMixin
3131
3232
@@ -124,7 +124,7 @@ def __call__(
124124 attn : Attention ,
125125 hidden_states : torch .Tensor ,
126126 encoder_hidden_states : torch .Tensor ,
127- attention_mask : Optional [torch .LongTensor ] = None ,
127+ attention_mask : Optional [torch .Tensor ] = None ,
128128 image_rotary_emb : Optional [torch .Tensor ] = None ,
129129 ) -> torch .Tensor :
130130 batch_size , text_seq_length , embed_dim = encoder_hidden_states .shape
@@ -157,7 +157,7 @@ def __call__(
157157 key [:, :, text_seq_length :, :], image_rotary_emb , use_real_unbind_dim = - 2
158158 )
159159
160- # 4. Attention and Attention Mask
160+ # 4. Attention
161161 if attention_mask is not None :
162162 text_attention_mask = attention_mask .float ().to (query .device )
163163 actual_text_seq_length = text_attention_mask .size (1 )
@@ -167,7 +167,9 @@ def __call__(
167167 attention_mask_matrix = new_attention_mask @ new_attention_mask .transpose (1 , 2 )
168168 attention_mask = (attention_mask_matrix > 0 ).unsqueeze (1 ).to (query .dtype )
169169
170- hidden_states = F .scaled_dot_product_attention (query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False )
170+ hidden_states = F .scaled_dot_product_attention (
171+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
172+ )
171173 hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
172174 hidden_states = hidden_states .type_as (query )
173175
@@ -246,8 +248,8 @@ def forward(
246248 1 + c_scale_mlp .unsqueeze (1 )
247249 ) + c_shift_mlp .unsqueeze (1 )
248250
249- ff_output = self .ff (norm_hidden_states , ** kwargs )
250- ff_output_context = self .ff (norm_encoder_hidden_states , ** kwargs )
251+ ff_output = self .ff (norm_hidden_states )
252+ ff_output_context = self .ff (norm_encoder_hidden_states )
251253 hidden_states = hidden_states + ff_output * gate_mlp .unsqueeze (1 )
252254 encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp .unsqueeze (1 )
253255
@@ -258,30 +260,34 @@ class CogView4RotaryPosEmbed(nn.Module):
258260 def __init__ (self , dim : int , patch_size : int , rope_axes_dim : Tuple [int , int ], theta : float = 10000.0 ) -> None :
259261 super ().__init__ ()
260262
263+ self .dim = dim
261264 self .patch_size = patch_size
262265 self .rope_axes_dim = rope_axes_dim
263-
264- dim_h , dim_w = dim // 2 , dim // 2
265- h_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h ))
266- w_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w ))
267- h_seq = torch .arange (self .rope_axes_dim [0 ])
268- w_seq = torch .arange (self .rope_axes_dim [1 ])
269- self .freqs_h = torch .outer (h_seq , h_inv_freq )
270- self .freqs_w = torch .outer (w_seq , w_inv_freq )
266+ self .theta = theta
271267
272268 def forward (self , hidden_states : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
273269 batch_size , num_channels , height , width = hidden_states .shape
274270 height , width = height // self .patch_size , width // self .patch_size
275271
276- h_idx = torch .arange (height )
277- w_idx = torch .arange (width )
272+ dim_h , dim_w = self .dim // 2 , self .dim // 2
273+ h_inv_freq = 1.0 / (
274+ self .theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h )
275+ )
276+ w_inv_freq = 1.0 / (
277+ self .theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w )
278+ )
279+ h_seq = torch .arange (self .rope_axes_dim [0 ])
280+ w_seq = torch .arange (self .rope_axes_dim [1 ])
281+ freqs_h = torch .outer (h_seq , h_inv_freq )
282+ freqs_w = torch .outer (w_seq , w_inv_freq )
283+
284+ h_idx = torch .arange (height , device = freqs_h .device )
285+ w_idx = torch .arange (width , device = freqs_w .device )
278286 inner_h_idx = h_idx * self .rope_axes_dim [0 ] // height
279287 inner_w_idx = w_idx * self .rope_axes_dim [1 ] // width
280288
281- self .freqs_h = self .freqs_h .to (hidden_states .device )
282- self .freqs_w = self .freqs_w .to (hidden_states .device )
283- freqs_h = self .freqs_h [inner_h_idx ]
284- freqs_w = self .freqs_w [inner_w_idx ]
289+ freqs_h = freqs_h [inner_h_idx ]
290+ freqs_w = freqs_w [inner_w_idx ]
285291
286292 # Create position matrices for height and width
287293 # [height, 1, dim//4] and [1, width, dim//4]
@@ -393,10 +399,26 @@ def forward(
393399 original_size : torch .Tensor ,
394400 target_size : torch .Tensor ,
395401 crop_coords : torch .Tensor ,
402+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
396403 return_dict : bool = True ,
397404 attention_mask : Optional [torch .Tensor ] = None ,
398405 ** kwargs ,
399406 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
407+ if attention_kwargs is not None :
408+ attention_kwargs = attention_kwargs .copy ()
409+ lora_scale = attention_kwargs .pop ("scale" , 1.0 )
410+ else :
411+ lora_scale = 1.0
412+
413+ if USE_PEFT_BACKEND :
414+ # weight the lora layers by setting `lora_scale` for each PEFT layer
415+ scale_lora_layers (self , lora_scale )
416+ else :
417+ if attention_kwargs is not None and attention_kwargs .get ("scale" , None ) is not None :
418+ logger .warning (
419+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
420+ )
421+
400422 batch_size , num_channels , height , width = hidden_states .shape
401423
402424 # 1. RoPE
@@ -431,6 +453,10 @@ def forward(
431453 hidden_states = hidden_states .reshape (batch_size , post_patch_height , post_patch_width , - 1 , p , p )
432454 output = hidden_states .permute (0 , 3 , 1 , 4 , 2 , 5 ).flatten (4 , 5 ).flatten (2 , 3 )
433455
456+ if USE_PEFT_BACKEND :
457+ # remove `lora_scale` from each PEFT layer
458+ unscale_lora_layers (self , lora_scale )
459+
434460 if not return_dict :
435461 return (output ,)
436462 return Transformer2DModelOutput (sample = output )
0 commit comments