Skip to content

Commit 51f61cd

Browse files
fix(dpmodel): fix energy loss (#4765)
1. dpmodel has different model output keys; 2. in the current code, energy, force, virial, etc are necessary keys. 3. update more_loss <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Adjusted internal handling and reporting of energy, force, virial, and atomic energy loss metrics for improved clarity and consistency. - **Tests** - Updated test cases to align with the new prediction key structure, ensuring continued accuracy and reliability of energy loss evaluations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9c17b96 commit 51f61cd

File tree

2 files changed

+64
-65
lines changed

2 files changed

+64
-65
lines changed

deepmd/dpmodel/loss/ener.py

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def call(
9393
label_dict: dict[str, np.ndarray],
9494
) -> dict[str, np.ndarray]:
9595
"""Calculate loss from model results and labeled results."""
96-
energy = model_dict["energy"]
97-
force = model_dict["force"]
98-
virial = model_dict["virial"]
99-
atom_ener = model_dict["atom_ener"]
96+
energy = model_dict["energy_redu"]
97+
force = model_dict["energy_derv_r"]
98+
virial = model_dict["energy_derv_c_redu"]
99+
atom_ener = model_dict["energy"]
100100
energy_hat = label_dict["energy"]
101101
force_hat = label_dict["force"]
102102
virial_hat = label_dict["virial"]
@@ -177,7 +177,7 @@ def call(
177177
delta=self.huber_delta,
178178
)
179179
loss += pref_e * l_huber_loss
180-
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
180+
more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy)
181181
if self.has_f:
182182
l2_force_loss = xp.mean(xp.square(diff_f))
183183
if not self.use_huber:
@@ -189,9 +189,7 @@ def call(
189189
delta=self.huber_delta,
190190
)
191191
loss += pref_f * l_huber_loss
192-
more_loss["l2_force_loss"] = self.display_if_exist(
193-
l2_force_loss, find_force
194-
)
192+
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
195193
if self.has_v:
196194
virial_reshape = xp.reshape(virial, [-1])
197195
virial_hat_reshape = xp.reshape(virial_hat, [-1])
@@ -207,9 +205,7 @@ def call(
207205
delta=self.huber_delta,
208206
)
209207
loss += pref_v * l_huber_loss
210-
more_loss["l2_virial_loss"] = self.display_if_exist(
211-
l2_virial_loss, find_virial
212-
)
208+
more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial)
213209
if self.has_ae:
214210
atom_ener_reshape = xp.reshape(atom_ener, [-1])
215211
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
@@ -225,7 +221,7 @@ def call(
225221
delta=self.huber_delta,
226222
)
227223
loss += pref_ae * l_huber_loss
228-
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
224+
more_loss["rmse_ae"] = self.display_if_exist(
229225
l2_atom_ener_loss, find_atom_ener
230226
)
231227
if self.has_pf:
@@ -234,7 +230,7 @@ def call(
234230
xp.multiply(xp.square(diff_f), atom_pref_reshape),
235231
)
236232
loss += pref_pf * l2_pref_force_loss
237-
more_loss["l2_pref_force_loss"] = self.display_if_exist(
233+
more_loss["rmse_pf"] = self.display_if_exist(
238234
l2_pref_force_loss, find_atom_pref
239235
)
240236
if self.has_gf:
@@ -256,69 +252,63 @@ def call(
256252
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
257253
)
258254
loss += pref_gf * l2_gen_force_loss
259-
more_loss["l2_gen_force_loss"] = self.display_if_exist(
260-
l2_gen_force_loss, find_drdq
261-
)
255+
more_loss["rmse_gf"] = self.display_if_exist(l2_gen_force_loss, find_drdq)
262256

263257
self.l2_l = loss
258+
more_loss["rmse"] = xp.sqrt(loss)
264259
self.l2_more = more_loss
265260
return loss, more_loss
266261

267262
@property
268263
def label_requirement(self) -> list[DataRequirementItem]:
269264
"""Return data label requirements needed for this loss calculation."""
270265
label_requirement = []
271-
if self.has_e:
272-
label_requirement.append(
273-
DataRequirementItem(
274-
"energy",
275-
ndof=1,
276-
atomic=False,
277-
must=False,
278-
high_prec=True,
279-
)
266+
label_requirement.append(
267+
DataRequirementItem(
268+
"energy",
269+
ndof=1,
270+
atomic=False,
271+
must=False,
272+
high_prec=True,
280273
)
281-
if self.has_f:
282-
label_requirement.append(
283-
DataRequirementItem(
284-
"force",
285-
ndof=3,
286-
atomic=True,
287-
must=False,
288-
high_prec=False,
289-
)
274+
)
275+
label_requirement.append(
276+
DataRequirementItem(
277+
"force",
278+
ndof=3,
279+
atomic=True,
280+
must=False,
281+
high_prec=False,
290282
)
291-
if self.has_v:
292-
label_requirement.append(
293-
DataRequirementItem(
294-
"virial",
295-
ndof=9,
296-
atomic=False,
297-
must=False,
298-
high_prec=False,
299-
)
283+
)
284+
label_requirement.append(
285+
DataRequirementItem(
286+
"virial",
287+
ndof=9,
288+
atomic=False,
289+
must=False,
290+
high_prec=False,
300291
)
301-
if self.has_ae:
302-
label_requirement.append(
303-
DataRequirementItem(
304-
"atom_ener",
305-
ndof=1,
306-
atomic=True,
307-
must=False,
308-
high_prec=False,
309-
)
292+
)
293+
label_requirement.append(
294+
DataRequirementItem(
295+
"atom_ener",
296+
ndof=1,
297+
atomic=True,
298+
must=False,
299+
high_prec=False,
310300
)
311-
if self.has_pf:
312-
label_requirement.append(
313-
DataRequirementItem(
314-
"atom_pref",
315-
ndof=1,
316-
atomic=True,
317-
must=False,
318-
high_prec=False,
319-
repeat=3,
320-
)
301+
)
302+
label_requirement.append(
303+
DataRequirementItem(
304+
"atom_pref",
305+
ndof=1,
306+
atomic=True,
307+
must=False,
308+
high_prec=False,
309+
repeat=3,
321310
)
311+
)
322312
if self.has_gf > 0:
323313
label_requirement.append(
324314
DataRequirementItem(

source/tests/consistent/loss/test_ener.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ def setUp(self) -> None:
108108
)
109109
),
110110
}
111+
self.predict_dpmodel_style = {
112+
"energy_derv_c_redu": self.predict["virial"],
113+
"energy_derv_r": self.predict["force"],
114+
"energy_redu": self.predict["energy"],
115+
"energy": self.predict["atom_ener"],
116+
}
111117
self.label = {
112118
"energy": rng.random((self.nframes,)),
113119
"force": rng.random((self.nframes, self.natoms, 3)),
@@ -187,12 +193,12 @@ def eval_dp(self, dp_obj: Any) -> Any:
187193
return dp_obj(
188194
self.learning_rate,
189195
self.natoms,
190-
self.predict,
196+
self.predict_dpmodel_style,
191197
self.label,
192198
)
193199

194200
def eval_jax(self, jax_obj: Any) -> Any:
195-
predict = {kk: jnp.asarray(vv) for kk, vv in self.predict.items()}
201+
predict = {kk: jnp.asarray(vv) for kk, vv in self.predict_dpmodel_style.items()}
196202
label = {kk: jnp.asarray(vv) for kk, vv in self.label.items()}
197203

198204
loss, more_loss = jax_obj(
@@ -206,7 +212,10 @@ def eval_jax(self, jax_obj: Any) -> Any:
206212
return loss, more_loss
207213

208214
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
209-
predict = {kk: array_api_strict.asarray(vv) for kk, vv in self.predict.items()}
215+
predict = {
216+
kk: array_api_strict.asarray(vv)
217+
for kk, vv in self.predict_dpmodel_style.items()
218+
}
210219
label = {kk: array_api_strict.asarray(vv) for kk, vv in self.label.items()}
211220

212221
loss, more_loss = array_api_strict_obj(

0 commit comments

Comments
 (0)