@@ -271,6 +271,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
271271 Whether to use electronic configuration type embedding.
272272 use_tebd_bias : bool, Optional
273273 Whether to use bias in the type embedding layer.
274+ use_loc_mapping : bool, Optional
275+ Whether to use local atom index mapping in non-parallel inference.
274276 type_map : list[str], Optional
275277 A list of strings. Give the name to each type of atoms.
276278 """
@@ -290,6 +292,7 @@ def __init__(
290292 seed : Optional [Union [int , list [int ]]] = None ,
291293 use_econf_tebd : bool = False ,
292294 use_tebd_bias : bool = False ,
295+ use_loc_mapping : bool = True ,
293296 type_map : Optional [list [str ]] = None ,
294297 ) -> None :
295298 super ().__init__ ()
@@ -335,6 +338,7 @@ def init_subclass_params(sub_data, sub_class):
335338 use_exp_switch = self .repflow_args .use_exp_switch ,
336339 use_dynamic_sel = self .repflow_args .use_dynamic_sel ,
337340 sel_reduce_factor = self .repflow_args .sel_reduce_factor ,
341+ use_loc_mapping = use_loc_mapping ,
338342 exclude_types = exclude_types ,
339343 env_protection = env_protection ,
340344 precision = precision ,
@@ -343,6 +347,7 @@ def init_subclass_params(sub_data, sub_class):
343347
344348 self .use_econf_tebd = use_econf_tebd
345349 self .use_tebd_bias = use_tebd_bias
350+ self .use_loc_mapping = use_loc_mapping
346351 self .type_map = type_map
347352 self .tebd_dim = self .repflow_args .n_dim
348353 self .type_embedding = TypeEmbedNet (
@@ -541,10 +546,16 @@ def call(
541546 nall = xp .reshape (coord_ext , (nframes , - 1 )).shape [1 ] // 3
542547
543548 type_embedding = self .type_embedding .call ()
544- node_ebd_ext = xp .reshape (
545- xp .take (type_embedding , xp .reshape (atype_ext , [- 1 ]), axis = 0 ),
546- (nframes , nall , self .tebd_dim ),
547- )
549+ if self .use_loc_mapping :
550+ node_ebd_ext = xp .reshape (
551+ xp .take (type_embedding , xp .reshape (atype_ext [:, :nloc ], [- 1 ]), axis = 0 ),
552+ (nframes , nloc , self .tebd_dim ),
553+ )
554+ else :
555+ node_ebd_ext = xp .reshape (
556+ xp .take (type_embedding , xp .reshape (atype_ext , [- 1 ]), axis = 0 ),
557+ (nframes , nall , self .tebd_dim ),
558+ )
548559 node_ebd_inp = node_ebd_ext [:, :nloc , :]
549560 # repflows
550561 node_ebd , edge_ebd , h2 , rot_mat , sw = self .repflows (
0 commit comments