Skip to content

Commit afd4746

Browse files
Merge branch 'devel' into add_paddle_backend
2 parents cbc9c65 + d165fee commit afd4746

38 files changed

+1500
-247
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).
1919

2020
### Highlighted features
2121

22-
- **interfaced with multiple backends**, including TensorFlow, PyTorch and Paddle, the most popular deep learning frameworks, making the training process highly automatic and efficient.
22+
- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle the most popular deep learning frameworks, making the training process highly automatic and efficient.
2323
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
2424
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
2525
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
@@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea
7272

7373
#### v3
7474

75-
- Multiple backends supported. Add PyTorch and Paddle backend.
75+
- Multiple backends supported. Add PyTorch, JAX and Paddle backends.
7676
- The DPA-2 model.
7777

7878
## Install and use DeePMD-kit

deepmd/backend/jax.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class JAXBackend(Backend):
3333
"""The formal name of the backend."""
3434
features: ClassVar[Backend.Feature] = (
3535
Backend.Feature.IO
36-
# Backend.Feature.ENTRY_POINT
37-
# | Backend.Feature.DEEP_EVAL
38-
# | Backend.Feature.NEIGHBOR_STAT
36+
| Backend.Feature.ENTRY_POINT
37+
| Backend.Feature.DEEP_EVAL
38+
| Backend.Feature.NEIGHBOR_STAT
3939
)
4040
"""The features of the backend."""
41-
suffixes: ClassVar[list[str]] = [".jax"]
41+
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
4242
"""The suffixes of the backend."""
4343

4444
def is_available(self) -> bool:
@@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
7171
type[DeepEvalBackend]
7272
The Deep Eval backend of the backend.
7373
"""
74-
raise NotImplementedError
74+
from deepmd.jax.infer.deep_eval import (
75+
DeepEval,
76+
)
77+
78+
return DeepEval
7579

7680
@property
7781
def neighbor_stat(self) -> type["NeighborStat"]:
@@ -82,7 +86,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
8286
type[NeighborStat]
8387
The neighbor statistics of the backend.
8488
"""
85-
raise NotImplementedError
89+
from deepmd.jax.utils.neighbor_stat import (
90+
NeighborStat,
91+
)
92+
93+
return NeighborStat
8694

8795
@property
8896
def serialize_hook(self) -> Callable[[str], dict]:

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def call(
555555
coord_ext, atype_ext, nlist, self.davg, self.dstd
556556
)
557557
nf, nloc, nnei, _ = rr.shape
558-
sec = xp.asarray(self.sel_cumsum)
558+
sec = self.sel_cumsum
559559

560560
ng = self.neuron[-1]
561561
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)

deepmd/dpmodel/model/make_model.py

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3+
Callable,
34
Optional,
45
)
56

@@ -39,6 +40,95 @@
3940
)
4041

4142

