Skip to content

Commit eb9e71d

Browse files
authored
feat(jax): Hessian (#4649)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced enhanced Hessian fitting capabilities that extend model outputs to include second-order derivative information. - Integrated Hessian computations into the output transformation workflow for more detailed analytical results. - **Tests** - Updated test suites to conditionally import modules based on Python version, ensuring compatibility with the JAX library. - Adjusted precision level in tests for finite differences to improve accuracy of comparisons. - Added a new environment variable for memory allocation handling in test configurations. - Introduced a new function for scatter summation specifically for JAX arrays. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 4b78a1c commit eb9e71d

File tree

7 files changed

+438
-14
lines changed

7 files changed

+438
-14
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464
CUDA_VISIBLE_DEVICES: 0
6565
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
6666
XLA_PYTHON_CLIENT_PREALLOCATE: false
67+
XLA_PYTHON_CLIENT_ALLOCATOR: platform
6768
- name: Convert models
6869
run: source/tests/infer/convert-models.sh
6970
- name: Download libtorch

deepmd/dpmodel/array_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Utilities for the array API."""
33

44
import array_api_compat
5+
import numpy as np
56
from packaging.version import (
67
Version,
78
)
@@ -73,3 +74,21 @@ def xp_take_along_axis(arr, indices, axis):
7374
out = xp.take(arr, indices)
7475
out = xp.reshape(out, shape)
7576
return xp_swapaxes(out, axis, -1)
77+
78+
79+
def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
80+
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
81+
# jax only
82+
if array_api_compat.is_jax_array(input):
83+
from deepmd.jax.common import (
84+
scatter_sum,
85+
)
86+
87+
return scatter_sum(
88+
input,
89+
dim,
90+
index,
91+
src,
92+
)
93+
else:
94+
raise NotImplementedError("Only JAX arrays are supported.")

deepmd/dpmodel/model/ener_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from copy import (
3+
deepcopy,
4+
)
5+
26
from deepmd.dpmodel.atomic_model import (
37
DPEnergyAtomicModel,
48
)
59
from deepmd.dpmodel.model.base_model import (
610
BaseModel,
711
)
12+
from deepmd.dpmodel.output_def import (
13+
FittingOutputDef,
14+
)
815

916
from .dp_model import (
1017
DPModelCommon,
@@ -25,3 +32,15 @@ def __init__(
2532
) -> None:
2633
DPModelCommon.__init__(self)
2734
DPEnergyModel_.__init__(self, *args, **kwargs)
35+
self._enable_hessian = False
36+
self.hess_fitting_def = None
37+
38+
def enable_hessian(self):
39+
self.hess_fitting_def = deepcopy(self.atomic_output_def())
40+
self.hess_fitting_def["energy"].r_hessian = True
41+
self._enable_hessian = True
42+
43+
def atomic_output_def(self) -> FittingOutputDef:
44+
if self._enable_hessian:
45+
return self.hess_fitting_def
46+
return super().atomic_output_def()

deepmd/dpmodel/model/transform_output.py

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import array_api_compat
44
import numpy as np
55

6+
from deepmd.dpmodel.array_api import (
7+
xp_scatter_sum,
8+
)
69
from deepmd.dpmodel.common import (
710
GLOBAL_ENER_FLOAT_PRECISION,
811
)
@@ -11,6 +14,7 @@
1114
ModelOutputDef,
1215
OutputVariableDef,
1316
get_deriv_name,
17+
get_hessian_name,
1418
get_reduce_name,
1519
)
1620

@@ -81,6 +85,7 @@ def communicate_extended_output(
8185
8286
"""
8387
xp = array_api_compat.get_namespace(mapping)
88+
mapping_ = mapping
8489
new_ret = {}
8590
for kk in model_output_def.keys_outp():
8691
vv = model_ret[kk]
@@ -98,24 +103,96 @@ def communicate_extended_output(
98103
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
99104
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
100105
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
101-
# jax only
102-
if array_api_compat.is_jax_array(force):
103-
from deepmd.jax.common import (
104-
scatter_sum,
105-
)
106-
107-
force = scatter_sum(
108-
force,
109-
1,
110-
mapping,
111-
model_ret[kk_derv_r],
112-
)
113-
else:
114-
raise NotImplementedError("Only JAX arrays are supported.")
106+
force = xp_scatter_sum(
107+
force,
108+
1,
109+
mapping,
110+
model_ret[kk_derv_r],
111+
)
115112
new_ret[kk_derv_r] = force
116113
else:
117114
# name holders
118115
new_ret[kk_derv_r] = None
116+
if vdef.r_hessian:
117+
kk_hess = get_hessian_name(kk)
118+
if model_ret[kk_hess] is not None:
119+
# [nf, *def, nall, 3, nall, 3]
120+
hess_ = model_ret[kk_hess]
121+
def_ndim = len(vdef.shape)
122+
# [nf, nall1, nall2, *def, 3(1), 3(2)]
123+
hess_1 = xp.permute_dims(
124+
hess_,
125+
(
126+
0,
127+
def_ndim + 1,
128+
def_ndim + 3,
129+
*range(1, def_ndim + 1),
130+
def_ndim + 2,
131+
def_ndim + 4,
132+
),
133+
)
134+
nall = hess_1.shape[1]
135+
# (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)]
136+
hessian1 = xp.zeros(
137+
[*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype
138+
)
139+
mapping_hess = xp.reshape(
140+
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
141+
)
142+
mapping_hess = xp.tile(
143+
mapping_hess,
144+
[1] * len(mldims) + [nall, *vdef.shape, 3, 3],
145+
)
146+
hessian1 = xp_scatter_sum(
147+
hessian1,
148+
1,
149+
mapping_hess,
150+
hess_1,
151+
)
152+
# [nf, nall2, nloc1, *def, 3(1), 3(2)]
153+
hessian1 = xp.permute_dims(
154+
hessian1,
155+
(0, 2, 1, *range(3, def_ndim + 5)),
156+
)
157+
nloc = hessian1.shape[2]
158+
# (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)]
159+
hessian = xp.zeros(
160+
[*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype
161+
)
162+
mapping_hess = xp.reshape(
163+
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
164+
)
165+
mapping_hess = xp.tile(
166+
mapping_hess,
167+
[1] * len(mldims) + [nloc, *vdef.shape, 3, 3],
168+
)
169+
hessian = xp_scatter_sum(
170+
hessian,
171+
1,
172+
mapping_hess,
173+
hessian1,
174+
)
175+
# -> [nf, *def, nloc1, 3(1), nloc2, 3(2)]
176+
hessian = xp.permute_dims(
177+
hessian,
178+
(
179+
0,
180+
*range(3, def_ndim + 3),
181+
2,
182+
def_ndim + 3,
183+
1,
184+
def_ndim + 4,
185+
),
186+
)
187+
# -> [nf, *def nloc1 * 3, nloc2 * 3]
188+
hessian = xp.reshape(
189+
hessian,
190+
(hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3),
191+
)
192+
193+
new_ret[kk_hess] = hessian
194+
else:
195+
new_ret[kk_hess] = None
119196
if vdef.c_differentiable:
120197
assert vdef.r_differentiable
121198
if model_ret[kk_derv_c] is not None:

deepmd/jax/model/base_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from deepmd.dpmodel.output_def import (
1010
get_deriv_name,
11+
get_hessian_name,
1112
get_reduce_name,
1213
)
1314
from deepmd.jax.env import (
@@ -87,6 +88,18 @@ def eval_output(
8788
)
8889

8990
model_predict[kk_derv_r] = extended_force
91+
if vdef.r_hessian:
92+
# [nf, *def, nall, 3, nall, 3]
93+
hessian = jax.vmap(jax.hessian(eval_output, argnums=0))(
94+
extended_coord,
95+
extended_atype,
96+
nlist,
97+
mapping,
98+
fparam,
99+
aparam,
100+
)
101+
kk_hessian = get_hessian_name(kk)
102+
model_predict[kk_hessian] = hessian
90103
if vdef.c_differentiable:
91104
assert vdef.r_differentiable
92105
# avr: [nf, *def, nall, 3, 3]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
7+
from deepmd.dpmodel.common import (
8+
to_numpy_array,
9+
)
10+
11+
if sys.version_info >= (3, 10):
12+
from deepmd.jax.common import (
13+
to_jax_array,
14+
)
15+
from deepmd.jax.descriptor.se_e2_a import (
16+
DescrptSeA,
17+
)
18+
from deepmd.jax.env import (
19+
jnp,
20+
)
21+
from deepmd.jax.fitting.fitting import (
22+
EnergyFittingNet,
23+
)
24+
from deepmd.jax.model.ener_model import (
25+
EnergyModel,
26+
)
27+
28+
dtype = jnp.float64
29+
30+
31+
@unittest.skipIf(
32+
sys.version_info < (3, 10),
33+
"JAX requires Python 3.10 or later",
34+
)
35+
class TestCaseSingleFrameWithoutNlist:
36+
def setUp(self) -> None:
37+
# nloc == 3, nall == 4
38+
self.nloc = 3
39+
self.nf, self.nt = 1, 2
40+
self.coord = np.array(
41+
[
42+
[0, 0, 0],
43+
[0, 1, 0],
44+
[0, 0, 1],
45+
],
46+
dtype=np.float64,
47+
).reshape([1, self.nloc * 3])
48+
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
49+
self.cell = 2.0 * np.eye(3).reshape([1, 9])
50+
# sel = [5, 2]
51+
self.sel = [16, 8]
52+
self.sel_mix = [24]
53+
self.natoms = [3, 3, 2, 1]
54+
self.rcut = 2.2
55+
self.rcut_smth = 0.4
56+
self.atol = 1e-12
57+
58+
59+
@unittest.skipIf(
60+
sys.version_info < (3, 10),
61+
"JAX requires Python 3.10 or later",
62+
)
63+
class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
64+
def setUp(self):
65+
TestCaseSingleFrameWithoutNlist.setUp(self)
66+
67+
def test_self_consistency(self):
68+
ds = DescrptSeA(
69+
self.rcut,
70+
self.rcut_smth,
71+
self.sel,
72+
)
73+
ft = EnergyFittingNet(
74+
self.nt,
75+
ds.get_dim_out(),
76+
mixed_types=ds.mixed_types(),
77+
)
78+
type_map = ["foo", "bar"]
79+
md0 = EnergyModel(ds, ft, type_map=type_map)
80+
md1 = EnergyModel.deserialize(md0.serialize())
81+
md0.enable_hessian()
82+
md1.enable_hessian()
83+
args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]]
84+
ret0 = md0.call(*args)
85+
ret1 = md1.call(*args)
86+
np.testing.assert_allclose(
87+
to_numpy_array(ret0["energy"]),
88+
to_numpy_array(ret1["energy"]),
89+
atol=self.atol,
90+
)
91+
np.testing.assert_allclose(
92+
to_numpy_array(ret0["energy_redu"]),
93+
to_numpy_array(ret1["energy_redu"]),
94+
atol=self.atol,
95+
)
96+
np.testing.assert_allclose(
97+
to_numpy_array(ret0["energy_derv_r"]),
98+
to_numpy_array(ret1["energy_derv_r"]),
99+
atol=self.atol,
100+
)
101+
np.testing.assert_allclose(
102+
to_numpy_array(ret0["energy_derv_c_redu"]),
103+
to_numpy_array(ret1["energy_derv_c_redu"]),
104+
atol=self.atol,
105+
)
106+
np.testing.assert_allclose(
107+
to_numpy_array(ret0["energy_derv_r_derv_r"]),
108+
to_numpy_array(ret1["energy_derv_r_derv_r"]),
109+
atol=self.atol,
110+
)
111+
ret0 = md0.call(*args, do_atomic_virial=True)
112+
ret1 = md1.call(*args, do_atomic_virial=True)
113+
np.testing.assert_allclose(
114+
to_numpy_array(ret0["energy_derv_c"]),
115+
to_numpy_array(ret1["energy_derv_c"]),
116+
atol=self.atol,
117+
)

0 commit comments

Comments
 (0)