@@ -482,21 +482,23 @@ def patchify_and_embed(
482482 ).flatten (0 , 2 )
483483 all_cap_pos_ids .append (cap_padded_pos_ids )
484484 # pad mask
485+ cap_pad_mask = torch .cat (
486+ [
487+ torch .zeros ((cap_ori_len ,), dtype = torch .bool , device = device ),
488+ torch .ones ((cap_padding_len ,), dtype = torch .bool , device = device ),
489+ ],
490+ dim = 0 ,
491+ )
485492 all_cap_pad_mask .append (
486- torch .cat (
487- [
488- torch .zeros ((cap_ori_len ,), dtype = torch .bool , device = device ),
489- torch .ones ((cap_padding_len ,), dtype = torch .bool , device = device ),
490- ],
491- dim = 0 ,
492- )
493+ cap_pad_mask if cap_padding_len > 0 else torch .zeros ((cap_ori_len ,), dtype = torch .bool , device = device )
493494 )
495+
494496 # padded feature
495497 cap_padded_feat = torch .cat (
496498 [cap_feat , cap_feat [- 1 :].repeat (cap_padding_len , 1 )],
497499 dim = 0 ,
498500 )
499- all_cap_feats_out .append (cap_padded_feat )
501+ all_cap_feats_out .append (cap_padded_feat if cap_padding_len > 0 else cap_feat )
500502
501503 ### Process Image
502504 C , F , H , W = image .size ()
@@ -515,30 +517,35 @@ def patchify_and_embed(
515517 start = (cap_ori_len + cap_padding_len + 1 , 0 , 0 ),
516518 device = device ,
517519 ).flatten (0 , 2 )
518- image_padding_pos_ids = (
519- self . create_coordinate_grid (
520- size = ( 1 , 1 , 1 ) ,
521- start = (0 , 0 , 0 ),
522- device = device ,
523- )
524- . flatten ( 0 , 2 )
525- . repeat ( image_padding_len , 1 )
520+ image_padded_pos_ids = torch . cat (
521+ [
522+ image_ori_pos_ids ,
523+ self . create_coordinate_grid ( size = ( 1 , 1 , 1 ), start = (0 , 0 , 0 ), device = device )
524+ . flatten ( 0 , 2 )
525+ . repeat ( image_padding_len , 1 ),
526+ ],
527+ dim = 0 ,
526528 )
527- image_padded_pos_ids = torch .cat ([image_ori_pos_ids , image_padding_pos_ids ], dim = 0 )
528- all_image_pos_ids .append (image_padded_pos_ids )
529+ all_image_pos_ids .append (image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids )
529530 # pad mask
531+ image_pad_mask = torch .cat (
532+ [
533+ torch .zeros ((image_ori_len ,), dtype = torch .bool , device = device ),
534+ torch .ones ((image_padding_len ,), dtype = torch .bool , device = device ),
535+ ],
536+ dim = 0 ,
537+ )
530538 all_image_pad_mask .append (
531- torch .cat (
532- [
533- torch .zeros ((image_ori_len ,), dtype = torch .bool , device = device ),
534- torch .ones ((image_padding_len ,), dtype = torch .bool , device = device ),
535- ],
536- dim = 0 ,
537- )
539+ image_pad_mask
540+ if image_padding_len > 0
541+ else torch .zeros ((image_ori_len ,), dtype = torch .bool , device = device )
538542 )
539543 # padded feature
540- image_padded_feat = torch .cat ([image , image [- 1 :].repeat (image_padding_len , 1 )], dim = 0 )
541- all_image_out .append (image_padded_feat )
544+ image_padded_feat = torch .cat (
545+ [image , image [- 1 :].repeat (image_padding_len , 1 )],
546+ dim = 0 ,
547+ )
548+ all_image_out .append (image_padded_feat if image_padding_len > 0 else image )
542549
543550 return (
544551 all_image_out ,
@@ -588,10 +595,13 @@ def forward(
588595 adaln_input = t .type_as (x )
589596 x [torch .cat (x_inner_pad_mask )] = self .x_pad_token
590597 x = list (x .split (x_item_seqlens , dim = 0 ))
591- x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
598+ x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split ([ len ( _ ) for _ in x_pos_ids ] , dim = 0 ))
592599
593600 x = pad_sequence (x , batch_first = True , padding_value = 0.0 )
594601 x_freqs_cis = pad_sequence (x_freqs_cis , batch_first = True , padding_value = 0.0 )
602+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
603+ x_freqs_cis = x_freqs_cis [:, : x .shape [1 ]]
604+
595605 x_attn_mask = torch .zeros ((bsz , x_max_item_seqlen ), dtype = torch .bool , device = device )
596606 for i , seq_len in enumerate (x_item_seqlens ):
597607 x_attn_mask [i , :seq_len ] = 1
@@ -605,17 +615,21 @@ def forward(
605615
606616 # cap embed & refine
607617 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
608- assert all (_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens )
609618 cap_max_item_seqlen = max (cap_item_seqlens )
610619
611620 cap_feats = torch .cat (cap_feats , dim = 0 )
612621 cap_feats = self .cap_embedder (cap_feats )
613622 cap_feats [torch .cat (cap_inner_pad_mask )] = self .cap_pad_token
614623 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
615- cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
624+ cap_freqs_cis = list (
625+ self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split ([len (_ ) for _ in cap_pos_ids ], dim = 0 )
626+ )
616627
617628 cap_feats = pad_sequence (cap_feats , batch_first = True , padding_value = 0.0 )
618629 cap_freqs_cis = pad_sequence (cap_freqs_cis , batch_first = True , padding_value = 0.0 )
630+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
631+ cap_freqs_cis = cap_freqs_cis [:, : cap_feats .shape [1 ]]
632+
619633 cap_attn_mask = torch .zeros ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
620634 for i , seq_len in enumerate (cap_item_seqlens ):
621635 cap_attn_mask [i , :seq_len ] = 1
0 commit comments