@@ -564,6 +564,31 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
564564 return cos , sin
565565
566566
567+ def get_3d_rotary_pos_embed_allegro (
568+ embed_dim , crops_coords , grid_size , temporal_size , interpolation_scale : Tuple [float , float , float ] = (1.0 , 1.0 , 1.0 ), theta : int = 10000
569+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
570+ # TODO(aryan): docs
571+ start , stop = crops_coords
572+ grid_size_h , grid_size_w = grid_size
573+ interpolation_scale_t , interpolation_scale_h , interpolation_scale_w = interpolation_scale
574+ grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
575+ grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
576+ grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
577+
578+ # Compute dimensions for each axis
579+ dim_t = embed_dim // 3
580+ dim_h = embed_dim // 3
581+ dim_w = embed_dim // 3
582+
583+ # Temporal frequencies
584+ freqs_t = get_1d_rotary_pos_embed (dim_t , grid_t / interpolation_scale_t , theta = theta , use_real = True , repeat_interleave_real = False )
585+ # Spatial frequencies for height and width
586+ freqs_h = get_1d_rotary_pos_embed (dim_h , grid_h / interpolation_scale_h , theta = theta , use_real = True , repeat_interleave_real = False )
587+ freqs_w = get_1d_rotary_pos_embed (dim_w , grid_w / interpolation_scale_w , theta = theta , use_real = True , repeat_interleave_real = False )
588+
589+ return freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w
590+
591+
567592def get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True ):
568593 """
569594 RoPE for image tokens with 2d structure.
@@ -684,7 +709,7 @@ def get_1d_rotary_pos_embed(
684709 freqs_sin = freqs .sin ().repeat_interleave (2 , dim = 1 ).float () # [S, D]
685710 return freqs_cos , freqs_sin
686711 elif use_real :
687- # stable audio
712+ # stable audio, allegro
688713 freqs_cos = torch .cat ([freqs .cos (), freqs .cos ()], dim = - 1 ).float () # [S, D]
689714 freqs_sin = torch .cat ([freqs .sin (), freqs .sin ()], dim = - 1 ).float () # [S, D]
690715 return freqs_cos , freqs_sin
@@ -743,6 +768,24 @@ def apply_rotary_emb(
743768 return x_out .type_as (x )
744769
745770
771+ def apply_rotary_emb_allegro (x : torch .Tensor , freqs_cis , positions ):
772+ # TODO(aryan): rewrite
773+ def apply_1d_rope (tokens , pos , cos , sin ):
774+ cos = F .embedding (pos , cos )[:, None , :, :]
775+ sin = F .embedding (pos , sin )[:, None , :, :]
776+ x1 , x2 = tokens [..., : tokens .shape [- 1 ] // 2 ], tokens [..., tokens .shape [- 1 ] // 2 :]
777+ tokens_rotated = torch .cat ((- x2 , x1 ), dim = - 1 )
778+ return (tokens .float () * cos + tokens_rotated .float () * sin ).to (tokens .dtype )
779+
780+ (t_cos , t_sin ), (h_cos , h_sin ), (w_cos , w_sin ) = freqs_cis
781+ t , h , w = x .chunk (3 , dim = - 1 )
782+ t = apply_1d_rope (t , positions [0 ], t_cos , t_sin )
783+ h = apply_1d_rope (h , positions [1 ], h_cos , h_sin )
784+ w = apply_1d_rope (w , positions [2 ], w_cos , w_sin )
785+ x = torch .cat ([t , h , w ], dim = - 1 )
786+ return x
787+
788+
746789class FluxPosEmbed (nn .Module ):
747790 # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
748791 def __init__ (self , theta : int , axes_dim : List [int ]):
0 commit comments