Skip to content

Commit 39880b9

Browse files
committed
Align dpa3 numpy backend with PT loc mapping
1 parent 598c2da commit 39880b9

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
271271
Whether to use electronic configuration type embedding.
272272
use_tebd_bias : bool, Optional
273273
Whether to use bias in the type embedding layer.
274+
use_loc_mapping : bool, Optional
275+
Whether to use local atom index mapping in non-parallel inference.
274276
type_map : list[str], Optional
275277
A list of strings. Give the name to each type of atoms.
276278
"""
@@ -290,6 +292,7 @@ def __init__(
290292
seed: Optional[Union[int, list[int]]] = None,
291293
use_econf_tebd: bool = False,
292294
use_tebd_bias: bool = False,
295+
use_loc_mapping: bool = True,
293296
type_map: Optional[list[str]] = None,
294297
) -> None:
295298
super().__init__()
@@ -335,6 +338,7 @@ def init_subclass_params(sub_data, sub_class):
335338
use_exp_switch=self.repflow_args.use_exp_switch,
336339
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
337340
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
341+
use_loc_mapping=use_loc_mapping,
338342
exclude_types=exclude_types,
339343
env_protection=env_protection,
340344
precision=precision,
@@ -343,6 +347,7 @@ def init_subclass_params(sub_data, sub_class):
343347

344348
self.use_econf_tebd = use_econf_tebd
345349
self.use_tebd_bias = use_tebd_bias
350+
self.use_loc_mapping = use_loc_mapping
346351
self.type_map = type_map
347352
self.tebd_dim = self.repflow_args.n_dim
348353
self.type_embedding = TypeEmbedNet(
@@ -541,10 +546,16 @@ def call(
541546
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
542547

543548
type_embedding = self.type_embedding.call()
544-
node_ebd_ext = xp.reshape(
545-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
546-
(nframes, nall, self.tebd_dim),
547-
)
549+
if self.use_loc_mapping:
550+
node_ebd_ext = xp.reshape(
551+
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
552+
(nframes, nloc, self.tebd_dim),
553+
)
554+
else:
555+
node_ebd_ext = xp.reshape(
556+
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
557+
(nframes, nall, self.tebd_dim),
558+
)
548559
node_ebd_inp = node_ebd_ext[:, :nloc, :]
549560
# repflows
550561
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
145145
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
146146
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
147147
accommodating larger selection numbers.
148+
use_loc_mapping : bool, optional
149+
Whether to use local atom index mapping in non-parallel inference.
148150
ntypes : int
149151
Number of element types
150152
activation_function : str, optional
@@ -196,6 +198,7 @@ def __init__(
196198
use_exp_switch: bool = False,
197199
use_dynamic_sel: bool = False,
198200
sel_reduce_factor: float = 10.0,
201+
use_loc_mapping: bool = True,
199202
seed: Optional[Union[int, list[int]]] = None,
200203
) -> None:
201204
super().__init__()
@@ -229,6 +232,7 @@ def __init__(
229232
self.smooth_edge_update = smooth_edge_update
230233
self.use_exp_switch = use_exp_switch
231234
self.use_dynamic_sel = use_dynamic_sel
235+
self.use_loc_mapping = use_loc_mapping
232236
self.sel_reduce_factor = sel_reduce_factor
233237
if self.use_dynamic_sel and not self.smooth_edge_update:
234238
raise NotImplementedError(
@@ -527,10 +531,18 @@ def call(
527531
cosine_ij, (nframes, nloc, self.a_sel, self.a_sel, 1)
528532
) / (xp.pi**0.5)
529533

534+
if self.use_loc_mapping:
535+
assert mapping is not None
536+
flat_map = xp.reshape(mapping, (nframes, -1))
537+
nlist = xp.reshape(
538+
xp_take_along_axis(flat_map, xp.reshape(nlist, (nframes, -1)), axis=1),
539+
nlist.shape,
540+
)
541+
530542
if self.use_dynamic_sel:
531543
# get graph index
532544
edge_index, angle_index = get_graph_index(
533-
nlist, nlist_mask, a_nlist_mask, nall
545+
nlist, nlist_mask, a_nlist_mask, nall, use_loc_mapping=self.use_loc_mapping
534546
)
535547
# flat all the tensors
536548
# n_edge x 1
@@ -561,7 +573,11 @@ def call(
561573
for idx, ll in enumerate(self.layers):
562574
# node_ebd: nb x nloc x n_dim
563575
# node_ebd_ext: nb x nall x n_dim
564-
node_ebd_ext = xp_take_along_axis(node_ebd, mapping, axis=1)
576+
node_ebd_ext = (
577+
node_ebd
578+
if self.use_loc_mapping
579+
else xp_take_along_axis(node_ebd, mapping, axis=1)
580+
)
565581
node_ebd, edge_ebd, angle_ebd = ll.call(
566582
node_ebd_ext,
567583
edge_ebd,

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,7 @@ def get_graph_index(
10061006
nlist_mask: np.ndarray,
10071007
a_nlist_mask: np.ndarray,
10081008
nall: int,
1009+
use_loc_mapping: bool = True,
10091010
):
10101011
"""
10111012
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -1020,6 +1021,8 @@ def get_graph_index(
10201021
Masks of the neighbor list for angle. real nei 1 otherwise 0
10211022
nall
10221023
The number of extended atoms.
1024+
use_loc_mapping
1025+
Whether to use local atom index mapping in non-parallel inference.
10231026
10241027
Returns
10251028
-------
@@ -1060,7 +1063,7 @@ def get_graph_index(
10601063
n2e_index = n2e_index[xp.astype(nlist_mask, xp.bool)]
10611064

10621065
# node_ext(j) to edge(ij) index_select
1063-
frame_shift = xp.arange(nf, dtype=nlist.dtype) * nall
1066+
frame_shift = xp.arange(nf, dtype=nlist.dtype) * (nall if not use_loc_mapping else nloc)
10641067
shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis]
10651068
# n_edge
10661069
n_ext2e_index = shifted_nlist[xp.astype(nlist_mask, xp.bool)]

0 commit comments

Comments
 (0)