@@ -1075,7 +1075,7 @@ def forward(
10751075 return recon_faces , total_loss , loss_breakdown
10761076
10771077@save_load (version = __version__ )
1078- class MeshTransformer (Module ,PyTorchModelHubMixin ):
1078+ class MeshTransformer (Module , PyTorchModelHubMixin ):
10791079 @typecheck
10801080 def __init__ (
10811081 self ,
@@ -1094,12 +1094,13 @@ def __init__(
10941094 cross_attn_num_mem_kv = 4 , # needed for preventing nan when dropping out text condition
10951095 dropout = 0. ,
10961096 coarse_pre_gateloop_depth = 2 ,
1097+ coarse_adaptive_rmsnorm = False ,
10971098 fine_pre_gateloop_depth = 2 ,
10981099 gateloop_use_heinsen = False ,
10991100 fine_attn_depth = 2 ,
11001101 fine_attn_dim_head = 32 ,
11011102 fine_attn_heads = 8 ,
1102- fine_cross_attend_text = False ,
1103+ fine_cross_attend_text = False , # additional conditioning - fine transformer cross attention to text tokens
11031104 pad_id = - 1 ,
11041105 num_sos_tokens = None ,
11051106 condition_on_text = False ,
@@ -1177,6 +1178,8 @@ def __init__(
11771178 # main autoregressive attention network
11781179 # attending to a face token
11791180
1181+ self .coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
1182+
11801183 self .decoder = Decoder (
11811184 dim = dim ,
11821185 depth = attn_depth ,
@@ -1185,6 +1188,8 @@ def __init__(
11851188 attn_flash = flash_attn ,
11861189 attn_dropout = dropout ,
11871190 ff_dropout = dropout ,
1191+ use_adaptive_rmsnorm = coarse_adaptive_rmsnorm ,
1192+ dim_condition = dim_text ,
11881193 cross_attend = condition_on_text ,
11891194 cross_attn_dim_context = cross_attn_dim_context ,
11901195 cross_attn_num_mem_kv = cross_attn_num_mem_kv ,
@@ -1458,6 +1463,11 @@ def forward_on_codes(
14581463 context_mask = text_mask
14591464 )
14601465
1466+ if self .coarse_adaptive_rmsnorm :
1467+ attn_context_kwargs .update (
1468+ condition = pooled_text_embed
1469+ )
1470+
14611471 # take care of codes that may be flattened
14621472
14631473 if codes .ndim > 2 :
0 commit comments