From 4cb1cbc1ae485d17fc8b94965ef700a379370ee2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 22:13:05 -0400 Subject: [PATCH 1/5] feat(jax): atomic virial For the frozen model, store two exported functions: one enables do_atomic_virial and the other doesn't. Signed-off-by: Jinzhe Zeng --- deepmd/jax/infer/deep_eval.py | 3 ++ deepmd/jax/model/base_model.py | 60 +++++++++++++++++++--- deepmd/jax/model/hlo.py | 11 +++- deepmd/jax/utils/serialization.py | 51 ++++++++++++------ source/tests/consistent/io/test_io.py | 7 +++ source/tests/consistent/model/common.py | 10 +++- source/tests/consistent/model/test_ener.py | 11 +++- 7 files changed, 128 insertions(+), 25 deletions(-) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 76f044a327..a508560f9d 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -93,6 +93,9 @@ def __init__( model_data = load_dp_model(model_file) self.dp = HLO( stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 8631c85d16..1e880700a2 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -91,17 +91,65 @@ def eval_output( assert vdef.r_differentiable # avr: [nf, *def, nall, 3, 3] avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + + def eval_ce( + cc_ext, + extended_atype, + nlist, + mapping, + fparam, + aparam, + *, + _kk=kk, + _atom_axis=atom_axis - 1, + ): + # atomic_ret[_kk]: [nf, nloc, *def] + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext[None, ...], + extended_atype[None, ...], + nlist[None, ...], + mapping=mapping[None, ...] if mapping is not None else None, + fparam=fparam[None, ...] if fparam is not None else None, + aparam=aparam[None, ...] if aparam is not None else None, + ) + nloc = nlist.shape[0] + cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...] + cc_loc = jnp.reshape(cc_loc, [nloc, *[1] * def_ndim, 3]) + # [*def, 3] + return jnp.sum( + atomic_ret[_kk][0, ..., None] * cc_loc, axis=_atom_axis + ) + + # extended_virial_corr: [nf, *def, 3, nall, 3] + extended_virial_corr = jax.vmap(jax.jacrev(eval_ce, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + # move the first 3 to the last + # [nf, *def, nall, 3, 3] + extended_virial_corr = jnp.transpose( + extended_virial_corr, + [ + 0, + *range(1, def_ndim + 1), + def_ndim + 2, + def_ndim + 3, + def_ndim + 1, + ], + ) + avr += extended_virial_corr + # to [...,3,3] -> [...,9] # avr: [nf, *def, nall, 9] avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) # extended_virial: [nf, nall, *def, 9] extended_virial = jnp.transpose( avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) - - # the correction sums to zero, which does not contribute to global virial - # cannot jit - # if do_atomic_virial: - # raise NotImplementedError("Atomic virial is not implemented yet.") - # to [...,3,3] -> [...,9] model_predict[kk_derv_c] = extended_virial return model_predict diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 010e3d7a5e..2946f8bec7 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -45,6 +45,7 @@ class HLO(BaseModel): def __init__( self, stablehlo, + stablehlo_atomic_virial, model_def_script, type_map, rcut, @@ -58,6 +59,9 @@ def __init__( sel, ) -> None: self._call_lower = jax_export.deserialize(stablehlo).call + self._call_lower_atomic_virial = jax_export.deserialize( + stablehlo_atomic_virial + ).call self.stablehlo = stablehlo self.type_map = type_map self.rcut = rcut @@ -170,14 +174,17 @@ def call_lower( aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): - return self._call_lower( + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower + return call_lower( extended_coord, extended_atype, nlist, mapping, fparam, aparam, - do_atomic_virial, ) def get_type_map(self) -> list[str]: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fcfcc8a610..1e00d73439 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -51,26 +51,47 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script = data["model_def_script"] call_lower = model.call_lower - nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape( - "nf, nloc, nghost, nfp, nap" - ) - exported = jax_export.export(jax.jit(call_lower))( - jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype - jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping - jax.ShapeDtypeStruct((nf, nfp), jnp.float64) - if model.get_dim_fparam() - else None, # fparam - jax.ShapeDtypeStruct((nf, nap), jnp.float64) - if model.get_dim_aparam() - else None, # aparam - False, # do_atomic_virial + nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") + + def exported_whether_do_atomic_virial(do_atomic_virial): + def call_lower_with_fixed_do_atmic_virial( + coord, atype, nlist, nlist_start, fparam, aparam + ): + return call_lower( + coord, + atype, + nlist, + nlist_start, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + return jax_export.export(jax.jit(call_lower_with_fixed_do_atmic_virial))( + jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), + jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), + jax.ShapeDtypeStruct((nf, model.get_numb_fparam()), jnp.float64) + if model.get_dim_fparam() + else None, + jax.ShapeDtypeStruct((nf, nloc, model.get_numb_aparam()), jnp.float64) + if model.get_dim_aparam() + else None, + ) + + exported = exported_whether_do_atomic_virial(do_atomic_virial=False) + exported_atomic_virial = exported_whether_do_atomic_virial( + do_atomic_virial=True ) serialized: bytearray = exported.serialize() + serialized_atomic_virial = exported_atomic_virial.serialize() data = data.copy() data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) + data["@variables"]["stablehlo_atomic_virial"] = np.void( + serialized_atomic_virial + ) data["constants"] = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index dc0f280d56..178f3115ca 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -151,6 +151,13 @@ def test_deep_eval(self): self.atype, ) rets.append(ret) + ret = deep_eval.eval( + self.coords, + self.box, + self.atype, + do_atomic_virial=True, + ) + rets.append(ret) for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 11940d9bdf..4eeb19b1f0 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -51,7 +51,13 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], { + return [ + ret["energy"], + ret["atom_ener"], + ret["force"], + ret["virial"], + ret["atom_virial"], + ], { t_coord: coords, t_type: atype, t_natoms: natoms, @@ -69,6 +75,7 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: numpy_to_torch(coords), numpy_to_torch(atype), box=numpy_to_torch(box), + do_atomic_virial=True, ).items() } @@ -83,5 +90,6 @@ def assert_jax_array(arr): numpy_to_jax(coords), numpy_to_jax(atype), box=numpy_to_jax(box), + do_atomic_virial=True, ).items() } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..8490b36ffe 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -216,6 +216,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["energy"].ravel(), SKIP_FLAG, SKIP_FLAG, + SKIP_FLAG, ) elif backend is self.RefBackend.PT: return ( @@ -223,14 +224,22 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["atom_energy"].ravel(), ret["force"].ravel(), ret["virial"].ravel(), + ret["atom_virial"].ravel(), ) elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + return ( + ret[0].ravel(), + ret[1].ravel(), + ret[2].ravel(), + ret[3].ravel(), + ret[4].ravel(), + ) elif backend is self.RefBackend.JAX: return ( ret["energy_redu"].ravel(), ret["energy"].ravel(), ret["energy_derv_r"].ravel(), ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") From ba7147a6f57e1537c040d1d52bfe8132b08a6931 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 22:56:45 -0400 Subject: [PATCH 2/5] fix typo Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 1e00d73439..fe5fdd1d22 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -54,7 +54,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") def exported_whether_do_atomic_virial(do_atomic_virial): - def call_lower_with_fixed_do_atmic_virial( + def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, nlist_start, fparam, aparam ): return call_lower( @@ -67,7 +67,7 @@ def call_lower_with_fixed_do_atmic_virial( do_atomic_virial=do_atomic_virial, ) - return jax_export.export(jax.jit(call_lower_with_fixed_do_atmic_virial))( + return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), From 6e29d505f46a72bb957f9eccb5a1f2ba7c47ac8a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 16:34:13 -0400 Subject: [PATCH 3/5] add comments Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/serialization.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fe5fdd1d22..102b588a9e 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -68,16 +68,16 @@ def call_lower_with_fixed_do_atomic_virial( ) return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( - jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), - jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), - jax.ShapeDtypeStruct((nf, model.get_numb_fparam()), jnp.float64) + jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() - else None, - jax.ShapeDtypeStruct((nf, nloc, model.get_numb_aparam()), jnp.float64) + else None, # fparam + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) if model.get_dim_aparam() - else None, + else None, # aparam ) exported = exported_whether_do_atomic_virial(do_atomic_virial=False) From 0fa72ef55e7941b9a055fe3a5575ff33f534e806 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:36:47 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/jax/utils/serialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 102b588a9e..ec2de3060e 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -68,7 +68,9 @@ def call_lower_with_fixed_do_atomic_virial( ) return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( - jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord + jax.ShapeDtypeStruct( + (nf, nloc + nghost, 3), jnp.float64 + ), # extended_coord jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping From 26f790aeaef0dd124f6855e19d2aa008859d5d74 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 17:26:15 -0400 Subject: [PATCH 5/5] fparam/aparam Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 908fbe98f0..91cd391322 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -167,6 +167,8 @@ def test_deep_eval(self): self.coords, self.box, self.atype, + fparam=fparam, + aparam=aparam, do_atomic_virial=True, ) rets.append(ret)