Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
model_predict = fit_output_to_model_output(
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
89 changes: 83 additions & 6 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,28 @@ def fit_output_to_model_output(
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc.

Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.

Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +80,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.

"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -65,15 +89,68 @@ def communicate_extended_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.env import (
jnp,
)

f_idx = xp.arange(force.size, dtype=xp.int64).reshape(
force.shape
)
new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel()
f_shape = force.shape
force = force.ravel()
force = force.at[new_idx].add(model_ret[kk_derv_r].ravel())
force = force.reshape(f_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.env import (
jnp,
)

v_idx = xp.arange(virial.size, dtype=xp.int64).reshape(
virial.shape
)
new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel()
v_shape = virial.shape
virial = virial.ravel()
virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel())
virial = virial.reshape(v_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _make_env_mat(
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# the grad of JAX vector_norm is NaN at x=0
diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
t0 = 1 / (length + protection)
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

__all__ = [
"jax",
Expand Down
95 changes: 95 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,101 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_reduce_name,
)
from deepmd.jax.env import (
jax,
jnp,
)

BaseModel = make_base_model()


def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
atomic_output_def = self.atomic_output_def()
model_predict = {}
for kk, vv in atomic_ret.items():
model_predict[kk] = vv
vdef = atomic_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:
size = 1
for ii in vdef.shape:
size *= ii

split_ff = []
split_vv = []
for ss in range(size):

def eval_output(
cc_ext, extended_atype, nlist, mapping, fparam, aparam
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss]

ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
aviri = ffi[..., None] @ extended_coord[..., None, :]
ffi = ffi[..., None, :]
split_ff.append(ffi)
aviri = aviri[..., None, :]
split_vv.append(aviri)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
extended_force = jnp.concat(split_ff, axis=-2).reshape(
*out_lead_shape, 3
)

model_predict[kk_derv_r] = extended_force
if vdef.c_differentiable:
assert vdef.r_differentiable
extended_virial = jnp.concat(split_vv, axis=-2).reshape(
*out_lead_shape, 9
)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
raise NotImplementedError("Atomic virial is not implemented yet.")
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
26 changes: 26 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
Expand All @@ -10,8 +11,12 @@
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)


Expand All @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
4 changes: 4 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
"INSTALLED_ARRAY_API_STRICT",
]

SKIP_FLAG = object()


class CommonTest(ABC):
data: ClassVar[dict]
Expand Down Expand Up @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
data2 = dp_obj.serialize()
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
continue
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"]], {
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
Loading
Loading