Skip to content

Commit 226c8b9

Browse files
committed
Merge branch 'devel' into devel-use_aparam_as_mask
2 parents eee5531 + b8e57f2 commit 226c8b9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+399
-123
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import copy
2+
import math
33
from typing import (
44
Optional,
55
)
66

7+
import array_api_compat
78
import numpy as np
89

910
from deepmd.dpmodel.common import (
1011
NativeOP,
12+
to_numpy_array,
1113
)
1214
from deepmd.dpmodel.output_def import (
1315
FittingOutputDef,
@@ -172,17 +174,18 @@ def forward_common_atomic(
172174
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
173175
174176
"""
177+
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
175178
_, nloc, _ = nlist.shape
176179
atype = extended_atype[:, :nloc]
177180
if self.pair_excl is not None:
178181
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
179182
# exclude neighbors in the nlist
180-
nlist = np.where(pair_mask == 1, nlist, -1)
183+
nlist = xp.where(pair_mask == 1, nlist, -1)
181184

182185
ext_atom_mask = self.make_atom_mask(extended_atype)
183186
ret_dict = self.forward_atomic(
184187
extended_coord,
185-
np.where(ext_atom_mask, extended_atype, 0),
188+
xp.where(ext_atom_mask, extended_atype, 0),
186189
nlist,
187190
mapping=mapping,
188191
fparam=fparam,
@@ -191,13 +194,13 @@ def forward_common_atomic(
191194
ret_dict = self.apply_out_stat(ret_dict, atype)
192195

193196
# nf x nloc
194-
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)
197+
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
195198
if self.atom_excl is not None:
196199
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
197200

198201
for kk in ret_dict.keys():
199202
out_shape = ret_dict[kk].shape
200-
out_shape2 = np.prod(out_shape[2:])
203+
out_shape2 = math.prod(out_shape[2:])
201204
ret_dict[kk] = (
202205
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
203206
* atom_mask[:, :, None]
@@ -232,14 +235,15 @@ def serialize(self) -> dict:
232235
"rcond": self.rcond,
233236
"preset_out_bias": self.preset_out_bias,
234237
"@variables": {
235-
"out_bias": self.out_bias,
236-
"out_std": self.out_std,
238+
"out_bias": to_numpy_array(self.out_bias),
239+
"out_std": to_numpy_array(self.out_std),
237240
},
238241
}
239242

240243
@classmethod
241244
def deserialize(cls, data: dict) -> "BaseAtomicModel":
242-
data = copy.deepcopy(data)
245+
# do not deep copy Descriptor and Fitting class
246+
data = data.copy()
243247
variables = data.pop("@variables")
244248
obj = cls(**data)
245249
for kk in variables.keys():

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,20 @@ def serialize(self) -> dict:
169169
)
170170
return dd
171171

172+
# for subclass overriden
173+
base_descriptor_cls = BaseDescriptor
174+
"""The base descriptor class."""
175+
base_fitting_cls = BaseFitting
176+
"""The base fitting class."""
177+
172178
@classmethod
173179
def deserialize(cls, data) -> "DPAtomicModel":
174180
data = copy.deepcopy(data)
175181
check_version_compatibility(data.pop("@version", 1), 2, 2)
176182
data.pop("@class")
177183
data.pop("type")
178-
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
179-
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
184+
descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor"))
185+
fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting"))
180186
data["descriptor"] = descriptor_obj
181187
data["fitting"] = fitting_obj
182188
obj = super().deserialize(data)

deepmd/dpmodel/model/make_model.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Optional,
44
)
55

6+
import array_api_compat
67
import numpy as np
78

89
from deepmd.dpmodel.atomic_model.base_atomic_model import (
@@ -75,7 +76,8 @@ def __init__(
7576
else:
7677
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
7778
self.precision_dict = PRECISION_DICT
78-
self.reverse_precision_dict = RESERVED_PRECISON_DICT
79+
# not supported by flax
80+
# self.reverse_precision_dict = RESERVED_PRECISON_DICT
7981
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
8082
self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION
8183

@@ -253,9 +255,7 @@ def input_type_cast(
253255
str,
254256
]:
255257
"""Cast the input data to global float type."""
256-
input_prec = self.reverse_precision_dict[
257-
self.precision_dict[coord.dtype.name]
258-
]
258+
input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]]
259259
###
260260
### type checking would not pass jit, convert to coord prec anyway
261261
###
@@ -264,10 +264,7 @@ def input_type_cast(
264264
for vv in [box, fparam, aparam]
265265
]
266266
box, fparam, aparam = _lst
267-
if (
268-
input_prec
269-
== self.reverse_precision_dict[self.global_np_float_precision]
270-
):
267+
if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]:
271268
return coord, box, fparam, aparam, input_prec
272269
else:
273270
pp = self.global_np_float_precision
@@ -286,8 +283,7 @@ def output_type_cast(
286283
) -> dict[str, np.ndarray]:
287284
"""Convert the model output to the input prec."""
288285
do_cast = (
289-
input_prec
290-
!= self.reverse_precision_dict[self.global_np_float_precision]
286+
input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision]
291287
)
292288
pp = self.precision_dict[input_prec]
293289
odef = self.model_output_def()
@@ -366,17 +362,18 @@ def _format_nlist(
366362
nnei: int,
367363
extra_nlist_sort: bool = False,
368364
):
365+
xp = array_api_compat.array_namespace(extended_coord, nlist)
369366
n_nf, n_nloc, n_nnei = nlist.shape
370367
extended_coord = extended_coord.reshape([n_nf, -1, 3])
371368
nall = extended_coord.shape[1]
372369
rcut = self.get_rcut()
373370

374371
if n_nnei < nnei:
375372
# make a copy before revise
376-
ret = np.concatenate(
373+
ret = xp.concat(
377374
[
378375
nlist,
379-
-1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
376+
-1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
380377
],
381378
axis=-1,
382379
)
@@ -385,16 +382,16 @@ def _format_nlist(
385382
n_nf, n_nloc, n_nnei = nlist.shape
386383
# make a copy before revise
387384
m_real_nei = nlist >= 0
388-
ret = np.where(m_real_nei, nlist, 0)
385+
ret = xp.where(m_real_nei, nlist, 0)
389386
coord0 = extended_coord[:, :n_nloc, :]
390387
index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2)
391-
coord1 = np.take_along_axis(extended_coord, index, axis=1)
388+
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
392389
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
393-
rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
394-
rr = np.where(m_real_nei, rr, float("inf"))
395-
rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1)
396-
ret = np.take_along_axis(ret, ret_mapping, axis=2)
397-
ret = np.where(rr > rcut, -1, ret)
390+
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
391+
rr = xp.where(m_real_nei, rr, float("inf"))
392+
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
393+
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
394+
ret = xp.where(rr > rcut, -1, ret)
398395
ret = ret[..., :nnei]
399396
# not extra_nlist_sort and n_nnei <= nnei:
400397
elif n_nnei == nnei:

deepmd/dpmodel/model/transform_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
import array_api_compat
34
import numpy as np
45

56
from deepmd.dpmodel.common import (
@@ -23,6 +24,7 @@ def fit_output_to_model_output(
2324
the model output.
2425
2526
"""
27+
xp = array_api_compat.get_namespace(coord_ext)
2628
model_ret = dict(fit_ret.items())
2729
for kk, vv in fit_ret.items():
2830
vdef = fit_output_def[kk]
@@ -31,7 +33,7 @@ def fit_output_to_model_output(
3133
if vdef.reducible:
3234
kk_redu = get_reduce_name(kk)
3335
# cast to energy prec brefore reduction
34-
model_ret[kk_redu] = np.sum(
36+
model_ret[kk_redu] = xp.sum(
3537
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
3638
)
3739
if vdef.r_differentiable:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.jax.common import (
3+
to_jax_array,
4+
)
5+
from deepmd.jax.utils.exclude_mask import (
6+
AtomExcludeMask,
7+
PairExcludeMask,
8+
)
9+
10+
11+
def base_atomic_model_set_attr(name, value):
12+
if name in {"out_bias", "out_std"}:
13+
value = to_jax_array(value)
14+
elif name == "pair_excl" and value is not None:
15+
value = PairExcludeMask(value.ntypes, value.exclude_types)
16+
elif name == "atom_excl" and value is not None:
17+
value = AtomExcludeMask(value.ntypes, value.exclude_types)
18+
return value
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP
7+
from deepmd.jax.atomic_model.base_atomic_model import (
8+
base_atomic_model_set_attr,
9+
)
10+
from deepmd.jax.common import (
11+
flax_module,
12+
)
13+
from deepmd.jax.descriptor.base_descriptor import (
14+
BaseDescriptor,
15+
)
16+
from deepmd.jax.fitting.base_fitting import (
17+
BaseFitting,
18+
)
19+
20+
21+
@flax_module
22+
class DPAtomicModel(DPAtomicModelDP):
23+
base_descriptor_cls = BaseDescriptor
24+
"""The base descriptor class."""
25+
base_fitting_cls = BaseFitting
26+
"""The base fitting class."""
27+
28+
def __setattr__(self, name: str, value: Any) -> None:
29+
value = base_atomic_model_set_attr(name, value)
30+
return super().__setattr__(name, value)

deepmd/jax/descriptor/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.jax.descriptor.dpa1 import (
3+
DescrptDPA1,
4+
)
5+
from deepmd.jax.descriptor.se_e2_a import (
6+
DescrptSeA,
7+
)
8+
9+
__all__ = [
10+
"DescrptSeA",
11+
"DescrptDPA1",
12+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.descriptor.make_base_descriptor import (
3+
make_base_descriptor,
4+
)
5+
from deepmd.jax.env import (
6+
jnp,
7+
)
8+
9+
BaseDescriptor = make_base_descriptor(jnp.ndarray)

deepmd/jax/descriptor/dpa1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
flax_module,
1717
to_jax_array,
1818
)
19+
from deepmd.jax.descriptor.base_descriptor import (
20+
BaseDescriptor,
21+
)
1922
from deepmd.jax.utils.exclude_mask import (
2023
PairExcludeMask,
2124
)
@@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None:
7679
return super().__setattr__(name, value)
7780

7881

82+
@BaseDescriptor.register("dpa1")
83+
@BaseDescriptor.register("se_atten")
7984
@flax_module
8085
class DescrptDPA1(DescrptDPA1DP):
8186
def __setattr__(self, name: str, value: Any) -> None:

0 commit comments

Comments
 (0)