Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d2a46b7
update for dpa3_dynamic debug code
HydrogenSulfate Jul 20, 2025
8fe0df6
support DPA2/DPA3 inference
HydrogenSulfate Jul 22, 2025
c510815
fix code for view ops
HydrogenSulfate Jul 22, 2025
3effbdb
refine code
HydrogenSulfate Jul 22, 2025
38fb926
refine more code
HydrogenSulfate Jul 22, 2025
4d3f99c
fix typing
HydrogenSulfate Jul 22, 2025
d6a3b1e
update code
HydrogenSulfate Jul 22, 2025
55aeab3
update build_cc_pd.sh
HydrogenSulfate Jul 22, 2025
d967dfc
update file
HydrogenSulfate Jul 23, 2025
94e5b7b
update code
HydrogenSulfate Jul 23, 2025
a1ebcde
fix fetch name
HydrogenSulfate Jul 23, 2025
443230a
fix deeppotpd.cc
HydrogenSulfate Jul 23, 2025
2bd6a3f
Merge branch 'support_comm_2' of https://github.com/HydrogenSulfate/d…
HydrogenSulfate Jul 23, 2025
ef21ec1
update code
HydrogenSulfate Jul 23, 2025
8c106c7
update code for successfully run dpa2 C++ inference
HydrogenSulfate Aug 5, 2025
76d3e8e
update setup.py
HydrogenSulfate Aug 5, 2025
a4eeac4
update build_cc_pd.sh and fix
HydrogenSulfate Aug 5, 2025
038f4fe
fix yaml
HydrogenSulfate Aug 5, 2025
552eff8
Merge remote-tracking branch 'upstream/devel' into support_dpa2/3_inf…
HydrogenSulfate Aug 5, 2025
f36ef86
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2025
9cd2c4a
restore EAGER_COMP_OP_BLACK_LIST
HydrogenSulfate Aug 5, 2025
b0eb30e
clean code
HydrogenSulfate Aug 5, 2025
d20ef07
Merge branch 'support_dpa2/3_infernce' of https://github.com/Hydrogen…
HydrogenSulfate Aug 5, 2025
340596f
fix CMAKE
HydrogenSulfate Aug 5, 2025
75c56b1
fix review
HydrogenSulfate Aug 5, 2025
a02d164
fix
HydrogenSulfate Aug 5, 2025
1c03d45
update paddle inference code
HydrogenSulfate Aug 6, 2025
6f3c323
fix
HydrogenSulfate Aug 6, 2025
6f28b76
update UT files
HydrogenSulfate Aug 6, 2025
cb6d855
fix
HydrogenSulfate Aug 6, 2025
e698253
fix
HydrogenSulfate Aug 6, 2025
c2e337f
fix serialization
HydrogenSulfate Aug 6, 2025
89743ff
fix
HydrogenSulfate Aug 7, 2025
6c4309b
fix
HydrogenSulfate Aug 7, 2025
b929ae4
restore extended key to key
HydrogenSulfate Aug 7, 2025
6417f8c
Merge branch 'devel' into support_dpa2/3_infernce
HydrogenSulfate Aug 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions deepmd/pd/cxx_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
from types import (
ModuleType,
)


def load_library(module_name: str) -> tuple[bool, ModuleType]:
"""Load OP library and return the module if success.

Parameters
----------
module_name : str
Name of the module

Returns
-------
bool
Whether the library is loaded successfully
ModuleType
loaded custom operator module
"""
if importlib.util.find_spec(module_name) is not None:
module = importlib.import_module(module_name)
return True, module

return False, None


ENABLE_CUSTOMIZED_OP, paddle_ops_deepmd = load_library("deepmd_op_pd")

