@@ -45,6 +45,7 @@ def __init__(
4545 n_dim : int = 128 ,
4646 e_dim : int = 16 ,
4747 a_dim : int = 64 ,
48+ a_compress_rate : int = 0 ,
4849 axis_neuron : int = 4 ,
4950 update_angle : bool = True , # angle
5051 activation_function : str = "silu" ,
@@ -70,6 +71,12 @@ def __init__(
7071 self .n_dim = n_dim
7172 self .e_dim = e_dim
7273 self .a_dim = a_dim
74+ self .a_compress_rate = a_compress_rate
75+ if a_compress_rate != 0 :
76+ assert a_dim % (2 * a_compress_rate ) == 0 , (
77+ f"For a_compress_rate of { a_compress_rate } , a_dim must be divisible by { 2 * a_compress_rate } . "
78+ f"Currently, a_dim={ a_dim } is not valid."
79+ )
7380 self .axis_neuron = axis_neuron
7481 self .update_angle = update_angle
7582 self .activation_function = activation_function
@@ -167,20 +174,42 @@ def __init__(
167174 )
168175
169176 if self .update_angle :
170- self .angle_dim = self .a_dim + self .n_dim + 2 * self .e_dim
177+ self .angle_dim = self .a_dim
178+ if self .a_compress_rate == 0 :
179+ # angle + node + edge * 2
180+ self .angle_dim += self .n_dim + 2 * self .e_dim
181+ self .a_compress_n_linear = None
182+ self .a_compress_e_linear = None
183+ else :
184+ # angle + node/c + edge/2c * 2
185+ self .angle_dim += 2 * (self .a_dim // self .a_compress_rate )
186+ self .a_compress_n_linear = MLPLayer (
187+ self .n_dim ,
188+ self .a_dim // self .a_compress_rate ,
189+ precision = precision ,
190+ bias = False ,
191+ seed = child_seed (seed , 8 ),
192+ )
193+ self .a_compress_e_linear = MLPLayer (
194+ self .e_dim ,
195+ self .a_dim // (2 * self .a_compress_rate ),
196+ precision = precision ,
197+ bias = False ,
198+ seed = child_seed (seed , 9 ),
199+ )
171200
172201 # edge angle message
173202 self .edge_angle_linear1 = MLPLayer (
174203 self .angle_dim ,
175204 self .e_dim ,
176205 precision = precision ,
177- seed = child_seed (seed , 8 ),
206+ seed = child_seed (seed , 10 ),
178207 )
179208 self .edge_angle_linear2 = MLPLayer (
180209 self .e_dim ,
181210 self .e_dim ,
182211 precision = precision ,
183- seed = child_seed (seed , 9 ),
212+ seed = child_seed (seed , 11 ),
184213 )
185214 if self .update_style == "res_residual" :
186215 self .e_residual .append (
@@ -189,7 +218,7 @@ def __init__(
189218 self .update_residual ,
190219 self .update_residual_init ,
191220 precision = precision ,
192- seed = child_seed (seed , 10 ),
221+ seed = child_seed (seed , 12 ),
193222 )
194223 )
195224
@@ -198,7 +227,7 @@ def __init__(
198227 self .angle_dim ,
199228 self .a_dim ,
200229 precision = precision ,
201- seed = child_seed (seed , 11 ),
230+ seed = child_seed (seed , 13 ),
202231 )
203232 if self .update_style == "res_residual" :
204233 self .a_residual .append (
@@ -207,13 +236,15 @@ def __init__(
207236 self .update_residual ,
208237 self .update_residual_init ,
209238 precision = precision ,
210- seed = child_seed (seed , 12 ),
239+ seed = child_seed (seed , 14 ),
211240 )
212241 )
213242 else :
214243 self .angle_self_linear = None
215244 self .edge_angle_linear1 = None
216245 self .edge_angle_linear2 = None
246+ self .a_compress_n_linear = None
247+ self .a_compress_e_linear = None
217248 self .angle_dim = 0
218249
219250 self .n_residual = nn .ParameterList (self .n_residual )
@@ -448,12 +479,22 @@ def forward(
448479 assert self .edge_angle_linear1 is not None
449480 assert self .edge_angle_linear2 is not None
450481 # get angle info
482+ if self .a_compress_rate != 0 :
483+ assert self .a_compress_n_linear is not None
484+ assert self .a_compress_e_linear is not None
485+ node_ebd_for_angle = self .a_compress_n_linear (node_ebd )
486+ edge_ebd_for_angle = self .a_compress_e_linear (edge_ebd )
487+ else :
488+ node_ebd_for_angle = node_ebd
489+ edge_ebd_for_angle = edge_ebd
490+
451491 # nb x nloc x a_nnei x a_nnei x n_dim
452492 node_for_angle_info = torch .tile (
453- node_ebd .unsqueeze (2 ).unsqueeze (2 ), (1 , 1 , self .a_sel , self .a_sel , 1 )
493+ node_ebd_for_angle .unsqueeze (2 ).unsqueeze (2 ),
494+ (1 , 1 , self .a_sel , self .a_sel , 1 ),
454495 )
455496 # nb x nloc x a_nnei x e_dim
456- edge_for_angle = edge_ebd [:, :, : self .a_sel , :]
497+ edge_for_angle = edge_ebd_for_angle [:, :, : self .a_sel , :]
457498 # nb x nloc x a_nnei x e_dim
458499 edge_for_angle = torch .where (
459500 a_nlist_mask .unsqueeze (- 1 ), edge_for_angle , 0.0
@@ -471,7 +512,7 @@ def forward(
471512 [edge_for_angle_i , edge_for_angle_j ], dim = - 1
472513 )
473514 angle_info_list = [angle_ebd , node_for_angle_info , edge_for_angle_info ]
474- # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2)
515+ # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
475516 angle_info = torch .cat (angle_info_list , dim = - 1 )
476517
477518 # edge angle message
@@ -605,6 +646,7 @@ def serialize(self) -> dict:
605646 "n_dim" : self .n_dim ,
606647 "e_dim" : self .e_dim ,
607648 "a_dim" : self .a_dim ,
649+ "a_compress_rate" : self .a_compress_rate ,
608650 "axis_neuron" : self .axis_neuron ,
609651 "activation_function" : self .activation_function ,
610652 "update_angle" : self .update_angle ,
@@ -625,6 +667,13 @@ def serialize(self) -> dict:
625667 "angle_self_linear" : self .angle_self_linear .serialize (),
626668 }
627669 )
670+ if self .a_compress_rate != 0 :
671+ data .update (
672+ {
673+ "a_compress_n_linear" : self .a_compress_n_linear .serialize (),
674+ "a_compress_e_linear" : self .a_compress_e_linear .serialize (),
675+ }
676+ )
628677 if self .update_style == "res_residual" :
629678 data .update (
630679 {
@@ -650,13 +699,16 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
650699 check_version_compatibility (data .pop ("@version" ), 1 , 1 )
651700 data .pop ("@class" )
652701 update_angle = data ["update_angle" ]
702+ a_compress_rate = data ["a_compress_rate" ]
653703 node_self_mlp = data .pop ("node_self_mlp" )
654704 node_sym_linear = data .pop ("node_sym_linear" )
655705 node_edge_linear = data .pop ("node_edge_linear" )
656706 edge_self_linear = data .pop ("edge_self_linear" )
657707 edge_angle_linear1 = data .pop ("edge_angle_linear1" , None )
658708 edge_angle_linear2 = data .pop ("edge_angle_linear2" , None )
659709 angle_self_linear = data .pop ("angle_self_linear" , None )
710+ a_compress_n_linear = data .pop ("a_compress_n_linear" , None )
711+ a_compress_e_linear = data .pop ("a_compress_e_linear" , None )
660712 update_style = data ["update_style" ]
661713 variables = data .pop ("@variables" , {})
662714 n_residual = variables .get ("n_residual" , data .pop ("n_residual" , []))
@@ -676,6 +728,11 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
676728 obj .edge_angle_linear1 = MLPLayer .deserialize (edge_angle_linear1 )
677729 obj .edge_angle_linear2 = MLPLayer .deserialize (edge_angle_linear2 )
678730 obj .angle_self_linear = MLPLayer .deserialize (angle_self_linear )
731+ if a_compress_rate != 0 :
732+ assert isinstance (a_compress_n_linear , dict )
733+ assert isinstance (a_compress_e_linear , dict )
734+ obj .a_compress_n_linear = MLPLayer .deserialize (a_compress_n_linear )
735+ obj .a_compress_e_linear = MLPLayer .deserialize (a_compress_e_linear )
679736
680737 if update_style == "res_residual" :
681738 for ii , t in enumerate (obj .n_residual ):
0 commit comments