diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index c1967fb0da..b60076c68c 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -93,6 +93,9 @@ def __init__( model_data = load_dp_model(model_file) self.dp = HLO( stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 8631c85d16..1e880700a2 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -91,17 +91,65 @@ def eval_output( assert vdef.r_differentiable # avr: [nf, *def, nall, 3, 3] avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + + def eval_ce( + cc_ext, + extended_atype, + nlist, + mapping, + fparam, + aparam, + *, + _kk=kk, + _atom_axis=atom_axis - 1, + ): + # atomic_ret[_kk]: [nf, nloc, *def] + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext[None, ...], + extended_atype[None, ...], + nlist[None, ...], + mapping=mapping[None, ...] if mapping is not None else None, + fparam=fparam[None, ...] if fparam is not None else None, + aparam=aparam[None, ...] if aparam is not None else None, + ) + nloc = nlist.shape[0] + cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...] + cc_loc = jnp.reshape(cc_loc, [nloc, *[1] * def_ndim, 3]) + # [*def, 3] + return jnp.sum( + atomic_ret[_kk][0, ..., None] * cc_loc, axis=_atom_axis + ) + + # extended_virial_corr: [nf, *def, 3, nall, 3] + extended_virial_corr = jax.vmap(jax.jacrev(eval_ce, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + # move the first 3 to the last + # [nf, *def, nall, 3, 3] + extended_virial_corr = jnp.transpose( + extended_virial_corr, + [ + 0, + *range(1, def_ndim + 1), + def_ndim + 2, + def_ndim + 3, + def_ndim + 1, + ], + ) + avr += extended_virial_corr + # to [...,3,3] -> [...,9] # avr: [nf, *def, nall, 9] avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) # extended_virial: [nf, nall, *def, 9] extended_virial = jnp.transpose( avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) - - # the correction sums to zero, which does not contribute to global virial - # cannot jit - # if do_atomic_virial: - # raise NotImplementedError("Atomic virial is not implemented yet.") - # to [...,3,3] -> [...,9] model_predict[kk_derv_c] = extended_virial return model_predict diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 010e3d7a5e..2946f8bec7 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -45,6 +45,7 @@ class HLO(BaseModel): def __init__( self, stablehlo, + stablehlo_atomic_virial, model_def_script, type_map, rcut, @@ -58,6 +59,9 @@ def __init__( sel, ) -> None: self._call_lower = jax_export.deserialize(stablehlo).call + self._call_lower_atomic_virial = jax_export.deserialize( + stablehlo_atomic_virial + ).call self.stablehlo = stablehlo self.type_map = type_map self.rcut = rcut @@ -170,14 +174,17 @@ def call_lower( aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): - return self._call_lower( + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower + return call_lower( extended_coord, extended_atype, nlist, mapping, fparam, aparam, - do_atomic_virial, ) def get_type_map(self) -> list[str]: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index a7d57523e2..ec2de3060e 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -52,23 +52,48 @@ def deserialize_to_file(model_file: str, data: dict) -> None: call_lower = model.call_lower nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") - exported = jax_export.export(jax.jit(call_lower))( - jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype - jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping - jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) - if model.get_dim_fparam() - else None, # fparam - jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) - if model.get_dim_aparam() - else None, # aparam - False, # do_atomic_virial + + def exported_whether_do_atomic_virial(do_atomic_virial): + def call_lower_with_fixed_do_atomic_virial( + coord, atype, nlist, nlist_start, fparam, aparam + ): + return call_lower( + coord, + atype, + nlist, + nlist_start, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( + jax.ShapeDtypeStruct( + (nf, nloc + nghost, 3), jnp.float64 + ), # extended_coord + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) + if model.get_dim_fparam() + else None, # fparam + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) + if model.get_dim_aparam() + else None, # aparam + ) + + exported = exported_whether_do_atomic_virial(do_atomic_virial=False) + exported_atomic_virial = exported_whether_do_atomic_virial( + do_atomic_virial=True ) serialized: bytearray = exported.serialize() + serialized_atomic_virial = exported_atomic_virial.serialize() data = data.copy() data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) + data["@variables"]["stablehlo_atomic_virial"] = np.void( + serialized_atomic_virial + ) data["constants"] = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index af26c41694..91cd391322 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -163,6 +163,15 @@ def test_deep_eval(self): aparam=aparam, ) rets.append(ret) + ret = deep_eval.eval( + self.coords, + self.box, + self.atype, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + rets.append(ret) for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 11940d9bdf..4eeb19b1f0 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -51,7 +51,13 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], { + return [ + ret["energy"], + ret["atom_ener"], + ret["force"], + ret["virial"], + ret["atom_virial"], + ], { t_coord: coords, t_type: atype, t_natoms: natoms, @@ -69,6 +75,7 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: numpy_to_torch(coords), numpy_to_torch(atype), box=numpy_to_torch(box), + do_atomic_virial=True, ).items() } @@ -83,5 +90,6 @@ def assert_jax_array(arr): numpy_to_jax(coords), numpy_to_jax(atype), box=numpy_to_jax(box), + do_atomic_virial=True, ).items() } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..8490b36ffe 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -216,6 +216,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["energy"].ravel(), SKIP_FLAG, SKIP_FLAG, + SKIP_FLAG, ) elif backend is self.RefBackend.PT: return ( @@ -223,14 +224,22 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["atom_energy"].ravel(), ret["force"].ravel(), ret["virial"].ravel(), + ret["atom_virial"].ravel(), ) elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + return ( + ret[0].ravel(), + ret[1].ravel(), + ret[2].ravel(), + ret[3].ravel(), + ret[4].ravel(), + ) elif backend is self.RefBackend.JAX: return ( ret["energy_redu"].ravel(), ret["energy"].ravel(), ret["energy_derv_r"].ravel(), ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), ) raise ValueError(f"Unknown backend: {backend}")