Skip to content

Commit f460493

Browse files
caic99iProzdpre-commit-ci[bot]coderabbitai[bot]
authored
feat: add use_loc_mapping (#4772)
`node_ebd_ext` contains embedding on expanded atoms, which might be large for a large cut-off. Current implementation do matmul first, then gather tensor by the neighbors. This introduces saving `sub_node_ext_update` of size nf * nall * ndim in each repflow layer, where nall might be multiple times larger than nloc. This PR do gathering first, then compute matmul. After this PR, the peak memory size is of O(nlayer * nf * nloc), unrelated to `nall`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Introduced an optional setting to enable or disable local atom index mapping during descriptor computation, providing more flexibility in non-parallel inference scenarios. - **Bug Fixes** - Enhanced test coverage and consistency checks to ensure descriptor outputs remain equivalent whether local mapping is enabled or not. - **Documentation** - Improved parameter descriptions and comments to clarify the behavior and shape of key variables related to local mapping. - **Tests** - Added new tests to validate the functional equivalence of descriptor outputs with and without local mapping. - Updated existing tests to support and verify the new local mapping option. - **Style** - Improved assertion handling in tests by specifying absolute tolerance for numerical comparisons. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <amoycaic@gmail.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 2018d62 commit f460493

File tree

15 files changed

+430
-47
lines changed

15 files changed

+430
-47
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
277277
Whether to use electronic configuration type embedding.
278278
use_tebd_bias : bool, Optional
279279
Whether to use bias in the type embedding layer.
280+
use_loc_mapping : bool, Optional
281+
Whether to use local atom index mapping in training or non-parallel inference.
282+
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
280283
type_map : list[str], Optional
281284
A list of strings. Give the name to each type of atoms.
282285
"""
@@ -296,6 +299,7 @@ def __init__(
296299
seed: Optional[Union[int, list[int]]] = None,
297300
use_econf_tebd: bool = False,
298301
use_tebd_bias: bool = False,
302+
use_loc_mapping: bool = True,
299303
type_map: Optional[list[str]] = None,
300304
) -> None:
301305
super().__init__()
@@ -342,6 +346,7 @@ def init_subclass_params(sub_data, sub_class):
342346
use_exp_switch=self.repflow_args.use_exp_switch,
343347
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
344348
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
349+
use_loc_mapping=use_loc_mapping,
345350
exclude_types=exclude_types,
346351
env_protection=env_protection,
347352
precision=precision,
@@ -350,6 +355,7 @@ def init_subclass_params(sub_data, sub_class):
350355

351356
self.use_econf_tebd = use_econf_tebd
352357
self.use_tebd_bias = use_tebd_bias
358+
self.use_loc_mapping = use_loc_mapping
353359
self.type_map = type_map
354360
self.tebd_dim = self.repflow_args.n_dim
355361
self.type_embedding = TypeEmbedNet(
@@ -548,10 +554,16 @@ def call(
548554
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
549555

550556
type_embedding = self.type_embedding.call()
551-
node_ebd_ext = xp.reshape(
552-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
553-
(nframes, nall, self.tebd_dim),
554-
)
557+
if self.use_loc_mapping:
558+
node_ebd_ext = xp.reshape(
559+
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
560+
(nframes, nloc, self.tebd_dim),
561+
)
562+
else:
563+
node_ebd_ext = xp.reshape(
564+
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
565+
(nframes, nall, self.tebd_dim),
566+
)
555567
node_ebd_inp = node_ebd_ext[:, :nloc, :]
556568
# repflows
557569
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
@@ -570,7 +582,7 @@ def serialize(self) -> dict:
570582
data = {
571583
"@class": "Descriptor",
572584
"type": "dpa3",
573-
"@version": 1,
585+
"@version": 2,
574586
"ntypes": self.ntypes,
575587
"repflow_args": self.repflow_args.serialize(),
576588
"concat_output_tebd": self.concat_output_tebd,
@@ -581,6 +593,7 @@ def serialize(self) -> dict:
581593
"trainable": self.trainable,
582594
"use_econf_tebd": self.use_econf_tebd,
583595
"use_tebd_bias": self.use_tebd_bias,
596+
"use_loc_mapping": self.use_loc_mapping,
584597
"type_map": self.type_map,
585598
"type_embedding": self.type_embedding.serialize(),
586599
}
@@ -605,7 +618,7 @@ def serialize(self) -> dict:
605618
def deserialize(cls, data: dict) -> "DescrptDPA3":
606619
data = data.copy()
607620
version = data.pop("@version")
608-
check_version_compatibility(version, 1, 1)
621+
check_version_compatibility(version, 2, 1)
609622
data.pop("@class")
610623
data.pop("type")
611624
repflow_variable = data.pop("repflow_variable").copy()

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
148148
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
149149
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
150150
accommodating larger selection numbers.
151+
use_loc_mapping : bool, optional
152+
Whether to use local atom index mapping in training or non-parallel inference.
153+
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
151154
ntypes : int
152155
Number of element types
153156
activation_function : str, optional
@@ -200,6 +203,7 @@ def __init__(
200203
use_exp_switch: bool = False,
201204
use_dynamic_sel: bool = False,
202205
sel_reduce_factor: float = 10.0,
206+
use_loc_mapping: bool = True,
203207
seed: Optional[Union[int, list[int]]] = None,
204208
) -> None:
205209
super().__init__()
@@ -234,6 +238,7 @@ def __init__(
234238
self.edge_init_use_dist = edge_init_use_dist
235239
self.use_exp_switch = use_exp_switch
236240
self.use_dynamic_sel = use_dynamic_sel
241+
self.use_loc_mapping = use_loc_mapping
237242
self.sel_reduce_factor = sel_reduce_factor
238243
if self.use_dynamic_sel and not self.smooth_edge_update:
239244
raise NotImplementedError(
@@ -540,10 +545,22 @@ def call(
540545
cosine_ij, (nframes, nloc, self.a_sel, self.a_sel, 1)
541546
) / (xp.pi**0.5)
542547

548+
if self.use_loc_mapping:
549+
assert mapping is not None
550+
flat_map = xp.reshape(mapping, (nframes, -1))
551+
nlist = xp.reshape(
552+
xp_take_along_axis(flat_map, xp.reshape(nlist, (nframes, -1)), axis=1),
553+
nlist.shape,
554+
)
555+
543556
if self.use_dynamic_sel:
544557
# get graph index
545558
edge_index, angle_index = get_graph_index(
546-
nlist, nlist_mask, a_nlist_mask, nall
559+
nlist,
560+
nlist_mask,
561+
a_nlist_mask,
562+
nall,
563+
use_loc_mapping=self.use_loc_mapping,
547564
)
548565
# flat all the tensors
549566
# n_edge x 1
@@ -577,7 +594,11 @@ def call(
577594
for idx, ll in enumerate(self.layers):
578595
# node_ebd: nb x nloc x n_dim
579596
# node_ebd_ext: nb x nall x n_dim
580-
node_ebd_ext = xp_take_along_axis(node_ebd, mapping, axis=1)
597+
node_ebd_ext = (
598+
node_ebd
599+
if self.use_loc_mapping
600+
else xp_take_along_axis(node_ebd, mapping, axis=1)
601+
)
581602
node_ebd, edge_ebd, angle_ebd = ll.call(
582603
node_ebd_ext,
583604
edge_ebd,
@@ -684,6 +705,7 @@ def serialize(self):
684705
"smooth_edge_update": self.smooth_edge_update,
685706
"use_dynamic_sel": self.use_dynamic_sel,
686707
"sel_reduce_factor": self.sel_reduce_factor,
708+
"use_loc_mapping": self.use_loc_mapping,
687709
# variables
688710
"edge_embd": self.edge_embd.serialize(),
689711
"angle_embd": self.angle_embd.serialize(),

deepmd/dpmodel/utils/network.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,7 @@ def get_graph_index(
10151015
nlist_mask: np.ndarray,
10161016
a_nlist_mask: np.ndarray,
10171017
nall: int,
1018+
use_loc_mapping: bool = True,
10181019
):
10191020
"""
10201021
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -1029,6 +1030,9 @@ def get_graph_index(
10291030
Masks of the neighbor list for angle. real nei 1 otherwise 0
10301031
nall
10311032
The number of extended atoms.
1033+
use_loc_mapping
1034+
Whether to use local atom index mapping in training or non-parallel inference.
1035+
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
10321036
10331037
Returns
10341038
-------
@@ -1069,7 +1073,9 @@ def get_graph_index(
10691073
n2e_index = n2e_index[xp.astype(nlist_mask, xp.bool)]
10701074

10711075
# node_ext(j) to edge(ij) index_select
1072-
frame_shift = xp.arange(nf, dtype=nlist.dtype) * nall
1076+
frame_shift = xp.arange(nf, dtype=nlist.dtype) * (
1077+
nall if not use_loc_mapping else nloc
1078+
)
10731079
shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis]
10741080
# n_edge
10751081
n_ext2e_index = shifted_nlist[xp.astype(nlist_mask, xp.bool)]

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class DescrptDPA3(BaseDescriptor, paddle.nn.Layer):
8989
Whether to use electronic configuration type embedding.
9090
use_tebd_bias : bool, Optional
9191
Whether to use bias in the type embedding layer.
92+
use_loc_mapping : bool, Optional
93+
Whether to use local atom index mapping in training or non-parallel inference.
94+
Not supported yet in Paddle.
9295
type_map : list[str], Optional
9396
A list of strings. Give the name to each type of atoms.
9497
"""
@@ -108,6 +111,7 @@ def __init__(
108111
seed: Optional[Union[int, list[int]]] = None,
109112
use_econf_tebd: bool = False,
110113
use_tebd_bias: bool = False,
114+
use_loc_mapping: bool = False,
111115
type_map: Optional[list[str]] = None,
112116
) -> None:
113117
super().__init__()
@@ -152,6 +156,7 @@ def init_subclass_params(sub_data, sub_class):
152156
smooth_edge_update=self.repflow_args.smooth_edge_update,
153157
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
154158
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
159+
use_loc_mapping=use_loc_mapping,
155160
exclude_types=exclude_types,
156161
env_protection=env_protection,
157162
precision=precision,
@@ -160,6 +165,7 @@ def init_subclass_params(sub_data, sub_class):
160165

161166
self.use_econf_tebd = use_econf_tebd
162167
self.use_tebd_bias = use_tebd_bias
168+
self.use_loc_mapping = use_loc_mapping
163169
self.type_map = type_map
164170
self.tebd_dim = self.repflow_args.n_dim
165171
self.type_embedding = TypeEmbedNet(
@@ -370,7 +376,7 @@ def serialize(self) -> dict:
370376
data = {
371377
"@class": "Descriptor",
372378
"type": "dpa3",
373-
"@version": 1,
379+
"@version": 2,
374380
"ntypes": self.ntypes,
375381
"repflow_args": self.repflow_args.serialize(),
376382
"concat_output_tebd": self.concat_output_tebd,
@@ -381,6 +387,7 @@ def serialize(self) -> dict:
381387
"trainable": self.trainable,
382388
"use_econf_tebd": self.use_econf_tebd,
383389
"use_tebd_bias": self.use_tebd_bias,
390+
"use_loc_mapping": self.use_loc_mapping,
384391
"type_map": self.type_map,
385392
"type_embedding": self.type_embedding.embedding.serialize(),
386393
}
@@ -405,7 +412,7 @@ def serialize(self) -> dict:
405412
def deserialize(cls, data: dict) -> "DescrptDPA3":
406413
data = data.copy()
407414
version = data.pop("@version")
408-
check_version_compatibility(version, 1, 1)
415+
check_version_compatibility(version, 2, 1)
409416
data.pop("@class")
410417
data.pop("type")
411418
repflow_variable = data.pop("repflow_variable").copy()

deepmd/pd/model/descriptor/repflows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ class DescrptBlockRepflows(DescriptorBlock):
112112
optim_update : bool, optional
113113
Whether to enable the optimized update method.
114114
Uses a more efficient process when enabled. Defaults to True
115+
use_loc_mapping : bool, Optional
116+
Whether to use local atom index mapping in training or non-parallel inference.
117+
Not supported yet in Paddle.
115118
ntypes : int
116119
Number of element types
117120
activation_function : str, optional
@@ -161,6 +164,7 @@ def __init__(
161164
smooth_edge_update: bool = False,
162165
use_dynamic_sel: bool = False,
163166
sel_reduce_factor: float = 10.0,
167+
use_loc_mapping: bool = False,
164168
optim_update: bool = True,
165169
seed: Optional[Union[int, list[int]]] = None,
166170
) -> None:
@@ -196,6 +200,8 @@ def __init__(
196200
self.use_dynamic_sel = use_dynamic_sel # not supported yet
197201
self.sel_reduce_factor = sel_reduce_factor
198202
assert not self.use_dynamic_sel, "Dynamic selection is not supported yet."
203+
self.use_loc_mapping = use_loc_mapping
204+
assert not self.use_loc_mapping, "Local mapping is not supported yet."
199205

200206
self.n_dim = n_dim
201207
self.e_dim = e_dim

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module):
8989
Whether to use electronic configuration type embedding.
9090
use_tebd_bias : bool, Optional
9191
Whether to use bias in the type embedding layer.
92+
use_loc_mapping : bool, Optional
93+
Whether to use local atom index mapping in training or non-parallel inference.
94+
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
9295
type_map : list[str], Optional
9396
A list of strings. Give the name to each type of atoms.
9497
"""
@@ -108,6 +111,7 @@ def __init__(
108111
seed: Optional[Union[int, list[int]]] = None,
109112
use_econf_tebd: bool = False,
110113
use_tebd_bias: bool = False,
114+
use_loc_mapping: bool = True,
111115
type_map: Optional[list[str]] = None,
112116
) -> None:
113117
super().__init__()
@@ -154,13 +158,15 @@ def init_subclass_params(sub_data, sub_class):
154158
use_exp_switch=self.repflow_args.use_exp_switch,
155159
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
156160
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
161+
use_loc_mapping=use_loc_mapping,
157162
exclude_types=exclude_types,
158163
env_protection=env_protection,
159164
precision=precision,
160165
seed=child_seed(seed, 1),
161166
)
162167

163168
self.use_econf_tebd = use_econf_tebd
169+
self.use_loc_mapping = use_loc_mapping
164170
self.use_tebd_bias = use_tebd_bias
165171
self.type_map = type_map
166172
self.tebd_dim = self.repflow_args.n_dim
@@ -366,7 +372,7 @@ def serialize(self) -> dict:
366372
data = {
367373
"@class": "Descriptor",
368374
"type": "dpa3",
369-
"@version": 1,
375+
"@version": 2,
370376
"ntypes": self.ntypes,
371377
"repflow_args": self.repflow_args.serialize(),
372378
"concat_output_tebd": self.concat_output_tebd,
@@ -377,6 +383,7 @@ def serialize(self) -> dict:
377383
"trainable": self.trainable,
378384
"use_econf_tebd": self.use_econf_tebd,
379385
"use_tebd_bias": self.use_tebd_bias,
386+
"use_loc_mapping": self.use_loc_mapping,
380387
"type_map": self.type_map,
381388
"type_embedding": self.type_embedding.embedding.serialize(),
382389
}
@@ -401,7 +408,7 @@ def serialize(self) -> dict:
401408
def deserialize(cls, data: dict) -> "DescrptDPA3":
402409
data = data.copy()
403410
version = data.pop("@version")
404-
check_version_compatibility(version, 1, 1)
411+
check_version_compatibility(version, 2, 1)
405412
data.pop("@class")
406413
data.pop("type")
407414
repflow_variable = data.pop("repflow_variable").copy()
@@ -470,12 +477,16 @@ def forward(
470477
The smooth switch function. shape: nf x nloc x nnei
471478
472479
"""
480+
parallel_mode = comm_dict is not None
473481
# cast the input to internal precsion
474482
extended_coord = extended_coord.to(dtype=self.prec)
475483
nframes, nloc, nnei = nlist.shape
476484
nall = extended_coord.view(nframes, -1).shape[1] // 3
477485

478-
node_ebd_ext = self.type_embedding(extended_atype)
486+
if not parallel_mode and self.use_loc_mapping:
487+
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
488+
else:
489+
node_ebd_ext = self.type_embedding(extended_atype)
479490
node_ebd_inp = node_ebd_ext[:, :nloc, :]
480491
# repflows
481492
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def optim_edge_update_dynamic(
684684

685685
def forward(
686686
self,
687-
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
687+
node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode
688688
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
689689
h2: torch.Tensor, # nf x nloc x nnei x 3
690690
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim

0 commit comments

Comments
 (0)