Skip to content

Commit dabedd2

Browse files
authored
fix(jax): calculate virial in call_lower (#4304)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced output of the model by providing a reduced form of the virial tensor, improving usability for further calculations and analyses. - Introduced a new test class, `TestEnerLower`, to evaluate lower-level energy models, excluding TensorFlow functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 9ed0397 commit dabedd2

File tree

2 files changed

+221
-1
lines changed

2 files changed

+221
-1
lines changed

deepmd/jax/model/base_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,6 @@ def eval_ce(
152152
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
153153
)
154154
model_predict[kk_derv_c] = extended_virial
155+
# [nf, *def, 9]
156+
model_predict[kk_derv_c + "_redu"] = jnp.sum(extended_virial, axis=1)
155157
return model_predict

source/tests/consistent/model/test_ener.py

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66

77
import numpy as np
88

9+
from deepmd.dpmodel.common import (
10+
to_numpy_array,
11+
)
912
from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP
1013
from deepmd.dpmodel.model.model import get_model as get_model_dp
14+
from deepmd.dpmodel.utils.nlist import (
15+
build_neighbor_list,
16+
extend_coord_with_ghosts,
17+
)
18+
from deepmd.dpmodel.utils.region import (
19+
normalize_coord,
20+
)
1121
from deepmd.env import (
1222
GLOBAL_NP_FLOAT_PRECISION,
1323
)
@@ -27,7 +37,8 @@
2737
if INSTALLED_PT:
2838
from deepmd.pt.model.model import get_model as get_model_pt
2939
from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT
30-
40+
from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy
41+
from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch
3142
else:
3243
EnergyModelPT = None
3344
if INSTALLED_TF:
@@ -39,6 +50,9 @@
3950
)
4051

