Skip to content

Commit 159361d

Browse files
authored
feat(jax): force & virial (#4251)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced new methods `forward_common_atomic` in multiple classes to enhance atomic model predictions and derivative calculations. - Added a new function `get_leading_dims` for better handling of output dimensions. - Added a new function `scatter_sum` for performing reduction operations on tensors. - Updated test methods to include flexible handling of results with the new `SKIP_FLAG` variable. - **Bug Fixes** - Improved numerical stability in calculations by ensuring small values are handled appropriately. - **Tests** - Expanded test outputs to include additional data like forces and virials for more comprehensive testing. - Enhanced backend handling in tests to accommodate new return values based on backend availability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent dd36e6c commit 159361d

File tree

10 files changed

+284
-17
lines changed

10 files changed

+284
-17
lines changed

deepmd/dpmodel/model/make_model.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,22 +222,42 @@ def call_lower(
222222
extended_coord, fparam=fparam, aparam=aparam
223223
)
224224
del extended_coord, fparam, aparam
225-
atomic_ret = self.atomic_model.forward_common_atomic(
225+
model_predict = self.forward_common_atomic(
226226
cc_ext,
227227
extended_atype,
228228
nlist,
229229
mapping=mapping,
230230
fparam=fp,
231231
aparam=ap,
232+
do_atomic_virial=do_atomic_virial,
233+
)
234+
model_predict = self.output_type_cast(model_predict, input_prec)
235+
return model_predict
236+
237+
def forward_common_atomic(
238+
self,
239+
extended_coord: np.ndarray,
240+
extended_atype: np.ndarray,
241+
nlist: np.ndarray,
242+
mapping: Optional[np.ndarray] = None,
243+
fparam: Optional[np.ndarray] = None,
244+
aparam: Optional[np.ndarray] = None,
245+
do_atomic_virial: bool = False,
246+
):
247+
atomic_ret = self.atomic_model.forward_common_atomic(
248+
extended_coord,
249+
extended_atype,
250+
nlist,
251+
mapping=mapping,
252+
fparam=fparam,
253+
aparam=aparam,
232254
)
233-
model_predict = fit_output_to_model_output(
255+
return fit_output_to_model_output(
234256
atomic_ret,
235257
self.atomic_output_def(),
236-
cc_ext,
258+
extended_coord,
237259
do_atomic_virial=do_atomic_virial,
238260
)
239-
model_predict = self.output_type_cast(model_predict, input_prec)
240-
return model_predict
241261

242262
forward_lower = call_lower
243263

deepmd/dpmodel/model/transform_output.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from deepmd.dpmodel.output_def import (
1010
FittingOutputDef,
1111
ModelOutputDef,
12+
OutputVariableDef,
1213
get_deriv_name,
1314
get_reduce_name,
1415
)
@@ -47,6 +48,28 @@ def fit_output_to_model_output(
4748
return model_ret
4849

4950

51+
def get_leading_dims(
52+
vv: np.ndarray,
53+
vdef: OutputVariableDef,
54+
):
55+
"""Get the dimensions of nf x nloc.
56+
57+
Parameters
58+
----------
59+
vv : np.ndarray
60+
The input array from which to compute the leading dimensions.
61+
vdef : OutputVariableDef
62+
The output variable definition containing the shape to exclude from `vv`.
63+
64+
Returns
65+
-------
66+
list
67+
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
68+
"""
69+
vshape = vv.shape
70+
return list(vshape[: (len(vshape) - len(vdef.shape))])
71+
72+
5073
def communicate_extended_output(
5174
model_ret: dict[str, np.ndarray],
5275
model_output_def: ModelOutputDef,
@@ -57,6 +80,7 @@ def communicate_extended_output(
5780
local and ghost (extended) atoms to local atoms.
5881
5982
"""
83+
xp = array_api_compat.get_namespace(mapping)
6084
new_ret = {}
6185
for kk in model_output_def.keys_outp():
6286
vv = model_ret[kk]
@@ -65,15 +89,63 @@ def communicate_extended_output(
6589
if vdef.reducible:
6690
kk_redu = get_reduce_name(kk)
6791
new_ret[kk_redu] = model_ret[kk_redu]
92+
kk_derv_r, kk_derv_c = get_deriv_name(kk)
93+
mldims = list(mapping.shape)
94+
vldims = get_leading_dims(vv, vdef)
6895
if vdef.r_differentiable:
69-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
70-
# name holders
71-
new_ret[kk_derv_r] = None
96+
if model_ret[kk_derv_r] is not None:
97+
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
98+
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
99+
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
100+
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
101+
# jax only
102+
if array_api_compat.is_jax_array(force):
103+
from deepmd.jax.common import (
104+
scatter_sum,
105+
)
106+
107+
force = scatter_sum(
108+
force,
109+
1,
110+
mapping,
111+
model_ret[kk_derv_r],
112+
)
113+
else:
114+
raise NotImplementedError("Only JAX arrays are supported.")
115+
new_ret[kk_derv_r] = force
116+
else:
117+
# name holders
118+
new_ret[kk_derv_r] = None
72119
if vdef.c_differentiable:
73120
assert vdef.r_differentiable
74-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
75-
new_ret[kk_derv_c] = None
76-
new_ret[kk_derv_c + "_redu"] = None
121+
if model_ret[kk_derv_c] is not None:
122+
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
123+
mapping = xp.tile(
124+
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
125+
)
126+
virial = xp.zeros(
127+
vldims + derv_c_ext_dims,
128+
dtype=vv.dtype,
129+
)
130+
# jax only
131+
if array_api_compat.is_jax_array(virial):
132+
from deepmd.jax.common import (
133+
scatter_sum,
134+
)
135+
136+
virial = scatter_sum(
137+
virial,
138+
1,
139+
mapping,
140+
model_ret[kk_derv_c],
141+
)
142+
else:
143+
raise NotImplementedError("Only JAX arrays are supported.")
144+
new_ret[kk_derv_c] = virial
145+
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
146+
else:
147+
new_ret[kk_derv_c] = None
148+
new_ret[kk_derv_c + "_redu"] = None
77149
if not do_atomic_virial:
78150
# pop atomic virial, because it is not correctly calculated.
79151
new_ret.pop(kk_derv_c)

deepmd/dpmodel/utils/env_mat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _make_env_mat(
6161
# nf x nloc x nnei x 3
6262
diff = coord_r - coord_l
6363
# nf x nloc x nnei
64-
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
64+
# the grad of JAX vector_norm is NaN at x=0
65+
diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
66+
length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True)
6567
# for index 0 nloc atom
6668
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
6769
t0 = 1 / (length + protection)

deepmd/jax/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs):
9595

9696
def __dlpack_device__(self, *args, **kwargs):
9797
return self.value.__dlpack_device__(*args, **kwargs)
98+
99+
100+
def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
101+
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
102+
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
103+
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
104+
shape = input.shape
105+
input = input.ravel()
106+
input = input.at[new_idx].add(src.ravel())
107+
return input.reshape(shape)

deepmd/jax/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
jax.config.update("jax_enable_x64", True)
13+
# jax.config.update("jax_debug_nans", True)
1314

1415
__all__ = [
1516
"jax",

deepmd/jax/model/base_model.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,107 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
)
5+
26
from deepmd.dpmodel.model.base_model import (
37
make_base_model,
48
)
9+
from deepmd.dpmodel.output_def import (
10+
get_deriv_name,
11+
get_reduce_name,
12+
)
13+
from deepmd.jax.env import (
14+
jax,
15+
jnp,
16+
)
517

618
BaseModel = make_base_model()
19+
20+
21+
def forward_common_atomic(
22+
self,
23+
extended_coord: jnp.ndarray,
24+
extended_atype: jnp.ndarray,
25+
nlist: jnp.ndarray,
26+
mapping: Optional[jnp.ndarray] = None,
27+
fparam: Optional[jnp.ndarray] = None,
28+
aparam: Optional[jnp.ndarray] = None,
29+
do_atomic_virial: bool = False,
30+
):
31+
atomic_ret = self.atomic_model.forward_common_atomic(
32+
extended_coord,
33+
extended_atype,
34+
nlist,
35+
mapping=mapping,
36+
fparam=fparam,
37+
aparam=aparam,
38+
)
39+
atomic_output_def = self.atomic_output_def()
40+
model_predict = {}
41+
for kk, vv in atomic_ret.items():
42+
model_predict[kk] = vv
43+
vdef = atomic_output_def[kk]
44+
shap = vdef.shape
45+
atom_axis = -(len(shap) + 1)
46+
if vdef.reducible:
47+
kk_redu = get_reduce_name(kk)
48+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
49+
kk_derv_r, kk_derv_c = get_deriv_name(kk)
50+
if vdef.c_differentiable:
51+
52+
def eval_output(
53+
cc_ext,
54+
extended_atype,
55+
nlist,
56+
mapping,
57+
fparam,
58+
aparam,
59+
*,
60+
_kk=kk,
61+
_atom_axis=atom_axis,
62+
):
63+
atomic_ret = self.atomic_model.forward_common_atomic(
64+
cc_ext[None, ...],
65+
extended_atype[None, ...],
66+
nlist[None, ...],
67+
mapping=mapping[None, ...] if mapping is not None else None,
68+
fparam=fparam[None, ...] if fparam is not None else None,
69+
aparam=aparam[None, ...] if aparam is not None else None,
70+
)
71+
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)
72+
73+
# extended_coord: [nf, nall, 3]
74+
# ff: [nf, *def, nall, 3]
75+
ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))(
76+
extended_coord,
77+
extended_atype,
78+
nlist,
79+
mapping,
80+
fparam,
81+
aparam,
82+
)
83+
# extended_force: [nf, nall, *def, 3]
84+
def_ndim = len(vdef.shape)
85+
extended_force = jnp.transpose(
86+
ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
87+
)
88+
89+
model_predict[kk_derv_r] = extended_force
90+
if vdef.c_differentiable:
91+
assert vdef.r_differentiable
92+
# avr: [nf, *def, nall, 3, 3]
93+
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
94+
# avr: [nf, *def, nall, 9]
95+
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
96+
# extended_virial: [nf, nall, *def, 9]
97+
extended_virial = jnp.transpose(
98+
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
99+
)
100+
101+
# the correction sums to zero, which does not contribute to global virial
102+
# cannot jit
103+
# if do_atomic_virial:
104+
# raise NotImplementedError("Atomic virial is not implemented yet.")
105+
# to [...,3,3] -> [...,9]
106+
model_predict[kk_derv_c] = extended_virial
107+
return model_predict

deepmd/jax/model/ener_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
33
Any,
4+
Optional,
45
)
56

67
from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
@@ -10,8 +11,12 @@
1011
from deepmd.jax.common import (
1112
flax_module,
1213
)
14+
from deepmd.jax.env import (
15+
jnp,
16+
)
1317
from deepmd.jax.model.base_model import (
1418
BaseModel,
19+
forward_common_atomic,
1520
)
1621

1722

@@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
2227
if name == "atomic_model":
2328
value = DPAtomicModel.deserialize(value.serialize())
2429
return super().__setattr__(name, value)
30+
31+
def forward_common_atomic(
32+
self,
33+
extended_coord: jnp.ndarray,
34+
extended_atype: jnp.ndarray,
35+
nlist: jnp.ndarray,
36+
mapping: Optional[jnp.ndarray] = None,
37+
fparam: Optional[jnp.ndarray] = None,
38+
aparam: Optional[jnp.ndarray] = None,
39+
do_atomic_virial: bool = False,
40+
):
41+
return forward_common_atomic(
42+
self,
43+
extended_coord,
44+
extended_atype,
45+
nlist,
46+
mapping=mapping,
47+
fparam=fparam,
48+
aparam=aparam,
49+
do_atomic_virial=do_atomic_virial,
50+
)

source/tests/consistent/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
"INSTALLED_ARRAY_API_STRICT",
7070
]
7171

72+
SKIP_FLAG = object()
73+
7274

7375
class CommonTest(ABC):
7476
data: ClassVar[dict]
@@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
362364
data2 = dp_obj.serialize()
363365
np.testing.assert_equal(data1, data2)
364366
for rr1, rr2 in zip(ret1, ret2):
367+
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
368+
continue
365369
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
366370
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"
367371

source/tests/consistent/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
5151
{},
5252
suffix=suffix,
5353
)
54-
return [ret["energy"], ret["atom_ener"]], {
54+
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
5555
t_coord: coords,
5656
t_type: atype,
5757
t_natoms: natoms,

0 commit comments

Comments
 (0)