Skip to content

Commit 20a60c6

Browse files
committed
add mae
1 parent fe6a92e commit 20a60c6

File tree

2 files changed

+46
-31
lines changed

2 files changed

+46
-31
lines changed

deepmd/pt/loss/ener.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
187187
)
188188
# more_loss['log_keys'].append('rmse_e')
189189
else: # use l1 and for all atoms
190+
energy_pred = energy_pred * atom_norm
191+
energy_label = energy_label * atom_norm
190192
l1_ener_loss = F.l1_loss(
191193
energy_pred.reshape(-1),
192194
energy_label.reshape(-1),
193-
reduction="sum",
195+
reduction="mean",
194196
)
195197
loss += pref_e * l1_ener_loss
196198
more_loss["mae_e"] = self.display_if_exist(
197-
F.l1_loss(
198-
energy_pred.reshape(-1),
199-
energy_label.reshape(-1),
200-
reduction="mean",
201-
).detach(),
199+
l1_ener_loss.detach(),
202200
find_energy,
203201
)
204202
# more_loss['log_keys'].append('rmse_e')
205-
if mae:
206-
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
207-
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
208-
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
209-
more_loss["mae_e_all"] = self.display_if_exist(
210-
mae_e_all.detach(), find_energy
211-
)
203+
# if mae:
204+
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
205+
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
206+
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
207+
# more_loss["mae_e_all"] = self.display_if_exist(
208+
# mae_e_all.detach(), find_energy
209+
# )
212210

213211
if (
214212
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
@@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
241239
rmse_f.detach(), find_force
242240
)
243241
else:
244-
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
242+
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")
245243
more_loss["mae_f"] = self.display_if_exist(
246-
l1_force_loss.mean().detach(), find_force
244+
l1_force_loss.detach(), find_force
247245
)
248-
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
246+
# l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
249247
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
250-
if mae:
251-
mae_f = torch.mean(torch.abs(diff_f))
252-
more_loss["mae_f"] = self.display_if_exist(
253-
mae_f.detach(), find_force
254-
)
248+
# if mae:
249+
# mae_f = torch.mean(torch.abs(diff_f))
250+
# more_loss["mae_f"] = self.display_if_exist(
251+
# mae_f.detach(), find_force
252+
# )
255253

256254
if self.has_pf and "atom_pref" in label:
257255
atom_pref = label["atom_pref"]
@@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
297295
if self.has_v and "virial" in model_pred and "virial" in label:
298296
find_virial = label.get("find_virial", 0.0)
299297
pref_v = pref_v * find_virial
298+
virial_label = label["virial"]
299+
virial_pred = model_pred["virial"].reshape(-1, 9)
300300
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
301-
l2_virial_loss = torch.mean(torch.square(diff_v))
302-
if not self.inference:
303-
more_loss["l2_virial_loss"] = self.display_if_exist(
304-
l2_virial_loss.detach(), find_virial
301+
if not self.use_l1_all:
302+
l2_virial_loss = torch.mean(torch.square(diff_v))
303+
if not self.inference:
304+
more_loss["l2_virial_loss"] = self.display_if_exist(
305+
l2_virial_loss.detach(), find_virial
306+
)
307+
loss += atom_norm * (pref_v * l2_virial_loss)
308+
rmse_v = l2_virial_loss.sqrt() * atom_norm
309+
more_loss["rmse_v"] = self.display_if_exist(
310+
rmse_v.detach(), find_virial
311+
)
312+
else:
313+
l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean")
314+
more_loss["mae_v"] = self.display_if_exist(
315+
l1_virial_loss.detach(), find_virial
305316
)
306-
loss += atom_norm * (pref_v * l2_virial_loss)
307-
rmse_v = l2_virial_loss.sqrt() * atom_norm
308-
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
309-
if mae:
310-
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
311-
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
317+
loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION)
318+
# if mae:
319+
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
320+
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
312321

313322
if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
314323
atom_ener = model_pred["atom_energy"]

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,12 @@ def loss_ener():
23462346
doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label."
23472347
doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1."
23482348
return [
2349+
Argument(
2350+
"use_l1_all",
2351+
bool,
2352+
optional=True,
2353+
default=False,
2354+
),
23492355
Argument(
23502356
"start_pref_e",
23512357
[float, int],

0 commit comments

Comments
 (0)