@@ -152,7 +152,6 @@ def init_subclass_params(sub_data, sub_class):
152152 e_dim = self .repflow_args .e_dim ,
153153 a_dim = self .repflow_args .a_dim ,
154154 axis_neuron = self .repflow_args .axis_neuron ,
155- node_has_conv = self .repflow_args .node_has_conv ,
156155 update_angle = self .repflow_args .update_angle ,
157156 activation_function = self .activation_function ,
158157 update_style = self .repflow_args .update_style ,
@@ -299,7 +298,7 @@ def change_type_map(
299298 extend_descrpt_stat (
300299 repflow ,
301300 type_map ,
302- des_with_stat = model_with_new_type_stat .repflow
301+ des_with_stat = model_with_new_type_stat .repflows
303302 if model_with_new_type_stat is not None
304303 else None ,
305304 )
@@ -380,6 +379,7 @@ def serialize(self) -> dict:
380379 }
381380 repflow_variable = {
382381 "edge_embd" : repflows .edge_embd .serialize (),
382+ "angle_embd" : repflows .angle_embd .serialize (),
383383 "repflow_layers" : [layer .serialize () for layer in repflows .layers ],
384384 "env_mat" : DPEnvMat (repflows .rcut , repflows .rcut_smth ).serialize (),
385385 "@variables" : {
@@ -417,6 +417,9 @@ def t_cvt(xx):
417417 env_mat = repflow_variable .pop ("env_mat" )
418418 repflow_layers = repflow_variable .pop ("repflow_layers" )
419419 obj .repflows .edge_embd = MLPLayer .deserialize (repflow_variable .pop ("edge_embd" ))
420+ obj .repflows .angle_embd = MLPLayer .deserialize (
421+ repflow_variable .pop ("angle_embd" )
422+ )
420423 obj .repflows ["davg" ] = t_cvt (statistic_repflows ["davg" ])
421424 obj .repflows ["dstd" ] = t_cvt (statistic_repflows ["dstd" ])
422425 obj .repflows .layers = torch .nn .ModuleList (
@@ -449,12 +452,12 @@ def forward(
449452
450453 Returns
451454 -------
452- node_embd
455+ node_ebd
453456 The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim)
454457 rot_mat
455458 The rotationally equivariant and permutationally invariant single particle
456459 representation. shape: nf x nloc x e_dim x 3
457- edge_embd
460+ edge_ebd
458461 The edge embedding.
459462 shape: nf x nloc x nnei x e_dim
460463 h2
@@ -469,23 +472,23 @@ def forward(
469472 nframes , nloc , nnei = nlist .shape
470473 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
471474
472- node_embd_ext = self .type_embedding (extended_atype )
473- node_embd_inp = node_embd_ext [:, :nloc , :]
475+ node_ebd_ext = self .type_embedding (extended_atype )
476+ node_ebd_inp = node_ebd_ext [:, :nloc , :]
474477 # repflows
475- node_embd , edge_embd , h2 , rot_mat , sw = self .repflows (
478+ node_ebd , edge_ebd , h2 , rot_mat , sw = self .repflows (
476479 nlist ,
477480 extended_coord ,
478481 extended_atype ,
479- node_embd_ext ,
482+ node_ebd_ext ,
480483 mapping ,
481484 comm_dict = comm_dict ,
482485 )
483486 if self .concat_output_tebd :
484- node_embd = torch .cat ([node_embd , node_embd_inp ], dim = - 1 )
487+ node_ebd = torch .cat ([node_ebd , node_ebd_inp ], dim = - 1 )
485488 return (
486- node_embd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
489+ node_ebd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
487490 rot_mat .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
488- edge_embd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
491+ edge_ebd .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
489492 h2 .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
490493 sw .to (dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
491494 )
0 commit comments