Skip to content

Commit c5ad841

Browse files
1azykingcoderabbitai[bot]anyangmlwanghan-iapcmnjzjz
authored
feat(pt): train with energy Hessian (#4169)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for Hessian calculations across various components, enhancing the model's capabilities. - Added a new loss function for Hessian, allowing for more comprehensive training scenarios. - New JSON configuration files for multi-task and single-task learning models. - Enhanced output handling to include Hessian data in model evaluations. - Added new methods and properties to support Hessian in several classes and modules. - **Bug Fixes** - Improved handling of output shapes and results related to Hessian data. - **Documentation** - Updated documentation to include new Hessian properties and training guidelines. - Added sections detailing Hessian configurations and requirements in the training documentation. - **Tests** - Added unit tests for the new Hessian-related functionalities to ensure consistency and correctness. - Enhanced existing test cases to incorporate Hessian data handling and validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Anchor Yu <91590308+1azyking@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: anyangml <anyangpeng.ca@gmail.com> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent bf79cc6 commit c5ad841

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1083
-45
lines changed

deepmd/calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def calculate(
130130
cell = None
131131
symbols = self.atoms.get_chemical_symbols()
132132
atype = [self.type_dict[k] for k in symbols]
133-
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
133+
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
134134
self.results["energy"] = e[0][0]
135135
# see https://gitlab.com/ase/ase/-/merge_requests/2485
136136
self.results["free_energy"] = e[0][0]

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ def _get_output_shape(self, odef, nframes, natoms):
383383
# Something wrong here?
384384
# return [nframes, *shape, natoms, 1]
385385
return [nframes, natoms, *odef.shape, 1]
386+
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
387+
# hessian
388+
return [nframes, 3 * natoms, 3 * natoms]
386389
else:
387390
raise RuntimeError("unknown category")
388391

deepmd/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def label(self, data: dict) -> dict:
6767
cell = data["cells"].reshape((nframes, 9))
6868
else:
6969
cell = None
70-
e, f, v = self.dp.eval(coord, cell, atype)
70+
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
7171
data = data.copy()
7272
data["energies"] = e.reshape((nframes,))
7373
data["forces"] = f.reshape((nframes, natoms, 3))

deepmd/entrypoints/test.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def test_ener(
303303
if dp.has_spin:
304304
data.add("spin", 3, atomic=True, must=True, high_prec=False)
305305
data.add("force_mag", 3, atomic=True, must=False, high_prec=False)
306+
if dp.has_hessian:
307+
data.add("hessian", 1, atomic=True, must=True, high_prec=False)
306308

307309
test_data = data.get_test()
308310
mixed_type = data.mixed_type
@@ -352,6 +354,9 @@ def test_ener(
352354
energy = energy.reshape([numb_test, 1])
353355
force = force.reshape([numb_test, -1])
354356
virial = virial.reshape([numb_test, 9])
357+
if dp.has_hessian:
358+
hessian = ret[3]
359+
hessian = hessian.reshape([numb_test, -1])
355360
if has_atom_ener:
356361
ae = ret[3]
357362
av = ret[4]
@@ -415,6 +420,10 @@ def test_ener(
415420
rmse_ea = rmse_e / natoms
416421
mae_va = mae_v / natoms
417422
rmse_va = rmse_v / natoms
423+
if dp.has_hessian:
424+
diff_h = hessian - test_data["hessian"][:numb_test]
425+
mae_h = mae(diff_h)
426+
rmse_h = rmse(diff_h)
418427
if has_atom_ener:
419428
diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
420429
mae_ae = mae(diff_ae)
@@ -447,6 +456,9 @@ def test_ener(
447456
if has_atom_ener:
448457
log.info(f"Atomic ener MAE : {mae_ae:e} eV")
449458
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")
459+
if dp.has_hessian:
460+
log.info(f"Hessian MAE : {mae_h:e} eV/A^2")
461+
log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2")
450462

451463
if detail_file is not None:
452464
detail_path = Path(detail_file)
@@ -530,8 +542,24 @@ def test_ener(
530542
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
531543
append=append_detail,
532544
)
545+
if dp.has_hessian:
546+
data_h = test_data["hessian"][:numb_test].reshape(-1, 1)
547+
pred_h = hessian.reshape(-1, 1)
548+
h = np.concatenate(
549+
(
550+
data_h,
551+
pred_h,
552+
),
553+
axis=1,
554+
)
555+
save_txt_file(
556+
detail_path.with_suffix(".h.out"),
557+
h,
558+
header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)",
559+
append=append_detail,
560+
)
533561
if not out_put_spin:
534-
return {
562+
dict_to_return = {
535563
"mae_e": (mae_e, energy.size),
536564
"mae_ea": (mae_ea, energy.size),
537565
"mae_f": (mae_f, force.size),
@@ -544,7 +572,7 @@ def test_ener(
544572
"rmse_va": (rmse_va, virial.size),
545573
}
546574
else:
547-
return {
575+
dict_to_return = {
548576
"mae_e": (mae_e, energy.size),
549577
"mae_ea": (mae_ea, energy.size),
550578
"mae_fr": (mae_fr, force_r.size),
@@ -558,6 +586,10 @@ def test_ener(
558586
"rmse_v": (rmse_v, virial.size),
559587
"rmse_va": (rmse_va, virial.size),
560588
}
589+
if dp.has_hessian:
590+
dict_to_return["mae_h"] = (mae_h, hessian.size)
591+
dict_to_return["rmse_h"] = (rmse_h, hessian.size)
592+
return dict_to_return
561593

562594

563595
def print_ener_sys_avg(avg: dict[str, float]) -> None:
@@ -584,6 +616,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
584616
log.info(f"Virial RMSE : {avg['rmse_v']:e} eV")
585617
log.info(f"Virial MAE/Natoms : {avg['mae_va']:e} eV")
586618
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")
619+
if "rmse_h" in avg.keys():
620+
log.info(f"Hessian MAE : {avg['mae_h']:e} eV/A^2")
621+
log.info(f"Hessian RMSE : {avg['rmse_h']:e} eV/A^2")
587622

588623

589624
def test_dos(

deepmd/infer/deep_eval.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class DeepEvalBackend(ABC):
7575
# old models in v1
7676
"global_polar": "global_polar",
7777
"wfc": "wfc",
78+
"energy_derv_r_derv_r": "hessian",
7879
}
7980

8081
@abstractmethod
@@ -274,6 +275,10 @@ def get_has_spin(self) -> bool:
274275
"""Check if the model has spin atom types."""
275276
return False
276277

278+
def get_has_hessian(self):
279+
"""Check if the model has hessian."""
280+
return False
281+
277282
def get_var_name(self) -> str:
278283
"""Get the name of the fitting property."""
279284
raise NotImplementedError
@@ -543,6 +548,11 @@ def has_spin(self) -> bool:
543548
"""Check if the model has spin."""
544549
return self.deep_eval.get_has_spin()
545550

551+
@property
552+
def has_hessian(self) -> bool:
553+
"""Check if the model has hessian."""
554+
return self.deep_eval.get_has_hessian()
555+
546556
def get_ntypes_spin(self) -> int:
547557
"""Get the number of spin atom types of this model. Only used in old implement."""
548558
return self.deep_eval.get_ntypes_spin()

deepmd/infer/deep_pot.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def output_def(self) -> ModelOutputDef:
6464
r_differentiable=True,
6565
c_differentiable=True,
6666
atomic=True,
67+
r_hessian=True,
6768
),
6869
]
6970
)
@@ -99,7 +100,10 @@ def eval(
99100
aparam: Optional[np.ndarray],
100101
mixed_type: bool,
101102
**kwargs: Any,
102-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
103+
) -> Union[
104+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
105+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
106+
]:
103107
pass
104108

105109
@overload
@@ -113,7 +117,10 @@ def eval(
113117
aparam: Optional[np.ndarray],
114118
mixed_type: bool,
115119
**kwargs: Any,
116-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
120+
) -> Union[
121+
tuple[np.ndarray, np.ndarray, np.ndarray],
122+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
123+
]:
117124
pass
118125

119126
@overload
@@ -179,6 +186,8 @@ def eval(
179186
atomic_virial
180187
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned
181188
when atomic is True.
189+
hessian
190+
The Hessian matrix of the system, in shape (nframes, 3 * natoms, 3 * natoms). Returned when available.
182191
"""
183192
# This method has been used by:
184193
# documentation python.md
@@ -239,6 +248,11 @@ def eval(
239248
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
240249
mask_mag = results["mask_mag"].reshape(nframes, natoms, 1)
241250
result = (*list(result), force_mag, mask_mag)
251+
if self.deep_eval.get_has_hessian():
252+
hessian = results["energy_derv_r_derv_r"].reshape(
253+
nframes, 3 * natoms, 3 * natoms
254+
)
255+
result = (*list(result), hessian)
242256
return result
243257

244258

deepmd/jax/infer/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ def _get_output_shape(self, odef, nframes, natoms):
411411
elif odef.category == OutputVariableCategory.OUT:
412412
# atom_energy, atom_tensor
413413
return [nframes, natoms, *odef.shape, 1]
414+
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
415+
# hessian
416+
return [nframes, 3 * natoms, 3 * natoms]
414417
else:
415418
raise RuntimeError("unknown category")
416419

deepmd/pt/infer/deep_eval.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def __init__(
130130
] = state_dict[item].clone()
131131
state_dict = state_dict_head
132132
model = get_model(self.input_param).to(DEVICE)
133-
model = torch.jit.script(model)
133+
if not self.input_param.get("hessian_mode"):
134+
model = torch.jit.script(model)
134135
self.dp = ModelWrapper(model)
135136
self.dp.load_state_dict(state_dict)
136137
elif str(self.model_path).endswith(".pth"):
@@ -160,6 +161,7 @@ def __init__(
160161
self._has_spin = getattr(self.dp.model["Default"], "has_spin", False)
161162
if callable(self._has_spin):
162163
self._has_spin = self._has_spin()
164+
self._has_hessian = self.model_def_script.get("hessian_mode", False)
163165

164166
def get_rcut(self) -> float:
165167
"""Get the cutoff radius of this model."""
@@ -243,6 +245,10 @@ def get_has_spin(self):
243245
"""Check if the model has spin atom types."""
244246
return self._has_spin
245247

248+
def get_has_hessian(self):
249+
"""Check if the model has hessian."""
250+
return self._has_hessian
251+
246252
def eval(
247253
self,
248254
coords: np.ndarray,
@@ -348,6 +354,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
348354
OutputVariableCategory.REDU,
349355
OutputVariableCategory.DERV_R,
350356
OutputVariableCategory.DERV_C_REDU,
357+
OutputVariableCategory.DERV_R_DERV_R,
351358
)
352359
]
353360

@@ -577,6 +584,9 @@ def _get_output_shape(self, odef, nframes, natoms):
577584
# Something wrong here?
578585
# return [nframes, *shape, natoms, 1]
579586
return [nframes, natoms, *odef.shape, 1]
587+
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
588+
return [nframes, 3 * natoms, 3 * natoms]
589+
# return [nframes, *odef.shape, 3 * natoms, 3 * natoms]
580590
else:
581591
raise RuntimeError("unknown category")
582592

deepmd/pt/infer/inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __init__(
5555
] = state_dict[item].clone()
5656
state_dict = state_dict_head
5757

58+
model_params.pop(
59+
"hessian_mode", None
60+
) # wrapper Hessian to Energy model due to JIT limit
5861
self.model_params = deepcopy(model_params)
5962
self.model = get_model(model_params).to(DEVICE)
6063

deepmd/pt/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DOSLoss,
77
)
88
from .ener import (
9+
EnergyHessianStdLoss,
910
EnergyStdLoss,
1011
)
1112
from .ener_spin import (
@@ -24,6 +25,7 @@
2425
__all__ = [
2526
"DOSLoss",
2627
"DenoiseLoss",
28+
"EnergyHessianStdLoss",
2729
"EnergySpinLoss",
2830
"EnergyStdLoss",
2931
"PropertyLoss",

0 commit comments

Comments
 (0)