@@ -48,6 +48,7 @@ def __init__(
4848 a_compress_rate : int = 0 ,
4949 a_compress_use_split : bool = False ,
5050 a_compress_e_rate : int = 1 ,
51+ n_multi_edge_message : int = 1 ,
5152 axis_neuron : int = 4 ,
5253 update_angle : bool = True , # angle
5354 activation_function : str = "silu" ,
@@ -79,6 +80,8 @@ def __init__(
7980 f"For a_compress_rate of { a_compress_rate } , a_dim must be divisible by { 2 * a_compress_rate } . "
8081 f"Currently, a_dim={ a_dim } is not valid."
8182 )
83+ self .n_multi_edge_message = n_multi_edge_message
84+ assert self .n_multi_edge_message >= 1 , "n_multi_edge_message must >= 1!"
8285 self .axis_neuron = axis_neuron
8386 self .update_angle = update_angle
8487 self .activation_function = activation_function
@@ -144,20 +147,21 @@ def __init__(
144147 # node edge message
145148 self .node_edge_linear = MLPLayer (
146149 self .edge_info_dim ,
147- n_dim ,
150+ self . n_multi_edge_message * n_dim ,
148151 precision = precision ,
149152 seed = child_seed (seed , 4 ),
150153 )
151154 if self .update_style == "res_residual" :
152- self .n_residual .append (
153- get_residual (
154- n_dim ,
155- self .update_residual ,
156- self .update_residual_init ,
157- precision = precision ,
158- seed = child_seed (seed , 5 ),
155+ for head_index in range (self .n_multi_edge_message ):
156+ self .n_residual .append (
157+ get_residual (
158+ n_dim ,
159+ self .update_residual ,
160+ self .update_residual_init ,
161+ precision = precision ,
162+ seed = child_seed (child_seed (seed , 5 ), head_index ),
163+ )
159164 )
160- )
161165
162166 # edge self message
163167 self .edge_self_linear = MLPLayer (
@@ -479,10 +483,18 @@ def forward(
479483 )
480484
481485 # node edge message
482- # nb x nloc x nnei x n_dim
486+ # nb x nloc x nnei x (h * n_dim)
483487 node_edge_update = self .act (self .node_edge_linear (edge_info )) * sw .unsqueeze (- 1 )
484488 node_edge_update = torch .sum (node_edge_update , dim = - 2 ) / self .nnei
485- n_update_list .append (node_edge_update )
489+ if self .n_multi_edge_message > 1 :
490+ # nb x nloc x nnei x h x n_dim
491+ node_edge_update_mul_head = node_edge_update .view (
492+ nb , nloc , self .n_multi_edge_message , self .n_dim
493+ )
494+ for head_index in range (self .n_multi_edge_message ):
495+ n_update_list .append (node_edge_update_mul_head [:, :, head_index , :])
496+ else :
497+ n_update_list .append (node_edge_update )
486498 # update node_ebd
487499 n_updated = self .list_update (n_update_list , "node" )
488500
@@ -670,6 +682,7 @@ def serialize(self) -> dict:
670682 "a_compress_rate" : self .a_compress_rate ,
671683 "a_compress_e_rate" : self .a_compress_e_rate ,
672684 "a_compress_use_split" : self .a_compress_use_split ,
685+ "n_multi_edge_message" : self .n_multi_edge_message ,
673686 "axis_neuron" : self .axis_neuron ,
674687 "activation_function" : self .activation_function ,
675688 "update_angle" : self .update_angle ,
0 commit comments