|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from typing import ( |
| 3 | + Optional, |
| 4 | +) |
| 5 | + |
2 | 6 | from deepmd.dpmodel.model.base_model import ( |
3 | 7 | make_base_model, |
4 | 8 | ) |
| 9 | +from deepmd.dpmodel.output_def import ( |
| 10 | + get_deriv_name, |
| 11 | + get_reduce_name, |
| 12 | +) |
| 13 | +from deepmd.jax.env import ( |
| 14 | + jax, |
| 15 | + jnp, |
| 16 | +) |
5 | 17 |
|
6 | 18 | BaseModel = make_base_model() |
| 19 | + |
| 20 | + |
| 21 | +def forward_common_atomic( |
| 22 | + self, |
| 23 | + extended_coord: jnp.ndarray, |
| 24 | + extended_atype: jnp.ndarray, |
| 25 | + nlist: jnp.ndarray, |
| 26 | + mapping: Optional[jnp.ndarray] = None, |
| 27 | + fparam: Optional[jnp.ndarray] = None, |
| 28 | + aparam: Optional[jnp.ndarray] = None, |
| 29 | + do_atomic_virial: bool = False, |
| 30 | +): |
| 31 | + atomic_ret = self.atomic_model.forward_common_atomic( |
| 32 | + extended_coord, |
| 33 | + extended_atype, |
| 34 | + nlist, |
| 35 | + mapping=mapping, |
| 36 | + fparam=fparam, |
| 37 | + aparam=aparam, |
| 38 | + ) |
| 39 | + atomic_output_def = self.atomic_output_def() |
| 40 | + model_predict = {} |
| 41 | + for kk, vv in atomic_ret.items(): |
| 42 | + model_predict[kk] = vv |
| 43 | + vdef = atomic_output_def[kk] |
| 44 | + shap = vdef.shape |
| 45 | + atom_axis = -(len(shap) + 1) |
| 46 | + if vdef.reducible: |
| 47 | + kk_redu = get_reduce_name(kk) |
| 48 | + model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) |
| 49 | + kk_derv_r, kk_derv_c = get_deriv_name(kk) |
| 50 | + if vdef.c_differentiable: |
| 51 | + |
| 52 | + def eval_output( |
| 53 | + cc_ext, |
| 54 | + extended_atype, |
| 55 | + nlist, |
| 56 | + mapping, |
| 57 | + fparam, |
| 58 | + aparam, |
| 59 | + *, |
| 60 | + _kk=kk, |
| 61 | + _atom_axis=atom_axis, |
| 62 | + ): |
| 63 | + atomic_ret = self.atomic_model.forward_common_atomic( |
| 64 | + cc_ext[None, ...], |
| 65 | + extended_atype[None, ...], |
| 66 | + nlist[None, ...], |
| 67 | + mapping=mapping[None, ...] if mapping is not None else None, |
| 68 | + fparam=fparam[None, ...] if fparam is not None else None, |
| 69 | + aparam=aparam[None, ...] if aparam is not None else None, |
| 70 | + ) |
| 71 | + return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis) |
| 72 | + |
| 73 | + # extended_coord: [nf, nall, 3] |
| 74 | + # ff: [nf, *def, nall, 3] |
| 75 | + ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))( |
| 76 | + extended_coord, |
| 77 | + extended_atype, |
| 78 | + nlist, |
| 79 | + mapping, |
| 80 | + fparam, |
| 81 | + aparam, |
| 82 | + ) |
| 83 | + # extended_force: [nf, nall, *def, 3] |
| 84 | + def_ndim = len(vdef.shape) |
| 85 | + extended_force = jnp.transpose( |
| 86 | + ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] |
| 87 | + ) |
| 88 | + |
| 89 | + model_predict[kk_derv_r] = extended_force |
| 90 | + if vdef.c_differentiable: |
| 91 | + assert vdef.r_differentiable |
| 92 | + # avr: [nf, *def, nall, 3, 3] |
| 93 | + avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) |
| 94 | + # avr: [nf, *def, nall, 9] |
| 95 | + avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) |
| 96 | + # extended_virial: [nf, nall, *def, 9] |
| 97 | + extended_virial = jnp.transpose( |
| 98 | + avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] |
| 99 | + ) |
| 100 | + |
| 101 | + # the correction sums to zero, which does not contribute to global virial |
| 102 | + # cannot jit |
| 103 | + # if do_atomic_virial: |
| 104 | + # raise NotImplementedError("Atomic virial is not implemented yet.") |
| 105 | + # to [...,3,3] -> [...,9] |
| 106 | + model_predict[kk_derv_c] = extended_virial |
| 107 | + return model_predict |
0 commit comments