4152
if INSTALLED_JAX:
53+
from deepmd.jax.common import (
54+
to_jax_array,
55+
)
4256
from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX
4357
from deepmd.jax.model.model import get_model as get_model_jax
4458
else:
@@ -243,3 +257,207 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
243257
ret["energy_derv_c"].ravel(),
244258
)
245259
raise ValueError(f"Unknown backend: {backend}")
260+
261+
262+
@parameterized(
263+
(
264+
[],
265+
[[0, 1]],
266+
),
267+
(
268+
[],
269+
[1],
270+
),
271+
)
272+
class TestEnerLower(CommonTest, ModelTest, unittest.TestCase):
273+
@property
274+
def data(self) -> dict:
275+
pair_exclude_types, atom_exclude_types = self.param
276+
return {
277+
"type_map": ["O", "H"],
278+
"pair_exclude_types": pair_exclude_types,
279+
"atom_exclude_types": atom_exclude_types,
280+
"descriptor": {
281+
"type": "se_e2_a",
282+
"sel": [20, 20],
283+
"rcut_smth": 0.50,
284+
"rcut": 6.00,
285+
"neuron": [
286+
3,
287+
6,
288+
],
289+
"resnet_dt": False,
290+
"axis_neuron": 2,
291+
"precision": "float64",
292+
"type_one_side": True,
293+
"seed": 1,
294+
},
295+
"fitting_net": {
296+
"neuron": [
297+
5,
298+
5,
299+
],
300+
"resnet_dt": True,
301+
"precision": "float64",
302+
"seed": 1,
303+
},
304+
}
305+
306+
tf_class = EnergyModelTF
307+
dp_class = EnergyModelDP
308+
pt_class = EnergyModelPT
309+
jax_class = EnergyModelJAX
310+
args = model_args()
311+
312+
def get_reference_backend(self):
313+
"""Get the reference backend.
314+
315+
We need a reference backend that can reproduce forces.
316+
"""
317+
if not self.skip_pt:
318+
return self.RefBackend.PT
319+
if not self.skip_jax:
320+
return self.RefBackend.JAX
321+
if not self.skip_dp:
322+
return self.RefBackend.DP
323+
raise ValueError("No available reference")
324+
325+
@property
326+
def skip_tf(self):
327+
# TF does not have lower interface
328+
return True
329+
330+
@property
331+
def skip_jax(self):
332+
return not INSTALLED_JAX
333+
334+
def pass_data_to_cls(self, cls, data) -> Any:
335+
"""Pass data to the class."""
336+
data = data.copy()
337+
if cls is EnergyModelDP:
338+
return get_model_dp(data)
339+
elif cls is EnergyModelPT:
340+
return get_model_pt(data)
341+
elif cls is EnergyModelJAX:
342+
return get_model_jax(data)
343+
return cls(**data, **self.additional_data)
344+
345+
def setUp(self):
346+
CommonTest.setUp(self)
347+
348+
self.ntypes = 2
349+
coords = np.array(
350+
[
351+
12.83,
352+
2.56,
353+
2.18,
354+
12.09,
355+
2.87,
356+
2.74,
357+
00.25,
358+
3.32,
359+
1.68,
360+
3.36,
361+
3.00,
362+
1.81,
363+
3.51,
364+
2.51,
365+
2.60,
366+
4.27,
367+
3.22,
368+
1.56,
369+
],
370+
dtype=GLOBAL_NP_FLOAT_PRECISION,
371+
).reshape(1, -1, 3)
372+
atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1)
373+
box = np.array(
374+
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
375+
dtype=GLOBAL_NP_FLOAT_PRECISION,
376+
).reshape(1, 9)
377+
378+
rcut = 6.0
379+
nframes, nloc = atype.shape[:2]
380+
coord_normalized = normalize_coord(
381+
coords.reshape(nframes, nloc, 3),
382+
box.reshape(nframes, 3, 3),
383+
)
384+
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
385+
coord_normalized, atype, box, rcut
386+
)
387+
nlist = build_neighbor_list(
388+
extended_coord,
389+
extended_atype,
390+
nloc,
391+
6.0,
392+
[20, 20],
393+
distinguish_types=True,
394+
)
395+
extended_coord = extended_coord.reshape(nframes, -1, 3)
396+
self.nlist = nlist
397+
self.extended_coord = extended_coord
398+
self.extended_atype = extended_atype
399+
self.mapping = mapping
400+
401+
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
402+
raise NotImplementedError("no TF in this test")
403+
404+
def eval_dp(self, dp_obj: Any) -> Any:
405+
return dp_obj.call_lower(
406+
self.extended_coord,
407+
self.extended_atype,
408+
self.nlist,
409+
self.mapping,
410+
do_atomic_virial=True,
411+
)
412+
413+
def eval_pt(self, pt_obj: Any) -> Any:
414+
return {
415+
kk: torch_to_numpy(vv)
416+
for kk, vv in pt_obj.forward_lower(
417+
numpy_to_torch(self.extended_coord),
418+
numpy_to_torch(self.extended_atype),
419+
numpy_to_torch(self.nlist),
420+
numpy_to_torch(self.mapping),
421+
do_atomic_virial=True,
422+
).items()
423+
}
424+
425+
def eval_jax(self, jax_obj: Any) -> Any:
426+
return {
427+
kk: to_numpy_array(vv)
428+
for kk, vv in jax_obj.call_lower(
429+
to_jax_array(self.extended_coord),
430+
to_jax_array(self.extended_atype),
431+
to_jax_array(self.nlist),
432+
to_jax_array(self.mapping),
433+
do_atomic_virial=True,
434+
).items()
435+
}
436+
437+
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
438+
# shape not matched. ravel...
439+
if backend is self.RefBackend.DP:
440+
return (
441+
ret["energy_redu"].ravel(),
442+
ret["energy"].ravel(),
443+
SKIP_FLAG,
444+
SKIP_FLAG,
445+
SKIP_FLAG,
446+
)
447+
elif backend is self.RefBackend.PT:
448+
return (
449+
ret["energy"].ravel(),
450+
ret["atom_energy"].ravel(),
451+
ret["extended_force"].ravel(),
452+
ret["virial"].ravel(),
453+
ret["extended_virial"].ravel(),
454+
)
455+
elif backend is self.RefBackend.JAX:
456+
return (
457+
ret["energy_redu"].ravel(),
458+
ret["energy"].ravel(),
459+
ret["energy_derv_r"].ravel(),
460+
ret["energy_derv_c_redu"].ravel(),
461+
ret["energy_derv_c"].ravel(),
462+
)
463+
raise ValueError(f"Unknown backend: {backend}")

0 commit comments

Comments
 (0)