Skip to content

Commit 63cc460

Browse files
HydrogenSulfateiProzdChengqian-Zhangpre-commit-ci[bot]ChiahsinChu
authored
pd: support dpa3 with paddle backend (#4701)
support dpa3 with paddle backend(eager mode) ### 1. training curve ![pt_vs_pd](https://github.com/user-attachments/assets/22f4681e-a464-41cb-9b1b-16ed20112563) ### 2. accuracy <details> <summary>torch</summary> ![image](https://github.com/user-attachments/assets/c737ed30-0108-43f3-9d0e-7ae289db1498) </details> <details> <summary>paddle(slightly better than torch)</summary> ![image](https://github.com/user-attachments/assets/5e75abb8-f3be-46ce-adc1-eaf453cedcba) </details> ### 3. The main modifications in this PR include: 1. Added DPA-3 code and related modules based on the Paddle backend. 2. Added the EnergyHessianStdLoss module based on the Paddle backend. 3. Discovered that Paddle’s ParameterList does not support assignment of Tensors using the equals sign. Therefore, I added support for this feature at <PaddlePaddle/Paddle#72190>. However, considering version compatibility, deepmd still uses paddle.assign for assignments. 4. Fixed an issue in env_mat_stat.py where the return type was Tensor instead of float. 5. The SiLUT used APIs from the numpy series that do not support paddle.Tensor, so I replaced them with Paddle’s native APIs. Additionally, to temporarily bypass issues with dynamic-to-static control flow, I changed the if-else branch in SiLUT.forward to a single branch. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new descriptor, DPA3, for advanced molecular simulations, including its integration and public availability. - Added support for a new graph-based neural network layer and descriptor block for RepFlow calculations. - Enabled Hessian loss computation for enhanced training capabilities. - Added new learning rate utility. - **Bug Fixes** - Improved tensor shape handling and assignments for better compatibility and stability. - **Tests** - Added comprehensive tests for the new DPA3 descriptor, including consistency, JIT, and multitask scenarios. - Expanded test coverage for model permutation and smoothness with DPA3. - Enhanced tests for DPA2 with CINN compiler support. - **Refactor** - Standardized tensor shape definitions and updated method signatures for improved clarity and type safety. - **Chores** - Updated public interfaces to include new features and descriptors. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: root <2000011006@stu.pku.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com> Co-authored-by: Jia-Xin Zhu <53895049+ChiahsinChu@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent a1777b7 commit 63cc460

29 files changed

+2814
-147
lines changed

deepmd/pd/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from .ener import (
3+
EnergyHessianStdLoss,
34
EnergyStdLoss,
45
)
56
from .loss import (
67
TaskLoss,
78
)
89

910
__all__ = [
11+
"EnergyHessianStdLoss",
1012
"EnergyStdLoss",
1113
"TaskLoss",
1214
]

deepmd/pd/loss/ener.py

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
use_huber=False,
5757
huber_delta=0.01,
5858
**kwargs,
59-
):
59+
) -> None:
6060
r"""Construct a layer to compute loss on energy, force and virial.
6161
6262
Parameters
@@ -287,9 +287,9 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
287287
rmse_f.detach(), find_force
288288
)
289289
else:
290-
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
290+
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")
291291
more_loss["mae_f"] = self.display_if_exist(
292-
l1_force_loss.mean().detach(), find_force
292+
l1_force_loss.detach(), find_force
293293
)
294294
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
295295
loss += (pref_f * l1_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)
@@ -324,20 +324,19 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
324324
drdq_reshape = drdq.reshape(
325325
[-1, natoms * 3, self.numb_generalized_coord]
326326
)
327+
gen_force_label = paddle.einsum(
328+
"bij,bi->bj", drdq_reshape, force_label_reshape_nframes
329+
)
330+
# gen_force_label = (
331+
# drdq_reshape * force_label_reshape_nframes.unsqueeze(-1)
332+
# ).sum([-2])
327333

328-
# gen_force_label = paddle.einsum(
329-
# "bij,bi->bj", drdq_reshape, force_label_reshape_nframes
330-
# )
331-
gen_force_label = (
332-
drdq_reshape * force_label_reshape_nframes.unsqueeze(-1)
333-
).sum([-2])
334-
335-
# gen_force = paddle.einsum(
336-
# "bij,bi->bj", drdq_reshape, force_reshape_nframes
337-
# )
338-
gen_force = (drdq_reshape * force_reshape_nframes.unsqueeze(-1)).sum(
339-
[-2]
334+
gen_force = paddle.einsum(
335+
"bij,bi->bj", drdq_reshape, force_reshape_nframes
340336
)
337+
# gen_force = (drdq_reshape * force_reshape_nframes.unsqueeze(-1)).sum(
338+
# [-2]
339+
# )
341340

342341
diff_gen_force = gen_force_label - gen_force
343342
l2_gen_force_loss = paddle.square(diff_gen_force).mean()
@@ -534,3 +533,75 @@ def deserialize(cls, data: dict) -> "TaskLoss":
534533
check_version_compatibility(data.pop("@version"), 2, 1)
535534
data.pop("@class")
536535
return cls(**data)
536+
537+
538+
class EnergyHessianStdLoss(EnergyStdLoss):
539+
def __init__(
540+
self,
541+
start_pref_h=0.0,
542+
limit_pref_h=0.0,
543+
**kwargs,
544+
):
545+
r"""Enable the layer to compute loss on hessian.
546+
547+
Parameters
548+
----------
549+
start_pref_h : float
550+
The prefactor of hessian loss at the start of the training.
551+
limit_pref_h : float
552+
The prefactor of hessian loss at the end of the training.
553+
**kwargs
554+
Other keyword arguments.
555+
"""
556+
super().__init__(**kwargs)
557+
self.has_h = (start_pref_h != 0.0 and limit_pref_h != 0.0) or self.inference
558+
559+
self.start_pref_h = start_pref_h
560+
self.limit_pref_h = limit_pref_h
561+
562+
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
563+
model_pred, loss, more_loss = super().forward(
564+
input_dict, model, label, natoms, learning_rate, mae=mae
565+
)
566+
coef = learning_rate / self.starter_learning_rate
567+
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef
568+
569+
if self.has_h and "hessian" in model_pred and "hessian" in label:
570+
find_hessian = label.get("find_hessian", 0.0)
571+
pref_h = pref_h * find_hessian
572+
diff_h = label["hessian"].reshape(
573+
[-1],
574+
) - model_pred["hessian"].reshape(
575+
[-1],
576+
)
577+
l2_hessian_loss = paddle.mean(paddle.square(diff_h))
578+
if not self.inference:
579+
more_loss["l2_hessian_loss"] = self.display_if_exist(
580+
l2_hessian_loss.detach(), find_hessian
581+
)
582+
loss += pref_h * l2_hessian_loss
583+
rmse_h = l2_hessian_loss.sqrt()
584+
more_loss["rmse_h"] = self.display_if_exist(rmse_h.detach(), find_hessian)
585+
if mae:
586+
mae_h = paddle.mean(paddle.abs(diff_h))
587+
more_loss["mae_h"] = self.display_if_exist(mae_h.detach(), find_hessian)
588+
589+
if not self.inference:
590+
more_loss["rmse"] = paddle.sqrt(loss.detach())
591+
return model_pred, loss, more_loss
592+
593+
@property
594+
def label_requirement(self) -> list[DataRequirementItem]:
595+
"""Add hessian label requirement needed for this loss calculation."""
596+
label_requirement = super().label_requirement
597+
if self.has_h:
598+
label_requirement.append(
599+
DataRequirementItem(
600+
"hessian",
601+
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms
602+
atomic=True,
603+
must=False,
604+
high_prec=False,
605+
)
606+
)
607+
return label_requirement

deepmd/pd/model/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from .dpa2 import (
1313
DescrptDPA2,
1414
)
15+
from .dpa3 import (
16+
DescrptDPA3,
17+
)
1518
from .env_mat import (
1619
prod_env_mat,
1720
)
@@ -39,6 +42,7 @@
3942
"DescrptBlockSeTTebd",
4043
"DescrptDPA1",
4144
"DescrptDPA2",
45+
"DescrptDPA3",
4246
"DescrptSeA",
4347
"DescrptSeAttenV2",
4448
"DescrptSeTTebd",

deepmd/pd/model/descriptor/descriptor.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
)
77
from typing import (
88
Callable,
9+
NoReturn,
910
Optional,
1011
Union,
1112
)
1213

1314
import paddle
1415

16+
from deepmd.pd.model.network.network import (
17+
TypeEmbedNet,
18+
)
1519
from deepmd.pd.utils import (
1620
env,
1721
)
@@ -99,7 +103,7 @@ def compute_input_stats(
99103
self,
100104
merged: Union[Callable[[], list[dict]], list[dict]],
101105
path: Optional[DPPath] = None,
102-
):
106+
) -> NoReturn:
103107
"""
104108
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
105109
@@ -122,7 +126,7 @@ def get_stats(self) -> dict[str, StatItem]:
122126
"""Get the statistics of the descriptor."""
123127
raise NotImplementedError
124128

