Skip to content

Commit 91ebe34

Browse files
feat(pd): support dpa2/dpa3 C++ inference (#4870)
Support DPA2/DPA3 C++ inference with lammps > [!NOTE] > The intermediate representation (IR) of PaddlePaddle's computation graph does not support the `Dict[str, Tensor]` data type. Considering the special nature of `comm_dict`, we replaced it with `List[Tensor]`, which allowed us to successfully run inference for DPA2 and DPA3. - DPA2 <img width="896" height="680" alt="4c116161d304da381ffeed968857be1f" src="https://github.com/user-attachments/assets/c4bc178d-4a14-43f4-8d08-946eb9bdf3d3" /> - DPA3 <img width="891" height="682" alt="863dfb7c8f285149cfdaf6aa4c6849fc" src="https://github.com/user-attachments/assets/b74ea212-9055-4863-85fb-18cbace3c60e" /> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit * **New Features** * Added support for distributed computation and message passing in PaddlePaddle backend, including a custom operator for efficient border data exchange. * Enabled optional retrieval of intermediate network outputs during evaluation. * Introduced input statistics computation for model fitting. * Added dynamic loading of Paddle custom operator library and integration in model descriptor and repformer layers. * **Improvements** * Enhanced model freezing and input signature flexibility with dynamic batch sizes and expanded communication tensor support. * Improved tensor operations for better device handling, explicit broadcasting, and type casting consistency. * Refined handling of atomic virial outputs and descriptor communication. * Added support for spin and non-spin modes in parallel descriptor computations using the custom operator. * Simplified device placement in tensor creation and removed redundant explicit device transfers. * Improved input validation and type annotations across models and descriptors. * Replaced ad-hoc logging in C++ API with a structured logger for better debug output. * Updated parameter documentation for improved clarity. * **Bug Fixes** * Fixed tensor padding, masking, and device placement issues for improved robustness in distributed and parallel scenarios. * Corrected handling of neighbor list masking and indexing with explicit tensor fills to avoid broadcast errors. * **Build/Chores** * Added new build scripts and CMake configurations for PaddlePaddle custom operator integration. * Improved support for PaddlePaddle model detection and operator library loading in both Python and C++ APIs. * Implemented full message passing setup in C++ API replacing previous exceptions. * Added Python setup script for Paddle custom operator extension build. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3be0755 commit 91ebe34

32 files changed

+1484
-250
lines changed

deepmd/pd/cxx_op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
from types import (
4+
ModuleType,
5+
)
6+
7+
8+
def load_library(module_name: str) -> tuple[bool, ModuleType]:
9+
"""Load OP library and return the module if success.
10+
11+
Parameters
12+
----------
13+
module_name : str
14+
Name of the module
15+
16+
Returns
17+
-------
18+
bool
19+
Whether the library is loaded successfully
20+
ModuleType
21+
loaded custom operator module
22+
"""
23+
if importlib.util.find_spec(module_name) is not None:
24+
module = importlib.import_module(module_name)
25+
return True, module
26+
27+
return False, None
28+
29+
30+
ENABLE_CUSTOMIZED_OP, paddle_ops_deepmd = load_library("deepmd_op_pd")
31+
32+
__all__ = [
33+
"ENABLE_CUSTOMIZED_OP",
34+
"paddle_ops_deepmd",
35+
]

deepmd/pd/entrypoints/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,13 @@ def freeze(
368368
model.forward = paddle.jit.to_static(
369369
model.forward,
370370
input_spec=[
371-
InputSpec([1, -1, 3], dtype="float64", name="coord"), # coord
372-
InputSpec([1, -1], dtype="int64", name="atype"), # atype
373-
InputSpec([1, 9], dtype="float64", name="box"), # box
371+
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # coord
372+
InputSpec([-1, -1], dtype="int64", name="atype"), # atype
373+
InputSpec([-1, 9], dtype="float64", name="box"), # box
374374
None, # fparam
375375
None, # aparam
376-
True, # do_atomic_virial
376+
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
377+
False, # do_atomic_virial
377378
],
378379
full_graph=True,
379380
)
@@ -388,14 +389,23 @@ def freeze(
388389
model.forward_lower = paddle.jit.to_static(
389390
model.forward_lower,
390391
input_spec=[
391-
InputSpec([1, -1, 3], dtype="float64", name="coord"), # extended_coord
392-
InputSpec([1, -1], dtype="int32", name="atype"), # extended_atype
393-
InputSpec([1, -1, -1], dtype="int32", name="nlist"), # nlist
394-
InputSpec([1, -1], dtype="int64", name="mapping"), # mapping
392+
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # extended_coord
393+
InputSpec([-1, -1], dtype="int32", name="atype"), # extended_atype
394+
InputSpec([-1, -1, -1], dtype="int32", name="nlist"), # nlist
395+
InputSpec([-1, -1], dtype="int64", name="mapping"), # mapping
395396
None, # fparam
396397
None, # aparam
397-
True, # do_atomic_virial
398-
None, # comm_dict
398+
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
399+
False, # do_atomic_virial
400+
(
401+
InputSpec([-1], "int64", name="send_list"),
402+
InputSpec([-1], "int32", name="send_proc"),
403+
InputSpec([-1], "int32", name="recv_proc"),
404+
InputSpec([-1], "int32", name="send_num"),
405+
InputSpec([-1], "int32", name="recv_num"),
406+
InputSpec([-1], "int64", name="communicator"),
407+
# InputSpec([1], "int64", name="has_spin"),
408+
), # comm_dict
399409
],
400410
full_graph=True,
401411
)

deepmd/pd/model/atomic_model/base_atomic_model.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import copy
43
import logging
54
from typing import (
65
Callable,
6+
NoReturn,
77
Optional,
88
Union,
99
)
@@ -79,16 +79,18 @@ def __init__(
7979
pair_exclude_types: list[tuple[int, int]] = [],
8080
rcond: Optional[float] = None,
8181
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
82-
):
82+
data_stat_protect: float = 1e-2,
83+
) -> None:
8384
paddle.nn.Layer.__init__(self)
8485
BaseAtomicModel_.__init__(self)
8586
self.type_map = type_map
8687
self.reinit_atom_exclude(atom_exclude_types)
8788
self.reinit_pair_exclude(pair_exclude_types)
8889
self.rcond = rcond
8990
self.preset_out_bias = preset_out_bias
91+
self.data_stat_protect = data_stat_protect
9092

91-
def init_out_stat(self):
93+
def init_out_stat(self) -> None:
9294
"""Initialize the output bias."""
9395
ntypes = self.get_ntypes()
9496
self.bias_keys: list[str] = list(self.fitting_output_def().keys())
@@ -104,7 +106,7 @@ def init_out_stat(self):
104106
def set_out_bias(self, out_bias: paddle.Tensor) -> None:
105107
self.out_bias = out_bias
106108

107-
def __setitem__(self, key, value):
109+
def __setitem__(self, key, value) -> None:
108110
if key in ["out_bias"]:
109111
self.out_bias = value
110112
elif key in ["out_std"]:
@@ -124,10 +126,20 @@ def get_type_map(self) -> list[str]:
124126
"""Get the type map."""
125127
return self.type_map
126128

129+
def get_compute_stats_distinguish_types(self) -> bool:
130+
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
131+
return True
132+
133+
def get_intensive(self) -> bool:
134+
"""Whether the fitting property is intensive."""
135+
return False
136+
127137
def reinit_atom_exclude(
128138
self,
129-
exclude_types: list[int] = [],
130-
):
139+
exclude_types: Optional[list[int]] = None,
140+
) -> None:
141+
if exclude_types is None:
142+
exclude_types = []
131143
self.atom_exclude_types = exclude_types
132144
if exclude_types == []:
133145
self.atom_excl = None
@@ -137,7 +149,7 @@ def reinit_atom_exclude(
137149
def reinit_pair_exclude(
138150
self,
139151
exclude_types: list[tuple[int, int]] = [],
140-
):
152+
) -> None:
141153
self.pair_exclude_types = exclude_types
142154
if exclude_types == []:
143155
self.pair_excl = None
@@ -191,7 +203,7 @@ def forward_common_atomic(
191203
mapping: Optional[paddle.Tensor] = None,
192204
fparam: Optional[paddle.Tensor] = None,
193205
aparam: Optional[paddle.Tensor] = None,
194-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
206+
comm_dict: Optional[list[paddle.Tensor]] = None,
195207
) -> dict[str, paddle.Tensor]:
196208
"""Common interface for atomic inference.
197209
@@ -232,7 +244,7 @@ def forward_common_atomic(
232244
if self.pair_excl is not None:
233245
pair_mask = self.pair_excl(nlist, extended_atype)
234246
# exclude neighbors in the nlist
235-
nlist = paddle.where(pair_mask == 1, nlist, -1)
247+
nlist = paddle.where(pair_mask == 1, nlist, paddle.full_like(nlist, -1))
236248

237249
ext_atom_mask = self.make_atom_mask(extended_atype)
238250
ret_dict = self.forward_atomic(
@@ -274,7 +286,7 @@ def forward(
274286
mapping: Optional[paddle.Tensor] = None,
275287
fparam: Optional[paddle.Tensor] = None,
276288
aparam: Optional[paddle.Tensor] = None,
277-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
289+
comm_dict: Optional[list[paddle.Tensor]] = None,
278290
) -> dict[str, paddle.Tensor]:
279291
return self.forward_common_atomic(
280292
extended_coord,
@@ -332,7 +344,7 @@ def serialize(self) -> dict:
332344

333345
@classmethod
334346
def deserialize(cls, data: dict) -> "BaseAtomicModel":
335-
data = copy.deepcopy(data)
347+
data = data.copy()
336348
variables = data.pop("@variables", None)
337349
variables = (
338350
{"out_bias": None, "out_std": None} if variables is None else variables
@@ -354,7 +366,7 @@ def compute_or_load_stat(
354366
self,
355367
merged: Union[Callable[[], list[dict]], list[dict]],
356368
stat_file_path: Optional[DPPath] = None,
357-
):
369+
) -> NoReturn:
358370
"""
359371
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
360372
@@ -377,7 +389,7 @@ def compute_or_load_out_stat(
377389
self,
378390
merged: Union[Callable[[], list[dict]], list[dict]],
379391
stat_file_path: Optional[DPPath] = None,
380-
):
392+
) -> None:
381393
"""
382394
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
383395
@@ -457,7 +469,6 @@ def change_out_bias(
457469
model_forward=self._get_forward_wrapper_func(),
458470
rcond=self.rcond,
459471
preset_bias=self.preset_out_bias,
460-
atomic_output=self.atomic_output_def(),
461472
)
462473
self._store_out_stat(delta_bias, out_std, add=True)
463474
elif bias_adjust_mode == "set-by-statistic":
@@ -468,7 +479,8 @@ def change_out_bias(
468479
stat_file_path=stat_file_path,
469480
rcond=self.rcond,
470481
preset_bias=self.preset_out_bias,
471-
atomic_output=self.atomic_output_def(),
482+
stats_distinguish_types=self.get_compute_stats_distinguish_types(),
483+
intensive=self.get_intensive(),
472484
)
473485
self._store_out_stat(bias_out, std_out)
474486
else:
@@ -544,7 +556,7 @@ def _store_out_stat(
544556
out_bias: dict[str, paddle.Tensor],
545557
out_std: dict[str, paddle.Tensor],
546558
add: bool = False,
547-
):
559+
) -> None:
548560
ntypes = self.get_ntypes()
549561
out_bias_data = paddle.clone(self.out_bias)
550562
out_std_data = paddle.clone(self.out_std)

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(
6262
self.fitting_net = fitting
6363
super().init_out_stat()
6464
self.enable_eval_descriptor_hook = False
65+
self.enable_eval_fitting_last_layer_hook = False
6566
self.eval_descriptor_list = []
67+
self.eval_fitting_last_layer_list = []
6668

6769
# register 'type_map' as buffer
6870
def _string_to_array(s: str) -> list[int]:
@@ -112,16 +114,29 @@ def _string_to_array(s: str) -> list[int]:
112114
self.buffer_aparam_nall.name = "buffer_aparam_nall"
113115

114116
eval_descriptor_list: list[paddle.Tensor]
117+
eval_fitting_last_layer_list: list[paddle.Tensor]
115118

116119
def set_eval_descriptor_hook(self, enable: bool) -> None:
117120
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
118121
self.enable_eval_descriptor_hook = enable
119-
self.eval_descriptor_list = []
122+
# = [] does not work; See #4533
123+
self.eval_descriptor_list.clear()
120124

121125
def eval_descriptor(self) -> paddle.Tensor:
122126
"""Evaluate the descriptor."""
123127
return paddle.concat(self.eval_descriptor_list)
124128

129+
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
130+
"""Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list."""
131+
self.enable_eval_fitting_last_layer_hook = enable
132+
self.fitting_net.set_return_middle_output(enable)
133+
# = [] does not work; See #4533
134+
self.eval_fitting_last_layer_list.clear()
135+
136+
def eval_fitting_last_layer(self) -> paddle.Tensor:
137+
"""Evaluate the fitting last layer output."""
138+
return paddle.concat(self.eval_fitting_last_layer_list)
139+
125140
def fitting_output_def(self) -> FittingOutputDef:
126141
"""Get the output def of the fitting net."""
127142
return (
@@ -250,7 +265,7 @@ def forward_atomic(
250265
mapping: Optional[paddle.Tensor] = None,
251266
fparam: Optional[paddle.Tensor] = None,
252267
aparam: Optional[paddle.Tensor] = None,
253-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
268+
comm_dict: Optional[list[paddle.Tensor]] = None,
254269
) -> dict[str, paddle.Tensor]:
255270
"""Return atomic prediction.
256271
@@ -288,7 +303,7 @@ def forward_atomic(
288303
)
289304
assert descriptor is not None
290305
if self.enable_eval_descriptor_hook:
291-
self.eval_descriptor_list.append(descriptor)
306+
self.eval_descriptor_list.append(descriptor.detach())
292307
# energy, force
293308
fit_ret = self.fitting_net(
294309
descriptor,
@@ -299,6 +314,13 @@ def forward_atomic(
299314
fparam=fparam,
300315
aparam=aparam,
301316
)
317+
if self.enable_eval_fitting_last_layer_hook:
318+
assert "middle_output" in fit_ret, (
319+
"eval_fitting_last_layer not supported for this fitting net!"
320+
)
321+
self.eval_fitting_last_layer_list.append(
322+
fit_ret.pop("middle_output").detach()
323+
)
302324
return fit_ret
303325

304326
def get_out_bias(self) -> paddle.Tensor:
@@ -343,6 +365,9 @@ def wrapped_sampler():
343365
return sampled
344366

345367
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
368+
self.fitting_net.compute_input_stats(
369+
wrapped_sampler, protection=self.data_stat_protect
370+
)
346371
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
347372

348373
def get_dim_fparam(self) -> int:

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def forward(
596596
extended_atype: paddle.Tensor,
597597
nlist: paddle.Tensor,
598598
mapping: Optional[paddle.Tensor] = None,
599-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
599+
comm_dict: Optional[list[paddle.Tensor]] = None,
600600
):
601601
"""Compute the descriptor.
602602

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def forward(
712712
extended_atype: paddle.Tensor,
713713
nlist: paddle.Tensor,
714714
mapping: Optional[paddle.Tensor] = None,
715-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
715+
comm_dict: Optional[list[paddle.Tensor]] = None,
716716
):
717717
"""Compute the descriptor.
718718
@@ -747,7 +747,7 @@ def forward(
747747
748748
"""
749749
# cast the input to internal precsion
750-
extended_coord = extended_coord.to(dtype=self.prec)
750+
extended_coord = extended_coord.astype(dtype=self.prec)
751751

752752
use_three_body = self.use_three_body
753753
nframes, nloc, nnei = nlist.shape
@@ -798,14 +798,15 @@ def forward(
798798
assert self.tebd_transform is not None
799799
g1 = g1 + self.tebd_transform(g1_inp)
800800
# mapping g1
801-
if comm_dict is None:
802-
assert mapping is not None
801+
if comm_dict is None or len(comm_dict) == 0:
802+
if paddle.in_dynamic_mode():
803+
assert mapping is not None
803804
mapping_ext = (
804805
mapping.reshape([nframes, nall])
805806
.unsqueeze(-1)
806807
.expand([-1, -1, g1.shape[-1]])
807808
)
808-
g1_ext = paddle.take_along_axis(g1, mapping_ext, 1)
809+
g1_ext = paddle.take_along_axis(g1, mapping_ext, 1, broadcast=False)
809810
g1 = g1_ext
810811
# repformer
811812
g1, g2, h2, rot_mat, sw = self.repformers(
@@ -823,11 +824,11 @@ def forward(
823824
if self.concat_output_tebd:
824825
g1 = paddle.concat([g1, g1_inp], axis=-1)
825826
return (
826-
g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
827-
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
828-
g2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
829-
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
830-
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
827+
g1.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
828+
rot_mat.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
829+
g2.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
830+
h2.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
831+
sw.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
831832
)
832833

833834
@classmethod

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def forward(
457457
extended_atype: paddle.Tensor,
458458
nlist: paddle.Tensor,
459459
mapping: Optional[paddle.Tensor] = None,
460-
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
460+
comm_dict: Optional[list[paddle.Tensor]] = None,
461461
):
462462
"""Compute the descriptor.
463463

0 commit comments

Comments
 (0)