Skip to content

Commit 560d82e

Browse files
authored
feat(jax): reformat nlist in the TF model (#4336)
Reformat the neighbor list in the TF model to convert the dynamic shape to the determined shape so the TF model can accept the neighbor list with a dynamic shape. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new function to format neighbor lists based on selected neighbors and cutoff radius. - Enhanced deserialization process to incorporate the new formatting function for improved neighbor list handling. - **Tests** - Added a new test suite for the neighbor list formatting function, ensuring its functionality under various scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 4a9ed88 commit 560d82e

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

deepmd/jax/jax2tf/format_nlist.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import tensorflow as tf
3+
import tensorflow.experimental.numpy as tnp
4+
5+
6+
@tf.function(autograph=True)
7+
def format_nlist(
8+
extended_coord: tnp.ndarray,
9+
nlist: tnp.ndarray,
10+
nsel: int,
11+
rcut: float,
12+
):
13+
"""Format neighbor list.
14+
15+
If nnei == nsel, do nothing;
16+
If nnei < nsel, pad -1;
17+
If nnei > nsel, sort by distance and truncate.
18+
19+
Parameters
20+
----------
21+
extended_coord
22+
The extended coordinates of the atoms.
23+
shape: nf x nall x 3
24+
nlist
25+
The neighbor list.
26+
shape: nf x nloc x nnei
27+
nsel
28+
The number of selected neighbors.
29+
rcut
30+
The cutoff radius.
31+
32+
Returns
33+
-------
34+
nlist
35+
The formatted neighbor list.
36+
shape: nf x nloc x nsel
37+
"""
38+
nlist_shape = tf.shape(nlist)
39+
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
40+
extended_coord = extended_coord.reshape([n_nf, -1, 3])
41+
42+
if n_nsel < nsel:
43+
# make a copy before revise
44+
ret = tnp.concatenate(
45+
[
46+
nlist,
47+
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
48+
],
49+
axis=-1,
50+
)
51+
52+
elif n_nsel > nsel:
53+
# make a copy before revise
54+
m_real_nei = nlist >= 0
55+
ret = tnp.where(m_real_nei, nlist, 0)
56+
coord0 = extended_coord[:, :n_nloc, :]
57+
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
58+
index = tnp.repeat(index, 3, axis=2)
59+
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
60+
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
61+
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
62+
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
63+
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
64+
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
65+
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
66+
ret = ret[..., :nsel]
67+
else: # n_nsel == nsel:
68+
ret = nlist
69+
# do a reshape any way; this will tell the xla the shape without any dynamic shape
70+
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
71+
return ret

deepmd/jax/jax2tf/serialization.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
jax2tf,
1111
)
1212

13+
from deepmd.jax.jax2tf.format_nlist import (
14+
format_nlist,
15+
)
1316
from deepmd.jax.jax2tf.make_model import (
1417
model_call_from_call_lower,
1518
)
@@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
7679
input_signature=[
7780
tf.TensorSpec([None, None, 3], tf.float64),
7881
tf.TensorSpec([None, None], tf.int32),
79-
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
82+
tf.TensorSpec([None, None, None], tf.int64),
8083
tf.TensorSpec([None, None], tf.int64),
8184
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
8285
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
@@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
8588
def call_lower_without_atomic_virial(
8689
coord, atype, nlist, mapping, fparam, aparam
8790
):
91+
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
8892
return tf.cond(
8993
tf.shape(coord)[1] == tf.shape(nlist)[1],
9094
lambda: exported_whether_do_atomic_virial(
@@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
102106
input_signature=[
103107
tf.TensorSpec([None, None, 3], tf.float64),
104108
tf.TensorSpec([None, None], tf.int32),
105-
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
109+
tf.TensorSpec([None, None, None], tf.int64),
106110
tf.TensorSpec([None, None], tf.int64),
107111
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
108112
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
109113
],
110114
)
111115
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
116+
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
112117
return tf.cond(
113118
tf.shape(coord)[1] == tf.shape(nlist)[1],
114119
lambda: exported_whether_do_atomic_virial(
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import tensorflow as tf
3+
import tensorflow.experimental.numpy as tnp
4+
5+
from deepmd.jax.jax2tf.format_nlist import (
6+
format_nlist,
7+
)
8+
from deepmd.jax.jax2tf.nlist import (
9+
build_neighbor_list,
10+
extend_coord_with_ghosts,
11+
)
12+
13+
GLOBAL_SEED = 20241110
14+
15+
16+
class TestFormatNlist(tf.test.TestCase):
17+
def setUp(self):
18+
self.nf = 3
19+
self.nloc = 3
20+
self.ns = 5 * 5 * 3
21+
self.nall = self.ns * self.nloc
22+
self.cell = tnp.array(
23+
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
24+
)
25+
self.icoord = tnp.array(
26+
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
27+
dtype=tnp.float64,
28+
)
29+
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
30+
self.nsel = [10, 10]
31+
self.rcut = 1.01
32+
33+
self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(
34+
self.icoord, self.atype, self.cell, self.rcut
35+
)
36+
self.nlist = build_neighbor_list(
37+
self.ecoord,
38+
self.eatype,
39+
self.nloc,
40+
self.rcut,
41+
sum(self.nsel),
42+
distinguish_types=False,
43+
)
44+
45+
def test_format_nlist_equal(self):
46+
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
47+
self.assertAllEqual(nlist, self.nlist)
48+
49+
def test_format_nlist_less(self):
50+
nlist = build_neighbor_list(
51+
self.ecoord,
52+
self.eatype,
53+
self.nloc,
54+
self.rcut,
55+
sum(self.nsel) - 5,
56+
distinguish_types=False,
57+
)
58+
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
59+
self.assertAllEqual(nlist, self.nlist)
60+
61+
def test_format_nlist_large(self):
62+
nlist = build_neighbor_list(
63+
self.ecoord,
64+
self.eatype,
65+
self.nloc,
66+
self.rcut,
67+
sum(self.nsel) + 5,
68+
distinguish_types=False,
69+
)
70+
# random shuffle
71+
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
72+
nlist = tnp.take(nlist, shuffle_idx, axis=2)
73+
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
74+
# we only need to ensure the result is correct, no need to check the order
75+
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))
76+
77+
def test_format_nlist_larger_rcut(self):
78+
nlist = build_neighbor_list(
79+
self.ecoord,
80+
self.eatype,
81+
self.nloc,
82+
self.rcut * 2,
83+
40,
84+
distinguish_types=False,
85+
)
86+
# random shuffle
87+
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
88+
nlist = tnp.take(nlist, shuffle_idx, axis=2)
89+
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
90+
# we only need to ensure the result is correct, no need to check the order
91+
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

0 commit comments

Comments
 (0)