Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
Expand Down
15 changes: 8 additions & 7 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,19 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
atom_mask = ext_atom_mask[:, :nloc]
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)

return ret_dict

Expand Down
104 changes: 104 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
overload,
)

import array_api_compat
Expand Down Expand Up @@ -116,6 +121,105 @@
return np.from_dlpack(x)


def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.
The decorator should be used on an instance method.
The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.
The decorator supports the array API.
Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method
Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")

return wrapper


@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
def safe_cast_array(
input: Optional[np.ndarray], from_precision: str, to_precision: str
) -> Optional[np.ndarray]:
"""Convert an array from a precision to another precision.
If input is not an array or without the specific precision, the method will not
cast it.
Array API is supported.
Parameters
----------
input : np.ndarray or None
Input array
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to
Returns
-------
np.ndarray or None
casted array
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -329,6 +330,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -448,6 +450,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class):
self.rcut = self.repinit.get_rcut()
self.ntypes = ntypes
self.sel = self.repinit.sel
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

@cast_precision
def call(
self,
coord_ext: np.ndarray,
Expand Down
14 changes: 5 additions & 9 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand All @@ -29,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -340,6 +338,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -415,9 +414,7 @@ def call(
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype(
GLOBAL_NP_FLOAT_PRECISION
)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -506,6 +503,7 @@ def update_sel(


class DescrptSeAArrayAPI(DescrptSeA):
@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -585,7 +583,5 @@ def call(
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww
3 changes: 2 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -289,6 +290,7 @@ def cal_g(
gg = self.embeddings[(ll,)].call(ss)
return gg

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -352,7 +354,6 @@ def call(
res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -264,6 +265,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -317,7 +319,6 @@ def call(
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand Down Expand Up @@ -349,7 +350,6 @@ def call(
result += res_ij
# nf x nloc x ng
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -168,6 +168,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -287,6 +288,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -741,7 +743,6 @@ def call(
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
cast_precision,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -174,6 +177,7 @@ def output_def(self):
]
)

@cast_precision
def call(
self,
descriptor: np.ndarray,
Expand Down
20 changes: 12 additions & 8 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,18 +439,22 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
atom_property = xp.where(
mask, atom_property, xp.zeros_like(atom_property)
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
outs += xp.reshape(
xp.take(
xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0
),
[nf, nloc, net_dim_out],
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: outs}
Loading