__all__ = [
"ENABLE_CUSTOMIZED_OP",
"paddle_ops_deepmd",
]
28 changes: 21 additions & 7 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def freeze(
InputSpec([1, 9], dtype="float64", name="box"), # box
None, # fparam
None, # aparam
True, # do_atomic_virial
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
],
full_graph=True,
)
Expand All @@ -388,14 +389,27 @@ def freeze(
model.forward_lower = paddle.jit.to_static(
model.forward_lower,
input_spec=[
InputSpec([1, -1, 3], dtype="float64", name="coord"), # extended_coord
InputSpec([1, -1], dtype="int32", name="atype"), # extended_atype
InputSpec([1, -1, -1], dtype="int32", name="nlist"), # nlist
InputSpec([1, -1], dtype="int64", name="mapping"), # mapping
InputSpec(
[-1, -1, 3], dtype="float64", name="extended_coord"
), # extended_coord
InputSpec(
[-1, -1], dtype="int32", name="extended_atype"
), # extended_atype
InputSpec([-1, -1, -1], dtype="int32", name="nlist"), # nlist
InputSpec([-1, -1], dtype="int64", name="mapping"), # mapping
None, # fparam
None, # aparam
True, # do_atomic_virial
None, # comm_dict
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
(
InputSpec([-1], "int64", name="send_list"),
InputSpec([-1], "int32", name="send_proc"),
InputSpec([-1], "int32", name="recv_proc"),
InputSpec([-1], "int32", name="send_num"),
InputSpec([-1], "int32", name="recv_num"),
InputSpec([-1], "int64", name="communicator"),
# InputSpec([1], "int64", name="has_spin"),
), # comm_dict
],
full_graph=True,
)
Expand Down
44 changes: 28 additions & 16 deletions deepmd/pd/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import copy
import logging
from typing import (
Callable,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -79,16 +79,18 @@ def __init__(
pair_exclude_types: list[tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
):
data_stat_protect: float = 1e-2,
) -> None:
paddle.nn.Layer.__init__(self)
BaseAtomicModel_.__init__(self)
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias
self.data_stat_protect = data_stat_protect

def init_out_stat(self):
def init_out_stat(self) -> None:
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: list[str] = list(self.fitting_output_def().keys())
Expand All @@ -104,7 +106,7 @@ def init_out_stat(self):
def set_out_bias(self, out_bias: paddle.Tensor) -> None:
self.out_bias = out_bias

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
Expand All @@ -124,10 +126,20 @@ def get_type_map(self) -> list[str]:
"""Get the type map."""
return self.type_map

def get_compute_stats_distinguish_types(self) -> bool:
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
return True

def get_intensive(self) -> bool:
"""Whether the fitting property is intensive."""
return False

def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
):
exclude_types: Optional[list[int]] = None,
) -> None:
if exclude_types is None:
exclude_types = []
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
Expand All @@ -137,7 +149,7 @@ def reinit_atom_exclude(
def reinit_pair_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
):
) -> None:
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
Expand Down Expand Up @@ -191,7 +203,7 @@ def forward_common_atomic(
mapping: Optional[paddle.Tensor] = None,
fparam: Optional[paddle.Tensor] = None,
aparam: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
) -> dict[str, paddle.Tensor]:
"""Common interface for atomic inference.

Expand Down Expand Up @@ -232,7 +244,7 @@ def forward_common_atomic(
if self.pair_excl is not None:
pair_mask = self.pair_excl(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = paddle.where(pair_mask == 1, nlist, -1)
nlist = paddle.where(pair_mask == 1, nlist, paddle.full_like(nlist, -1))

ext_atom_mask = self.make_atom_mask(extended_atype)
ret_dict = self.forward_atomic(
Expand Down Expand Up @@ -274,7 +286,7 @@ def forward(
mapping: Optional[paddle.Tensor] = None,
fparam: Optional[paddle.Tensor] = None,
aparam: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
) -> dict[str, paddle.Tensor]:
return self.forward_common_atomic(
extended_coord,
Expand Down Expand Up @@ -332,7 +344,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "BaseAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
variables = data.pop("@variables", None)
variables = (
{"out_bias": None, "out_std": None} if variables is None else variables
Expand All @@ -354,7 +366,7 @@ def compute_or_load_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
):
) -> NoReturn:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Expand All @@ -377,7 +389,7 @@ def compute_or_load_out_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
):
) -> None:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Expand Down Expand Up @@ -457,7 +469,6 @@ def change_out_bias(
model_forward=self._get_forward_wrapper_func(),
rcond=self.rcond,
preset_bias=self.preset_out_bias,
atomic_output=self.atomic_output_def(),
)
self._store_out_stat(delta_bias, out_std, add=True)
elif bias_adjust_mode == "set-by-statistic":
Expand All @@ -468,7 +479,8 @@ def change_out_bias(
stat_file_path=stat_file_path,
rcond=self.rcond,
preset_bias=self.preset_out_bias,
atomic_output=self.atomic_output_def(),
stats_distinguish_types=self.get_compute_stats_distinguish_types(),
intensive=self.get_intensive(),
)
self._store_out_stat(bias_out, std_out)
else:
Expand Down Expand Up @@ -544,7 +556,7 @@ def _store_out_stat(
out_bias: dict[str, paddle.Tensor],
out_std: dict[str, paddle.Tensor],
add: bool = False,
):
) -> None:
ntypes = self.get_ntypes()
out_bias_data = paddle.clone(self.out_bias)
out_std_data = paddle.clone(self.out_std)
Expand Down
31 changes: 28 additions & 3 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(
self.fitting_net = fitting
super().init_out_stat()
self.enable_eval_descriptor_hook = False
self.enable_eval_fitting_last_layer_hook = False
self.eval_descriptor_list = []
self.eval_fitting_last_layer_list = []

# register 'type_map' as buffer
def _string_to_array(s: str) -> list[int]:
Expand Down Expand Up @@ -112,16 +114,29 @@ def _string_to_array(s: str) -> list[int]:
self.buffer_aparam_nall.name = "buffer_aparam_nall"

eval_descriptor_list: list[paddle.Tensor]
eval_fitting_last_layer_list: list[paddle.Tensor]

def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.enable_eval_descriptor_hook = enable
self.eval_descriptor_list = []
# = [] does not work; See #4533
self.eval_descriptor_list.clear()

def eval_descriptor(self) -> paddle.Tensor:
"""Evaluate the descriptor."""
return paddle.concat(self.eval_descriptor_list)

def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
"""Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list."""
self.enable_eval_fitting_last_layer_hook = enable
self.fitting_net.set_return_middle_output(enable)
# = [] does not work; See #4533
self.eval_fitting_last_layer_list.clear()

def eval_fitting_last_layer(self) -> paddle.Tensor:
"""Evaluate the fitting last layer output."""
return paddle.concat(self.eval_fitting_last_layer_list)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return (
Expand Down Expand Up @@ -250,7 +265,7 @@ def forward_atomic(
mapping: Optional[paddle.Tensor] = None,
fparam: Optional[paddle.Tensor] = None,
aparam: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
) -> dict[str, paddle.Tensor]:
"""Return atomic prediction.

