Skip to content

Commit dbdb9b9

Browse files
authored
feat: dpmodel energy loss & consistent tests (#4531)
Fix #4105. Fix #4429. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new energy loss calculation framework with support for multiple machine learning backends. - Added serialization and deserialization capabilities for loss modules. - Added a new class `EnergyLoss` for computing energy-related loss metrics. - **Documentation** - Added SPDX license identifiers to multiple files. - Included docstrings for new classes and methods. - **Tests** - Implemented comprehensive test suite for energy loss functions across different platforms (TensorFlow, PyTorch, Paddle, JAX). - Introduced a new test class `TestEner` for evaluating energy loss functions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 03819a0 commit dbdb9b9

File tree

11 files changed

+931
-0
lines changed

11 files changed

+931
-0
lines changed

deepmd/dpmodel/loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

deepmd/dpmodel/loss/ener.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
)
5+
6+
import array_api_compat
7+
import numpy as np
8+
9+
from deepmd.dpmodel.loss.loss import (
10+
Loss,
11+
)
12+
from deepmd.utils.data import (
13+
DataRequirementItem,
14+
)
15+
from deepmd.utils.version import (
16+
check_version_compatibility,
17+
)
18+
19+
20+
class EnergyLoss(Loss):
21+
def __init__(
22+
self,
23+
starter_learning_rate: float,
24+
start_pref_e: float = 0.02,
25+
limit_pref_e: float = 1.00,
26+
start_pref_f: float = 1000,
27+
limit_pref_f: float = 1.00,
28+
start_pref_v: float = 0.0,
29+
limit_pref_v: float = 0.0,
30+
start_pref_ae: float = 0.0,
31+
limit_pref_ae: float = 0.0,
32+
start_pref_pf: float = 0.0,
33+
limit_pref_pf: float = 0.0,
34+
relative_f: Optional[float] = None,
35+
enable_atom_ener_coeff: bool = False,
36+
start_pref_gf: float = 0.0,
37+
limit_pref_gf: float = 0.0,
38+
numb_generalized_coord: int = 0,
39+
**kwargs,
40+
) -> None:
41+
self.starter_learning_rate = starter_learning_rate
42+
self.start_pref_e = start_pref_e
43+
self.limit_pref_e = limit_pref_e
44+
self.start_pref_f = start_pref_f
45+
self.limit_pref_f = limit_pref_f
46+
self.start_pref_v = start_pref_v
47+
self.limit_pref_v = limit_pref_v
48+
self.start_pref_ae = start_pref_ae
49+
self.limit_pref_ae = limit_pref_ae
50+
self.start_pref_pf = start_pref_pf
51+
self.limit_pref_pf = limit_pref_pf
52+
self.relative_f = relative_f
53+
self.enable_atom_ener_coeff = enable_atom_ener_coeff
54+
self.start_pref_gf = start_pref_gf
55+
self.limit_pref_gf = limit_pref_gf
56+
self.numb_generalized_coord = numb_generalized_coord
57+
self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0
58+
self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0
59+
self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0
60+
self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0
61+
self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0
62+
self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0
63+
if self.has_gf and self.numb_generalized_coord < 1:
64+
raise RuntimeError(
65+
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
66+
)
67+
68+
def call(
69+
self,
70+
learning_rate: float,
71+
natoms: int,
72+
model_dict: dict[str, np.ndarray],
73+
label_dict: dict[str, np.ndarray],
74+
) -> dict[str, np.ndarray]:
75+
"""Calculate loss from model results and labeled results."""
76+
energy = model_dict["energy"]
77+
force = model_dict["force"]
78+
virial = model_dict["virial"]
79+
atom_ener = model_dict["atom_ener"]
80+
energy_hat = label_dict["energy"]
81+
force_hat = label_dict["force"]
82+
virial_hat = label_dict["virial"]
83+
atom_ener_hat = label_dict["atom_ener"]
84+
atom_pref = label_dict["atom_pref"]
85+
find_energy = label_dict["find_energy"]
86+
find_force = label_dict["find_force"]
87+
find_virial = label_dict["find_virial"]
88+
find_atom_ener = label_dict["find_atom_ener"]
89+
find_atom_pref = label_dict["find_atom_pref"]
90+
xp = array_api_compat.array_namespace(
91+
energy,
92+
force,
93+
virial,
94+
atom_ener,
95+
energy_hat,
96+
force_hat,
97+
virial_hat,
98+
atom_ener_hat,
99+
atom_pref,
100+
)
101+
102+
if self.enable_atom_ener_coeff:
103+
# when ener_coeff (\nu) is defined, the energy is defined as
104+
# E = \sum_i \nu_i E_i
105+
# instead of the sum of atomic energies.
106+
#
107+
# A case is that we want to train reaction energy
108+
# A + B -> C + D
109+
# E = - E(A) - E(B) + E(C) + E(D)
110+
# A, B, C, D could be put far away from each other
111+
atom_ener_coeff = label_dict["atom_ener_coeff"]
112+
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
113+
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
114+
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
115+
force_reshape = xp.reshape(force, [-1])
116+
force_hat_reshape = xp.reshape(force_hat, [-1])
117+
diff_f = force_hat_reshape - force_reshape
118+
else:
119+
diff_f = None
120+
121+
if self.relative_f is not None:
122+
force_hat_3 = xp.reshape(force_hat, [-1, 3])
123+
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f
124+
diff_f_3 = xp.reshape(diff_f, [-1, 3])
125+
diff_f_3 = diff_f_3 / norm_f
126+
diff_f = xp.reshape(diff_f_3, [-1])
127+
128+
atom_norm = 1.0 / natoms
129+
atom_norm_ener = 1.0 / natoms
130+
lr_ratio = learning_rate / self.starter_learning_rate
131+
pref_e = find_energy * (
132+
self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * lr_ratio
133+
)
134+
pref_f = find_force * (
135+
self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * lr_ratio
136+
)
137+
pref_v = find_virial * (
138+
self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * lr_ratio
139+
)
140+
pref_ae = find_atom_ener * (
141+
self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * lr_ratio
142+
)
143+
pref_pf = find_atom_pref * (
144+
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio
145+
)
146+
147+
l2_loss = 0
148+
more_loss = {}
149+
if self.has_e:
150+
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
151+
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
152+
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
153+
if self.has_f:
154+
l2_force_loss = xp.mean(xp.square(diff_f))
155+
l2_loss += pref_f * l2_force_loss
156+
more_loss["l2_force_loss"] = self.display_if_exist(
157+
l2_force_loss, find_force
158+
)
159+
if self.has_v:
160+
virial_reshape = xp.reshape(virial, [-1])
161+
virial_hat_reshape = xp.reshape(virial_hat, [-1])
162+
l2_virial_loss = xp.mean(
163+
xp.square(virial_hat_reshape - virial_reshape),
164+
)
165+
l2_loss += atom_norm * (pref_v * l2_virial_loss)
166+
more_loss["l2_virial_loss"] = self.display_if_exist(
167+
l2_virial_loss, find_virial
168+
)
169+
if self.has_ae:
170+
atom_ener_reshape = xp.reshape(atom_ener, [-1])
171+
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
172+
l2_atom_ener_loss = xp.mean(
173+
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
174+
)
175+
l2_loss += pref_ae * l2_atom_ener_loss
176+
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
177+
l2_atom_ener_loss, find_atom_ener
178+
)
179+
if self.has_pf:
180+
atom_pref_reshape = xp.reshape(atom_pref, [-1])
181+
l2_pref_force_loss = xp.mean(
182+
xp.multiply(xp.square(diff_f), atom_pref_reshape),
183+
)
184+
l2_loss += pref_pf * l2_pref_force_loss
185+
more_loss["l2_pref_force_loss"] = self.display_if_exist(
186+
l2_pref_force_loss, find_atom_pref
187+
)
188+
if self.has_gf:
189+
find_drdq = label_dict["find_drdq"]
190+
drdq = label_dict["drdq"]
191+
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
192+
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
193+
drdq_reshape = xp.reshape(
194+
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
195+
)
196+
gen_force_hat = xp.einsum(
197+
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
198+
)
199+
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
200+
diff_gen_force = gen_force_hat - gen_force
201+
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force))
202+
pref_gf = find_drdq * (
203+
self.limit_pref_gf
204+
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
205+
)
206+
l2_loss += pref_gf * l2_gen_force_loss
207+
more_loss["l2_gen_force_loss"] = self.display_if_exist(
208+
l2_gen_force_loss, find_drdq
209+
)
210+
211+
self.l2_l = l2_loss
212+
self.l2_more = more_loss
213+
return l2_loss, more_loss
214+
215+
@property
216+
def label_requirement(self) -> list[DataRequirementItem]:
217+
"""Return data label requirements needed for this loss calculation."""
218+
label_requirement = []
219+
if self.has_e:
220+
label_requirement.append(
221+
DataRequirementItem(
222+
"energy",
223+
ndof=1,
224+
atomic=False,
225+
must=False,
226+
high_prec=True,
227+
)
228+
)
229+
if self.has_f:
230+
label_requirement.append(
231+
DataRequirementItem(
232+
"force",
233+
ndof=3,
234+
atomic=True,
235+
must=False,
236+
high_prec=False,
237+
)
238+
)
239+
if self.has_v:
240+
label_requirement.append(
241+
DataRequirementItem(
242+
"virial",
243+
ndof=9,
244+
atomic=False,
245+
must=False,
246+
high_prec=False,
247+
)
248+
)
249+
if self.has_ae:
250+
label_requirement.append(
251+
DataRequirementItem(
252+
"atom_ener",
253+
ndof=1,
254+
atomic=True,
255+
must=False,
256+
high_prec=False,
257+
)
258+
)
259+
if self.has_pf:
260+
label_requirement.append(
261+
DataRequirementItem(
262+
"atom_pref",
263+
ndof=1,
264+
atomic=True,
265+
must=False,
266+
high_prec=False,
267+
repeat=3,
268+
)
269+
)
270+
if self.has_gf > 0:
271+
label_requirement.append(
272+
DataRequirementItem(
273+
"drdq",
274+
ndof=self.numb_generalized_coord * 3,
275+
atomic=True,
276+
must=False,
277+
high_prec=False,
278+
)
279+
)
280+
if self.enable_atom_ener_coeff:
281+
label_requirement.append(
282+
DataRequirementItem(
283+
"atom_ener_coeff",
284+
ndof=1,
285+
atomic=True,
286+
must=False,
287+
high_prec=False,
288+
default=1.0,
289+
)
290+
)
291+
return label_requirement
292+
293+
def serialize(self) -> dict:
294+
"""Serialize the loss module.
295+
296+
Returns
297+
-------
298+
dict
299+
The serialized loss module
300+
"""
301+
return {
302+
"@class": "EnergyLoss",
303+
"@version": 1,
304+
"starter_learning_rate": self.starter_learning_rate,
305+
"start_pref_e": self.start_pref_e,
306+
"limit_pref_e": self.limit_pref_e,
307+
"start_pref_f": self.start_pref_f,
308+
"limit_pref_f": self.limit_pref_f,
309+
"start_pref_v": self.start_pref_v,
310+
"limit_pref_v": self.limit_pref_v,
311+
"start_pref_ae": self.start_pref_ae,
312+
"limit_pref_ae": self.limit_pref_ae,
313+
"start_pref_pf": self.start_pref_pf,
314+
"limit_pref_pf": self.limit_pref_pf,
315+
"relative_f": self.relative_f,
316+
"enable_atom_ener_coeff": self.enable_atom_ener_coeff,
317+
"start_pref_gf": self.start_pref_gf,
318+
"limit_pref_gf": self.limit_pref_gf,
319+
"numb_generalized_coord": self.numb_generalized_coord,
320+
}
321+
322+
@classmethod
323+
def deserialize(cls, data: dict) -> "Loss":
324+
"""Deserialize the loss module.
325+
326+
Parameters
327+
----------
328+
data : dict
329+
The serialized loss module
330+
331+
Returns
332+
-------
333+
Loss
334+
The deserialized loss module
335+
"""
336+
data = data.copy()
337+
check_version_compatibility(data.pop("@version"), 1, 1)
338+
data.pop("@class")
339+
return cls(**data)

0 commit comments

Comments
 (0)