43+
def model_call_from_call_lower(
44+
*, # enforce keyword-only arguments
45+
call_lower: Callable[
46+
[
47+
np.ndarray,
48+
np.ndarray,
49+
np.ndarray,
50+
Optional[np.ndarray],
51+
Optional[np.ndarray],
52+
bool,
53+
],
54+
dict[str, np.ndarray],
55+
],
56+
rcut: float,
57+
sel: list[int],
58+
mixed_types: bool,
59+
model_output_def: ModelOutputDef,
60+
coord: np.ndarray,
61+
atype: np.ndarray,
62+
box: Optional[np.ndarray] = None,
63+
fparam: Optional[np.ndarray] = None,
64+
aparam: Optional[np.ndarray] = None,
65+
do_atomic_virial: bool = False,
66+
):
67+
"""Return model prediction from lower interface.
68+
69+
Parameters
70+
----------
71+
coord
72+
The coordinates of the atoms.
73+
shape: nf x (nloc x 3)
74+
atype
75+
The type of atoms. shape: nf x nloc
76+
box
77+
The simulation box. shape: nf x 9
78+
fparam
79+
frame parameter. nf x ndf
80+
aparam
81+
atomic parameter. nf x nloc x nda
82+
do_atomic_virial
83+
If calculate the atomic virial.
84+
85+
Returns
86+
-------
87+
ret_dict
88+
The result dict of type dict[str,np.ndarray].
89+
The keys are defined by the `ModelOutputDef`.
90+
91+
"""
92+
nframes, nloc = atype.shape[:2]
93+
cc, bb, fp, ap = coord, box, fparam, aparam
94+
del coord, box, fparam, aparam
95+
if bb is not None:
96+
coord_normalized = normalize_coord(
97+
cc.reshape(nframes, nloc, 3),
98+
bb.reshape(nframes, 3, 3),
99+
)
100+
else:
101+
coord_normalized = cc.copy()
102+
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
103+
coord_normalized, atype, bb, rcut
104+
)
105+
nlist = build_neighbor_list(
106+
extended_coord,
107+
extended_atype,
108+
nloc,
109+
rcut,
110+
sel,
111+
distinguish_types=not mixed_types,
112+
)
113+
extended_coord = extended_coord.reshape(nframes, -1, 3)
114+
model_predict_lower = call_lower(
115+
extended_coord,
116+
extended_atype,
117+
nlist,
118+
mapping,
119+
fparam=fp,
120+
aparam=ap,
121+
do_atomic_virial=do_atomic_virial,
122+
)
123+
model_predict = communicate_extended_output(
124+
model_predict_lower,
125+
model_output_def,
126+
mapping,
127+
do_atomic_virial=do_atomic_virial,
128+
)
129+
return model_predict
130+
131+
42132
def make_model(T_AtomicModel: type[BaseAtomicModel]):
43133
"""Make a model as a derived class of an atomic model.
44134
@@ -130,45 +220,23 @@ def call(
130220
The keys are defined by the `ModelOutputDef`.
131221
132222
"""
133-
nframes, nloc = atype.shape[:2]
134223
cc, bb, fp, ap, input_prec = self.input_type_cast(
135224
coord, box=box, fparam=fparam, aparam=aparam
136225
)
137226
del coord, box, fparam, aparam
138-
if bb is not None:
139-
coord_normalized = normalize_coord(
140-
cc.reshape(nframes, nloc, 3),
141-
bb.reshape(nframes, 3, 3),
142-
)
143-
else:
144-
coord_normalized = cc.copy()
145-
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
146-
coord_normalized, atype, bb, self.get_rcut()
147-
)
148-
nlist = build_neighbor_list(
149-
extended_coord,
150-
extended_atype,
151-
nloc,
152-
self.get_rcut(),
153-
self.get_sel(),
154-
distinguish_types=not self.mixed_types(),
155-
)
156-
extended_coord = extended_coord.reshape(nframes, -1, 3)
157-
model_predict_lower = self.call_lower(
158-
extended_coord,
159-
extended_atype,
160-
nlist,
161-
mapping,
227+
model_predict = model_call_from_call_lower(
228+
call_lower=self.call_lower,
229+
rcut=self.get_rcut(),
230+
sel=self.get_sel(),
231+
mixed_types=self.mixed_types(),
232+
model_output_def=self.model_output_def(),
233+
coord=cc,
234+
atype=atype,
235+
box=bb,
162236
fparam=fp,
163237
aparam=ap,
164238
do_atomic_virial=do_atomic_virial,
165239
)
166-
model_predict = communicate_extended_output(
167-
model_predict_lower,
168-
self.model_output_def(),
169-
mapping,
170-
do_atomic_virial=do_atomic_virial,
171-
)
172240
model_predict = self.output_type_cast(model_predict, input_prec)
173241
return model_predict
174242

@@ -222,22 +290,42 @@ def call_lower(
222290
extended_coord, fparam=fparam, aparam=aparam
223291
)
224292
del extended_coord, fparam, aparam
225-
atomic_ret = self.atomic_model.forward_common_atomic(
293+
model_predict = self.forward_common_atomic(
226294
cc_ext,
227295
extended_atype,
228296
nlist,
229297
mapping=mapping,
230298
fparam=fp,
231299
aparam=ap,
300+
do_atomic_virial=do_atomic_virial,
232301
)
233-
model_predict = fit_output_to_model_output(
302+
model_predict = self.output_type_cast(model_predict, input_prec)
303+
return model_predict
304+
305+
def forward_common_atomic(
306+
self,
307+
extended_coord: np.ndarray,
308+
extended_atype: np.ndarray,
309+
nlist: np.ndarray,
310+
mapping: Optional[np.ndarray] = None,
311+
fparam: Optional[np.ndarray] = None,
312+
aparam: Optional[np.ndarray] = None,
313+
do_atomic_virial: bool = False,
314+
):
315+
atomic_ret = self.atomic_model.forward_common_atomic(
316+
extended_coord,
317+
extended_atype,
318+
nlist,
319+
mapping=mapping,
320+
fparam=fparam,
321+
aparam=aparam,
322+
)
323+
return fit_output_to_model_output(
234324
atomic_ret,
235325
self.atomic_output_def(),
236-
cc_ext,
326+
extended_coord,
237327
do_atomic_virial=do_atomic_virial,
238328
)
239-
model_predict = self.output_type_cast(model_predict, input_prec)
240-
return model_predict
241329

242330
forward_lower = call_lower
243331

deepmd/dpmodel/model/transform_output.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from deepmd.dpmodel.output_def import (
1010
FittingOutputDef,
1111
ModelOutputDef,
12+
OutputVariableDef,
1213
get_deriv_name,
1314
get_reduce_name,
1415
)
@@ -47,6 +48,28 @@ def fit_output_to_model_output(
4748
return model_ret
4849

4950

51+
def get_leading_dims(
52+
vv: np.ndarray,
53+
vdef: OutputVariableDef,
54+
):
55+
"""Get the dimensions of nf x nloc.
56+
57+
Parameters
58+
----------
59+
vv : np.ndarray
60+
The input array from which to compute the leading dimensions.
61+
vdef : OutputVariableDef
62+
The output variable definition containing the shape to exclude from `vv`.
63+
64+
Returns
65+
-------
66+
list
67+
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
68+
"""
69+
vshape = vv.shape
70+
return list(vshape[: (len(vshape) - len(vdef.shape))])
71+
72+
5073
def communicate_extended_output(
5174
model_ret: dict[str, np.ndarray],
5275
model_output_def: ModelOutputDef,
@@ -57,6 +80,7 @@ def communicate_extended_output(
5780
local and ghost (extended) atoms to local atoms.
5881
5982
"""
83+
xp = array_api_compat.get_namespace(mapping)
6084
new_ret = {}
6185
for kk in model_output_def.keys_outp():
6286
vv = model_ret[kk]
@@ -65,15 +89,63 @@ def communicate_extended_output(
6589
if vdef.reducible:
6690
kk_redu = get_reduce_name(kk)
6791
new_ret[kk_redu] = model_ret[kk_redu]
92+
kk_derv_r, kk_derv_c = get_deriv_name(kk)
93+
mldims = list(mapping.shape)
94+
vldims = get_leading_dims(vv, vdef)
6895
if vdef.r_differentiable:
69-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
70-
# name holders
71-
new_ret[kk_derv_r] = None
96+
if model_ret[kk_derv_r] is not None:
97+
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
98+
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
99+
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
100+
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
101+
# jax only
102+
if array_api_compat.is_jax_array(force):
103+
from deepmd.jax.common import (
104+
scatter_sum,
105+
)
106+
107+
force = scatter_sum(
108+
force,
109+
1,
110+
mapping,
111+
model_ret[kk_derv_r],
112+
)
113+
else:
114+
raise NotImplementedError("Only JAX arrays are supported.")
115+
new_ret[kk_derv_r] = force
116+
else:
117+
# name holders
118+
new_ret[kk_derv_r] = None
72119
if vdef.c_differentiable:
73120
assert vdef.r_differentiable
74-
kk_derv_r, kk_derv_c = get_deriv_name(kk)
75-
new_ret[kk_derv_c] = None
76-
new_ret[kk_derv_c + "_redu"] = None
121+
if model_ret[kk_derv_c] is not None:
122+
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
123+
mapping = xp.tile(
124+
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
125+
)
126+
virial = xp.zeros(
127+
vldims + derv_c_ext_dims,
128+
dtype=vv.dtype,
129+
)
130+
# jax only
131+
if array_api_compat.is_jax_array(virial):
132+
from deepmd.jax.common import (
133+
scatter_sum,
134+
)
135+
136+
virial = scatter_sum(
137+
virial,
138+
1,
139+
mapping,
140+
model_ret[kk_derv_c],
141+
)
142+
else:
143+
raise NotImplementedError("Only JAX arrays are supported.")
144+
new_ret[kk_derv_c] = virial
145+
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
146+
else:
147+
new_ret[kk_derv_c] = None
148+
new_ret[kk_derv_c + "_redu"] = None
77149
if not do_atomic_virial:
78150
# pop atomic virial, because it is not correctly calculated.
79151
new_ret.pop(kk_derv_c)

0 commit comments

Comments
 (0)