2121
2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
24- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25- from ..attention import FeedForward
26- from ..attention_processor import Attention
24+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
25+ from ...utils .torch_utils import maybe_allow_in_graph
26+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
27+ from ..attention_dispatch import dispatch_attention_fn
2728from ..cache_utils import CacheMixin
2829from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
2930from ..modeling_outputs import Transformer2DModelOutput
3435logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3536
3637
37- class WanAttnProcessor2_0 :
38+ def _get_qkv_projections (attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
39+ # encoder_hidden_states is only passed for cross-attention
40+ if encoder_hidden_states is None :
41+ encoder_hidden_states = hidden_states
42+
43+ if attn .fused_projections :
44+ if attn .cross_attention_dim_head is None :
45+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
47+ else :
48+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49+ query = attn .to_q (hidden_states )
50+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
51+ else :
52+ query = attn .to_q (hidden_states )
53+ key = attn .to_k (encoder_hidden_states )
54+ value = attn .to_v (encoder_hidden_states )
55+ return query , key , value
56+
57+
58+ def _get_added_kv_projections (attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
59+ if attn .fused_projections :
60+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
61+ else :
62+ key_img = attn .add_k_proj (encoder_hidden_states_img )
63+ value_img = attn .add_v_proj (encoder_hidden_states_img )
64+ return key_img , value_img
65+
66+
67+ class WanAttnProcessor :
68+ _attention_backend = None
69+
3870 def __init__ (self ):
3971 if not hasattr (F , "scaled_dot_product_attention" ):
40- raise ImportError ("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
72+ raise ImportError (
73+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74+ )
4175
4276 def __call__ (
4377 self ,
44- attn : Attention ,
78+ attn : "WanAttention" ,
4579 hidden_states : torch .Tensor ,
4680 encoder_hidden_states : Optional [torch .Tensor ] = None ,
4781 attention_mask : Optional [torch .Tensor ] = None ,
48- rotary_emb : Optional [torch .Tensor ] = None ,
82+ rotary_emb : Optional [Tuple [ torch .Tensor , torch . Tensor ] ] = None ,
4983 ) -> torch .Tensor :
5084 encoder_hidden_states_img = None
5185 if attn .add_k_proj is not None :
5286 # 512 is the context length of the text encoder, hardcoded for now
5387 image_context_length = encoder_hidden_states .shape [1 ] - 512
5488 encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
5589 encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
56- if encoder_hidden_states is None :
57- encoder_hidden_states = hidden_states
5890
59- query = attn .to_q (hidden_states )
60- key = attn .to_k (encoder_hidden_states )
61- value = attn .to_v (encoder_hidden_states )
91+ query , key , value = _get_qkv_projections (attn , hidden_states , encoder_hidden_states )
6292
63- if attn .norm_q is not None :
64- query = attn .norm_q (query )
65- if attn .norm_k is not None :
66- key = attn .norm_k (key )
93+ query = attn .norm_q (query )
94+ key = attn .norm_k (key )
6795
68- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
69- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
70- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
96+ query = query .unflatten (2 , (attn .heads , - 1 ))
97+ key = key .unflatten (2 , (attn .heads , - 1 ))
98+ value = value .unflatten (2 , (attn .heads , - 1 ))
7199
72100 if rotary_emb is not None :
73101
74- def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
75- dtype = torch .float32 if hidden_states .device .type == "mps" else torch .float64
76- x_rotated = torch .view_as_complex (hidden_states .to (dtype ).unflatten (3 , (- 1 , 2 )))
77- x_out = torch .view_as_real (x_rotated * freqs ).flatten (3 , 4 )
78- return x_out .type_as (hidden_states )
102+ def apply_rotary_emb (
103+ hidden_states : torch .Tensor ,
104+ freqs_cos : torch .Tensor ,
105+ freqs_sin : torch .Tensor ,
106+ ):
107+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
108+ cos = freqs_cos [..., 0 ::2 ]
109+ sin = freqs_sin [..., 1 ::2 ]
110+ out = torch .empty_like (hidden_states )
111+ out [..., 0 ::2 ] = x1 * cos - x2 * sin
112+ out [..., 1 ::2 ] = x1 * sin + x2 * cos
113+ return out .type_as (hidden_states )
79114
80115 query = apply_rotary_emb (query , rotary_emb )
81116 key = apply_rotary_emb (key , rotary_emb )
82117
83118 # I2V task
84119 hidden_states_img = None
85120 if encoder_hidden_states_img is not None :
86- key_img = attn . add_k_proj ( encoder_hidden_states_img )
121+ key_img , value_img = _get_added_kv_projections ( attn , encoder_hidden_states_img )
87122 key_img = attn .norm_added_k (key_img )
88- value_img = attn .add_v_proj (encoder_hidden_states_img )
89-
90- key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
91- value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
92123
93- hidden_states_img = F .scaled_dot_product_attention (
94- query , key_img , value_img , attn_mask = None , dropout_p = 0.0 , is_causal = False
124+ key_img = key_img .unflatten (2 , (attn .heads , - 1 ))
125+ value_img = value_img .unflatten (2 , (attn .heads , - 1 ))
126+
127+ hidden_states_img = dispatch_attention_fn (
128+ query ,
129+ key_img ,
130+ value_img ,
131+ attn_mask = None ,
132+ dropout_p = 0.0 ,
133+ is_causal = False ,
134+ backend = self ._attention_backend ,
95135 )
96- hidden_states_img = hidden_states_img .transpose ( 1 , 2 ). flatten (2 , 3 )
136+ hidden_states_img = hidden_states_img .flatten (2 , 3 )
97137 hidden_states_img = hidden_states_img .type_as (query )
98138
99- hidden_states = F .scaled_dot_product_attention (
100- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
139+ hidden_states = dispatch_attention_fn (
140+ query ,
141+ key ,
142+ value ,
143+ attn_mask = attention_mask ,
144+ dropout_p = 0.0 ,
145+ is_causal = False ,
146+ backend = self ._attention_backend ,
101147 )
102- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
148+ hidden_states = hidden_states .flatten (2 , 3 )
103149 hidden_states = hidden_states .type_as (query )
104150
105151 if hidden_states_img is not None :
@@ -110,6 +156,119 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
110156 return hidden_states
111157
112158
159+ class WanAttnProcessor2_0 :
160+ def __new__ (cls , * args , ** kwargs ):
161+ deprecation_message = (
162+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163+ "Please use WanAttnProcessor instead. "
164+ )
165+ deprecate ("WanAttnProcessor2_0" , "1.0.0" , deprecation_message , standard_warn = False )
166+ return WanAttnProcessor (* args , ** kwargs )
167+
168+
169+ class WanAttention (torch .nn .Module , AttentionModuleMixin ):
170+ _default_processor_cls = WanAttnProcessor
171+ _available_processors = [WanAttnProcessor ]
172+
173+ def __init__ (
174+ self ,
175+ dim : int ,
176+ heads : int = 8 ,
177+ dim_head : int = 64 ,
178+ eps : float = 1e-5 ,
179+ dropout : float = 0.0 ,
180+ added_kv_proj_dim : Optional [int ] = None ,
181+ cross_attention_dim_head : Optional [int ] = None ,
182+ processor = None ,
183+ ):
184+ super ().__init__ ()
185+
186+ self .inner_dim = dim_head * heads
187+ self .heads = heads
188+ self .added_kv_proj_dim = added_kv_proj_dim
189+ self .cross_attention_dim_head = cross_attention_dim_head
190+ self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191+
192+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
193+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
194+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
195+ self .to_out = torch .nn .ModuleList (
196+ [
197+ torch .nn .Linear (self .inner_dim , dim , bias = True ),
198+ torch .nn .Dropout (dropout ),
199+ ]
200+ )
201+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
202+ self .norm_k = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
203+
204+ self .add_k_proj = self .add_v_proj = None
205+ if added_kv_proj_dim is not None :
206+ self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
207+ self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
208+ self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
209+
210+ self .set_processor (processor )
211+
212+ def fuse_projections (self ):
213+ if getattr (self , "fused_projections" , False ):
214+ return
215+
216+ if self .cross_attention_dim_head is None :
217+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
218+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
219+ out_features , in_features = concatenated_weights .shape
220+ with torch .device ("meta" ):
221+ self .to_qkv = nn .Linear (in_features , out_features , bias = True )
222+ self .to_qkv .load_state_dict (
223+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
224+ )
225+ else :
226+ concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
227+ concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
228+ out_features , in_features = concatenated_weights .shape
229+ with torch .device ("meta" ):
230+ self .to_kv = nn .Linear (in_features , out_features , bias = True )
231+ self .to_kv .load_state_dict (
232+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
233+ )
234+
235+ if self .added_kv_proj_dim is not None :
236+ concatenated_weights = torch .cat ([self .add_k_proj .weight .data , self .add_v_proj .weight .data ])
237+ concatenated_bias = torch .cat ([self .add_k_proj .bias .data , self .add_v_proj .bias .data ])
238+ out_features , in_features = concatenated_weights .shape
239+ with torch .device ("meta" ):
240+ self .to_added_kv = nn .Linear (in_features , out_features , bias = True )
241+ self .to_added_kv .load_state_dict (
242+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
243+ )
244+
245+ self .fused_projections = True
246+
247+ @torch .no_grad ()
248+ def unfuse_projections (self ):
249+ if not getattr (self , "fused_projections" , False ):
250+ return
251+
252+ if hasattr (self , "to_qkv" ):
253+ delattr (self , "to_qkv" )
254+ if hasattr (self , "to_kv" ):
255+ delattr (self , "to_kv" )
256+ if hasattr (self , "to_added_kv" ):
257+ delattr (self , "to_added_kv" )
258+
259+ self .fused_projections = False
260+
261+ def forward (
262+ self ,
263+ hidden_states : torch .Tensor ,
264+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
265+ attention_mask : Optional [torch .Tensor ] = None ,
266+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
267+ ** kwargs ,
268+ ) -> torch .Tensor :
269+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
270+
271+
113272class WanImageEmbedding (torch .nn .Module ):
114273 def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
115274 super ().__init__ ()
@@ -217,11 +376,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
217376 dim = 1 ,
218377 )
219378
220- freqs_f = freqs [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
221- freqs_h = freqs [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
222- freqs_w = freqs [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
223- freqs = torch .cat ([freqs_f , freqs_h , freqs_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
224- return freqs
379+ freqs_cos = self .freqs_cos .split (split_sizes , dim = 1 )
380+ freqs_sin = self .freqs_sin .split (split_sizes , dim = 1 )
381+
382+ freqs_cos_f = freqs_cos [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
383+ freqs_cos_h = freqs_cos [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
384+ freqs_cos_w = freqs_cos [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
385+
386+ freqs_sin_f = freqs_sin [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
387+ freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
388+ freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
389+
390+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
391+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
392+
393+ return freqs_cos , freqs_sin
225394
226395
227396class WanTransformerBlock (nn .Module ):
@@ -239,33 +408,24 @@ def __init__(
239408
240409 # 1. Self-attention
241410 self .norm1 = FP32LayerNorm (dim , eps , elementwise_affine = False )
242- self .attn1 = Attention (
243- query_dim = dim ,
411+ self .attn1 = WanAttention (
412+ dim = dim ,
244413 heads = num_heads ,
245- kv_heads = num_heads ,
246414 dim_head = dim // num_heads ,
247- qk_norm = qk_norm ,
248415 eps = eps ,
249- bias = True ,
250- cross_attention_dim = None ,
251- out_bias = True ,
252- processor = WanAttnProcessor2_0 (),
416+ cross_attention_dim_head = None ,
417+ processor = WanAttnProcessor (),
253418 )
254419
255420 # 2. Cross-attention
256- self .attn2 = Attention (
257- query_dim = dim ,
421+ self .attn2 = WanAttention (
422+ dim = dim ,
258423 heads = num_heads ,
259- kv_heads = num_heads ,
260424 dim_head = dim // num_heads ,
261- qk_norm = qk_norm ,
262425 eps = eps ,
263- bias = True ,
264- cross_attention_dim = None ,
265- out_bias = True ,
266426 added_kv_proj_dim = added_kv_proj_dim ,
267- added_proj_bias = True ,
268- processor = WanAttnProcessor2_0 (),
427+ cross_attention_dim_head = dim // num_heads ,
428+ processor = WanAttnProcessor (),
269429 )
270430 self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
271431
@@ -302,12 +462,12 @@ def forward(
302462
303463 # 1. Self-attention
304464 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
305- attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
465+ attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
306466 hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
307467
308468 # 2. Cross-attention
309469 norm_hidden_states = self .norm2 (hidden_states .float ()).type_as (hidden_states )
310- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
470+ attn_output = self .attn2 (norm_hidden_states , encoder_hidden_states , None , None )
311471 hidden_states = hidden_states + attn_output
312472
313473 # 3. Feed-forward
@@ -320,7 +480,9 @@ def forward(
320480 return hidden_states
321481
322482
323- class WanTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
483+ class WanTransformer3DModel (
484+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
485+ ):
324486 r"""
325487 A Transformer model for video-like data used in the Wan model.
326488
0 commit comments