Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(
model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
Expand Down
60 changes: 54 additions & 6 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,65 @@ def eval_output(
assert vdef.r_differentiable
# avr: [nf, *def, nall, 3, 3]
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:

def eval_ce(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_atom_axis=atom_axis - 1,
):
# atomic_ret[_kk]: [nf, nloc, *def]
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
nloc = nlist.shape[0]
cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...]
cc_loc = jnp.reshape(cc_loc, [nloc, *[1] * def_ndim, 3])
# [*def, 3]
return jnp.sum(
atomic_ret[_kk][0, ..., None] * cc_loc, axis=_atom_axis
)

# extended_virial_corr: [nf, *def, 3, nall, 3]
extended_virial_corr = jax.vmap(jax.jacrev(eval_ce, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
# move the first 3 to the last
# [nf, *def, nall, 3, 3]
extended_virial_corr = jnp.transpose(
extended_virial_corr,
[
0,
*range(1, def_ndim + 1),
def_ndim + 2,
def_ndim + 3,
def_ndim + 1,
],
)
avr += extended_virial_corr
# to [...,3,3] -> [...,9]
# avr: [nf, *def, nall, 9]
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
# extended_virial: [nf, nall, *def, 9]
extended_virial = jnp.transpose(
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

# the correction sums to zero, which does not contribute to global virial
# cannot jit
# if do_atomic_virial:
# raise NotImplementedError("Atomic virial is not implemented yet.")
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
11 changes: 9 additions & 2 deletions deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
def __init__(
self,
stablehlo,
stablehlo_atomic_virial,
model_def_script,
type_map,
rcut,
Expand All @@ -58,6 +59,9 @@
sel,
) -> None:
self._call_lower = jax_export.deserialize(stablehlo).call
self._call_lower_atomic_virial = jax_export.deserialize(
stablehlo_atomic_virial
).call
self.stablehlo = stablehlo
self.type_map = type_map
self.rcut = rcut
Expand Down Expand Up @@ -170,14 +174,17 @@
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return self._call_lower(
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial
else:
call_lower = self._call_lower

Check warning on line 180 in deepmd/jax/model/hlo.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/model/hlo.py#L180

Added line #L180 was not covered by tests
return call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
do_atomic_virial,
)

def get_type_map(self) -> list[str]:
Expand Down
51 changes: 36 additions & 15 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,47 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
model_def_script = data["model_def_script"]
call_lower = model.call_lower

nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape(
"nf, nloc, nghost, nfp, nap"
)
exported = jax_export.export(jax.jit(call_lower))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, nfp), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
jax.ShapeDtypeStruct((nf, nap), jnp.float64)
if model.get_dim_aparam()
else None, # aparam
False, # do_atomic_virial
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")

def exported_whether_do_atomic_virial(do_atomic_virial):
def call_lower_with_fixed_do_atmic_virial(
coord, atype, nlist, nlist_start, fparam, aparam
):
return call_lower(
coord,
atype,
nlist,
nlist_start,
fparam,
aparam,
do_atomic_virial=do_atomic_virial,
)

return jax_export.export(jax.jit(call_lower_with_fixed_do_atmic_virial))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64),
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32),
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64),
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64),
jax.ShapeDtypeStruct((nf, model.get_numb_fparam()), jnp.float64)
if model.get_dim_fparam()
else None,
jax.ShapeDtypeStruct((nf, nloc, model.get_numb_aparam()), jnp.float64)
if model.get_dim_aparam()
else None,
)

exported = exported_whether_do_atomic_virial(do_atomic_virial=False)
exported_atomic_virial = exported_whether_do_atomic_virial(
do_atomic_virial=True
)
serialized: bytearray = exported.serialize()
serialized_atomic_virial = exported_atomic_virial.serialize()
data = data.copy()
data.setdefault("@variables", {})
data["@variables"]["stablehlo"] = np.void(serialized)
data["@variables"]["stablehlo_atomic_virial"] = np.void(
serialized_atomic_virial
)
data["constants"] = {
"type_map": model.get_type_map(),
"rcut": model.get_rcut(),
Expand Down
7 changes: 7 additions & 0 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def test_deep_eval(self):
self.atype,
)
rets.append(ret)
ret = deep_eval.eval(
self.coords,
self.box,
self.atype,
do_atomic_virial=True,
)
rets.append(ret)
for ret in rets[1:]:
for vv1, vv2 in zip(rets[0], ret):
if np.isnan(vv2).all():
Expand Down
10 changes: 9 additions & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
return [
ret["energy"],
ret["atom_ener"],
ret["force"],
ret["virial"],
ret["atom_virial"],
], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand All @@ -69,6 +75,7 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any:
numpy_to_torch(coords),
numpy_to_torch(atype),
box=numpy_to_torch(box),
do_atomic_virial=True,
).items()
}

Expand All @@ -83,5 +90,6 @@ def assert_jax_array(arr):
numpy_to_jax(coords),
numpy_to_jax(atype),
box=numpy_to_jax(box),
do_atomic_virial=True,
).items()
}
11 changes: 10 additions & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,30 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret["energy"].ravel(),
SKIP_FLAG,
SKIP_FLAG,
SKIP_FLAG,
)
elif backend is self.RefBackend.PT:
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
ret["virial"].ravel(),
ret["atom_virial"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
return (
ret[0].ravel(),
ret[1].ravel(),
ret[2].ravel(),
ret[3].ravel(),
ret[4].ravel(),
)
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
ret["energy_derv_c_redu"].ravel(),
ret["energy_derv_c"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")
Loading