@@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
187187 )
188188 # more_loss['log_keys'].append('rmse_e')
189189 else : # use l1 and for all atoms
190+ energy_pred = energy_pred * atom_norm
191+ energy_label = energy_label * atom_norm
190192 l1_ener_loss = F .l1_loss (
191193 energy_pred .reshape (- 1 ),
192194 energy_label .reshape (- 1 ),
193- reduction = "sum " ,
195+ reduction = "mean " ,
194196 )
195197 loss += pref_e * l1_ener_loss
196198 more_loss ["mae_e" ] = self .display_if_exist (
197- F .l1_loss (
198- energy_pred .reshape (- 1 ),
199- energy_label .reshape (- 1 ),
200- reduction = "mean" ,
201- ).detach (),
199+ l1_ener_loss .detach (),
202200 find_energy ,
203201 )
204202 # more_loss['log_keys'].append('rmse_e')
205- if mae :
206- mae_e = torch .mean (torch .abs (energy_pred - energy_label )) * atom_norm
207- more_loss ["mae_e" ] = self .display_if_exist (mae_e .detach (), find_energy )
208- mae_e_all = torch .mean (torch .abs (energy_pred - energy_label ))
209- more_loss ["mae_e_all" ] = self .display_if_exist (
210- mae_e_all .detach (), find_energy
211- )
203+ # if mae:
204+ # mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
205+ # more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
206+ # mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
207+ # more_loss["mae_e_all"] = self.display_if_exist(
208+ # mae_e_all.detach(), find_energy
209+ # )
212210
213211 if (
214212 (self .has_f or self .has_pf or self .relative_f or self .has_gf )
@@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
241239 rmse_f .detach (), find_force
242240 )
243241 else :
244- l1_force_loss = F .l1_loss (force_label , force_pred , reduction = "none " )
242+ l1_force_loss = F .l1_loss (force_label , force_pred , reduction = "mean " )
245243 more_loss ["mae_f" ] = self .display_if_exist (
246- l1_force_loss .mean (). detach (), find_force
244+ l1_force_loss .detach (), find_force
247245 )
248- l1_force_loss = l1_force_loss .sum (- 1 ).mean (- 1 ).sum ()
246+ # l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
249247 loss += (pref_f * l1_force_loss ).to (GLOBAL_PT_FLOAT_PRECISION )
250- if mae :
251- mae_f = torch .mean (torch .abs (diff_f ))
252- more_loss ["mae_f" ] = self .display_if_exist (
253- mae_f .detach (), find_force
254- )
248+ # if mae:
249+ # mae_f = torch.mean(torch.abs(diff_f))
250+ # more_loss["mae_f"] = self.display_if_exist(
251+ # mae_f.detach(), find_force
252+ # )
255253
256254 if self .has_pf and "atom_pref" in label :
257255 atom_pref = label ["atom_pref" ]
@@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
297295 if self .has_v and "virial" in model_pred and "virial" in label :
298296 find_virial = label .get ("find_virial" , 0.0 )
299297 pref_v = pref_v * find_virial
298+ virial_label = label ["virial" ]
299+ virial_pred = model_pred ["virial" ].reshape (- 1 , 9 )
300300 diff_v = label ["virial" ] - model_pred ["virial" ].reshape (- 1 , 9 )
301- l2_virial_loss = torch .mean (torch .square (diff_v ))
302- if not self .inference :
303- more_loss ["l2_virial_loss" ] = self .display_if_exist (
304- l2_virial_loss .detach (), find_virial
301+ if not self .use_l1_all :
302+ l2_virial_loss = torch .mean (torch .square (diff_v ))
303+ if not self .inference :
304+ more_loss ["l2_virial_loss" ] = self .display_if_exist (
305+ l2_virial_loss .detach (), find_virial
306+ )
307+ loss += atom_norm * (pref_v * l2_virial_loss )
308+ rmse_v = l2_virial_loss .sqrt () * atom_norm
309+ more_loss ["rmse_v" ] = self .display_if_exist (
310+ rmse_v .detach (), find_virial
311+ )
312+ else :
313+ l1_virial_loss = F .l1_loss (virial_label , virial_pred , reduction = "mean" )
314+ more_loss ["mae_v" ] = self .display_if_exist (
315+ l1_virial_loss .detach (), find_virial
305316 )
306- loss += atom_norm * (pref_v * l2_virial_loss )
307- rmse_v = l2_virial_loss .sqrt () * atom_norm
308- more_loss ["rmse_v" ] = self .display_if_exist (rmse_v .detach (), find_virial )
309- if mae :
310- mae_v = torch .mean (torch .abs (diff_v )) * atom_norm
311- more_loss ["mae_v" ] = self .display_if_exist (mae_v .detach (), find_virial )
317+ loss += (pref_v * l1_virial_loss ).to (GLOBAL_PT_FLOAT_PRECISION )
318+ # if mae:
319+ # mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
320+ # more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
312321
313322 if self .has_ae and "atom_energy" in model_pred and "atom_ener" in label :
314323 atom_ener = model_pred ["atom_energy" ]
0 commit comments