Expand Down Expand Up @@ -288,7 +303,7 @@ def forward_atomic(
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor)
self.eval_descriptor_list.append(descriptor.detach())
# energy, force
fit_ret = self.fitting_net(
descriptor,
Expand All @@ -299,6 +314,13 @@ def forward_atomic(
fparam=fparam,
aparam=aparam,
)
if self.enable_eval_fitting_last_layer_hook:
assert "middle_output" in fit_ret, (
"eval_fitting_last_layer not supported for this fitting net!"
)
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)
return fit_ret

def get_out_bias(self) -> paddle.Tensor:
Expand Down Expand Up @@ -343,6 +365,9 @@ def wrapped_sampler():
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def get_dim_fparam(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pd/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def forward(
extended_atype: paddle.Tensor,
nlist: paddle.Tensor,
mapping: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down
21 changes: 11 additions & 10 deletions deepmd/pd/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def forward(
extended_atype: paddle.Tensor,
nlist: paddle.Tensor,
mapping: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down Expand Up @@ -747,7 +747,7 @@ def forward(

"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
extended_coord = extended_coord.astype(dtype=self.prec)

use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
Expand Down Expand Up @@ -798,14 +798,15 @@ def forward(
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
if comm_dict is None:
assert mapping is not None
if comm_dict is None or len(comm_dict) == 0:
if paddle.in_dynamic_mode():
assert mapping is not None
mapping_ext = (
mapping.reshape([nframes, nall])
.unsqueeze(-1)
.expand([-1, -1, g1.shape[-1]])
)
g1_ext = paddle.take_along_axis(g1, mapping_ext, 1)
g1_ext = paddle.take_along_axis(g1, mapping_ext, 1, broadcast=False)
g1 = g1_ext
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
Expand All @@ -823,11 +824,11 @@ def forward(
if self.concat_output_tebd:
g1 = paddle.concat([g1, g1_inp], axis=-1)
return (
g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
g1.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
g2.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
h2.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
sw.astype(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pd/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def forward(
extended_atype: paddle.Tensor,
nlist: paddle.Tensor,
mapping: Optional[paddle.Tensor] = None,
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
comm_dict: Optional[list[paddle.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down
Loading
Loading