125-
def share_params(self, base_class, shared_level, resume=False):
129+
def share_params(self, base_class, shared_level, resume=False) -> None:
126130
"""
127131
Share the parameters of self to the base_class with shared_level during multitask training.
128132
If not start from checkpoint (resume is False),
@@ -134,7 +138,10 @@ def share_params(self, base_class, shared_level, resume=False):
134138
if shared_level == 0:
135139
# link buffers
136140
if hasattr(self, "mean"):
137-
if not resume:
141+
if not resume and (
142+
not getattr(self, "set_stddev_constant", False)
143+
or not getattr(self, "set_davg_zero", False)
144+
):
138145
# in case of change params during resume
139146
base_env = EnvMatStatSe(base_class)
140147
base_env.stats = base_class.stats
@@ -172,6 +179,7 @@ def forward(
172179
extended_atype: paddle.Tensor,
173180
extended_atype_embd: Optional[paddle.Tensor] = None,
174181
mapping: Optional[paddle.Tensor] = None,
182+
type_embedding: Optional[paddle.Tensor] = None,
175183
):
176184
"""Calculate DescriptorBlock."""
177185
pass
@@ -185,7 +193,15 @@ def need_sorted_nlist_for_lower(self) -> bool:
185193
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
186194

187195

188-
def extend_descrpt_stat(des, type_map, des_with_stat=None):
196+
def make_default_type_embedding(
197+
ntypes,
198+
):
199+
aux = {}
200+
aux["tebd_dim"] = 8
201+
return TypeEmbedNet(ntypes, aux["tebd_dim"]), aux
202+
203+
204+
def extend_descrpt_stat(des, type_map, des_with_stat=None) -> None:
189205
r"""
190206
Extend the statistics of a descriptor block with types from newly provided `type_map`.
191207

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def enable_compression(
584584
The overflow check frequency
585585
"""
586586
# do some checks before the mocel compression process
587-
raise NotImplementedError("Model compression is not supported in paddle yet.")
587+
raise ValueError("Compression is already enabled.")
588588

589589
def forward(
590590
self,

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def share_params(self, base_class, shared_level, resume=False) -> None:
408408
# shared_level: 1
409409
# share all parameters in type_embedding
410410
elif shared_level == 1:
411-
self._modules["type_embedding"] = base_class._modules["type_embedding"]
411+
self._sub_layers["type_embedding"] = base_class._sub_layers[
412+
"type_embedding"
413+
]
412414
# Other shared levels
413415
else:
414416
raise NotImplementedError
@@ -899,4 +901,4 @@ def enable_compression(
899901
The overflow check frequency
900902
"""
901903
# do some checks before the mocel compression process
902-
raise NotImplementedError("enable_compression is not implemented yet")
904+
raise ValueError("Compression is already enabled.")

0 commit comments

Comments
 (0)