Skip to content

Commit 4e65d8b

Browse files
authored
feat(pt): add dpa3 alpha descriptor (deepmodeling#4476)
This PR is an early experimental preview version of DPA3. Significant changes may occur in subsequent updates. Please use with caution.
2 parents 104fc36 + 1309e26 commit 4e65d8b

File tree

12 files changed

+2478
-31
lines changed

12 files changed

+2478
-31
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
4+
class RepFlowArgs:
5+
def __init__(
6+
self,
7+
n_dim: int = 128,
8+
e_dim: int = 64,
9+
a_dim: int = 64,
10+
nlayers: int = 6,
11+
e_rcut: float = 6.0,
12+
e_rcut_smth: float = 5.0,
13+
e_sel: int = 120,
14+
a_rcut: float = 4.0,
15+
a_rcut_smth: float = 3.5,
16+
a_sel: int = 20,
17+
a_compress_rate: int = 0,
18+
axis_neuron: int = 4,
19+
update_angle: bool = True,
20+
update_style: str = "res_residual",
21+
update_residual: float = 0.1,
22+
update_residual_init: str = "const",
23+
) -> None:
24+
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
25+
26+
Parameters
27+
----------
28+
n_dim : int, optional
29+
The dimension of node representation.
30+
e_dim : int, optional
31+
The dimension of edge representation.
32+
a_dim : int, optional
33+
The dimension of angle representation.
34+
nlayers : int, optional
35+
Number of repflow layers.
36+
e_rcut : float, optional
37+
The edge cut-off radius.
38+
e_rcut_smth : float, optional
39+
Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth.
40+
e_sel : int, optional
41+
Maximally possible number of selected edge neighbors.
42+
a_rcut : float, optional
43+
The angle cut-off radius.
44+
a_rcut_smth : float, optional
45+
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
46+
a_sel : int, optional
47+
Maximally possible number of selected angle neighbors.
48+
a_compress_rate : int, optional
49+
The compression rate for angular messages. The default value is 0, indicating no compression.
50+
If a non-zero integer c is provided, the node and edge dimensions will be compressed
51+
to n_dim/c and e_dim/2c, respectively, within the angular message.
52+
axis_neuron : int, optional
53+
The number of dimension of submatrix in the symmetrization ops.
54+
update_angle : bool, optional
55+
Where to update the angle rep. If not, only node and edge rep will be used.
56+
update_style : str, optional
57+
Style to update a representation.
58+
Supported options are:
59+
-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n)
60+
-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)
61+
-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n)
62+
where `r1`, `r2` ... `r3` are residual weights defined by `update_residual`
63+
and `update_residual_init`.
64+
update_residual : float, optional
65+
When update using residual mode, the initial std of residual vector weights.
66+
update_residual_init : str, optional
67+
When update using residual mode, the initialization mode of residual vector weights.
68+
"""
69+
self.n_dim = n_dim
70+
self.e_dim = e_dim
71+
self.a_dim = a_dim
72+
self.nlayers = nlayers
73+
self.e_rcut = e_rcut
74+
self.e_rcut_smth = e_rcut_smth
75+
self.e_sel = e_sel
76+
self.a_rcut = a_rcut
77+
self.a_rcut_smth = a_rcut_smth
78+
self.a_sel = a_sel
79+
self.a_compress_rate = a_compress_rate
80+
self.axis_neuron = axis_neuron
81+
self.update_angle = update_angle
82+
self.update_style = update_style
83+
self.update_residual = update_residual
84+
self.update_residual_init = update_residual_init
85+
86+
def __getitem__(self, key):
87+
if hasattr(self, key):
88+
return getattr(self, key)
89+
else:
90+
raise KeyError(key)
91+
92+
def serialize(self) -> dict:
93+
return {
94+
"n_dim": self.n_dim,
95+
"e_dim": self.e_dim,
96+
"a_dim": self.a_dim,
97+
"nlayers": self.nlayers,
98+
"e_rcut": self.e_rcut,
99+
"e_rcut_smth": self.e_rcut_smth,
100+
"e_sel": self.e_sel,
101+
"a_rcut": self.a_rcut,
102+
"a_rcut_smth": self.a_rcut_smth,
103+
"a_sel": self.a_sel,
104+
"a_compress_rate": self.a_compress_rate,
105+
"axis_neuron": self.axis_neuron,
106+
"update_angle": self.update_angle,
107+
"update_style": self.update_style,
108+
"update_residual": self.update_residual,
109+
"update_residual_init": self.update_residual_init,
110+
}
111+
112+
@classmethod
113+
def deserialize(cls, data: dict) -> "RepFlowArgs":
114+
return cls(**data)

deepmd/pt/loss/ener.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
187187
)
188188
# more_loss['log_keys'].append('rmse_e')
189189
else: # use l1 and for all atoms
190+
energy_pred = energy_pred * atom_norm
191+
energy_label = energy_label * atom_norm
190192
l1_ener_loss = F.l1_loss(
191193
energy_pred.reshape(-1),
192194
energy_label.reshape(-1),
193-
reduction="sum",
195+
reduction="mean",
194196
)
195197
loss += pref_e * l1_ener_loss
196198
more_loss["mae_e"] = self.display_if_exist(
197-
F.l1_loss(
198-
energy_pred.reshape(-1),
199-
energy_label.reshape(-1),
200-
reduction="mean",
201-
).detach(),
199+
l1_ener_loss.detach(),
202200
find_energy,
203201
)
204202
# more_loss['log_keys'].append('rmse_e')
205-
if mae:
206-
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
207-
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
208-
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
209-
more_loss["mae_e_all"] = self.display_if_exist(
210-
mae_e_all.detach(), find_energy
211-
)
203+
# if mae:
204+
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
205+
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
206+
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
207+
# more_loss["mae_e_all"] = self.display_if_exist(
208+
# mae_e_all.detach(), find_energy
209+
# )
212210

213211
if (
214212
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
@@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
241239
rmse_f.detach(), find_force
242240
)
243241
else:
244-
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
242+
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")
245243
more_loss["mae_f"] = self.display_if_exist(
246-
l1_force_loss.mean().detach(), find_force
244+
l1_force_loss.detach(), find_force
247245
)
248-
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
246+
# l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
249247
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
250-
if mae:
251-
mae_f = torch.mean(torch.abs(diff_f))
252-
more_loss["mae_f"] = self.display_if_exist(
253-
mae_f.detach(), find_force
254-
)
248+
# if mae:
249+
# mae_f = torch.mean(torch.abs(diff_f))
250+
# more_loss["mae_f"] = self.display_if_exist(
251+
# mae_f.detach(), find_force
252+
# )
255253

256254
if self.has_pf and "atom_pref" in label:
257255
atom_pref = label["atom_pref"]
@@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
297295
if self.has_v and "virial" in model_pred and "virial" in label:
298296
find_virial = label.get("find_virial", 0.0)
299297
pref_v = pref_v * find_virial
298+
virial_label = label["virial"]
299+
virial_pred = model_pred["virial"].reshape(-1, 9)
300300
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
301-
l2_virial_loss = torch.mean(torch.square(diff_v))
302-
if not self.inference:
303-
more_loss["l2_virial_loss"] = self.display_if_exist(
304-
l2_virial_loss.detach(), find_virial
301+
if not self.use_l1_all:
302+
l2_virial_loss = torch.mean(torch.square(diff_v))
303+
if not self.inference:
304+
more_loss["l2_virial_loss"] = self.display_if_exist(
305+
l2_virial_loss.detach(), find_virial
306+
)
307+
loss += atom_norm * (pref_v * l2_virial_loss)
308+
rmse_v = l2_virial_loss.sqrt() * atom_norm
309+
more_loss["rmse_v"] = self.display_if_exist(
310+
rmse_v.detach(), find_virial
311+
)
312+
else:
313+
l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean")
314+
more_loss["mae_v"] = self.display_if_exist(
315+
l1_virial_loss.detach(), find_virial
305316
)
306-
loss += atom_norm * (pref_v * l2_virial_loss)
307-
rmse_v = l2_virial_loss.sqrt() * atom_norm
308-
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
309-
if mae:
310-
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
311-
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
317+
loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION)
318+
# if mae:
319+
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
320+
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
312321

313322
if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
314323
atom_ener = model_pred["atom_energy"]

deepmd/pt/model/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from .dpa2 import (
1414
DescrptDPA2,
1515
)
16+
from .dpa3 import (
17+
DescrptDPA3,
18+
)
1619
from .env_mat import (
1720
prod_env_mat,
1821
)
@@ -49,6 +52,7 @@
4952
"DescrptBlockSeTTebd",
5053
"DescrptDPA1",
5154
"DescrptDPA2",
55+
"DescrptDPA3",
5256
"DescrptHybrid",
5357
"DescrptSeA",
5458
"DescrptSeAttenV2",

0 commit comments

Comments
 (0)