Skip to content

Commit 82aaa0d

Browse files
feat(jax): neighbor stat (#4258)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `NeighborStat` and `NeighborStatOP` classes for enhanced neighbor statistics computation. - Added `AutoBatchSize` class to manage automatic batch sizing in deep learning applications. - **Improvements** - Enhanced `JAXBackend` functionality with implemented properties for neighbor statistics and serialization. - Refactored neighbor counting logic for better clarity and modularity. - **Tests** - Updated unit tests for `neighbor_stat` to support multiple backends (TensorFlow, PyTorch, NumPy, JAX). - Removed outdated test files to streamline testing processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent b647547 commit 82aaa0d

File tree

7 files changed

+210
-159
lines changed

7 files changed

+210
-159
lines changed

deepmd/backend/jax.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class JAXBackend(Backend):
3333
"""The formal name of the backend."""
3434
features: ClassVar[Backend.Feature] = (
3535
Backend.Feature.IO
36-
# Backend.Feature.ENTRY_POINT
36+
| Backend.Feature.ENTRY_POINT
3737
# | Backend.Feature.DEEP_EVAL
38-
# | Backend.Feature.NEIGHBOR_STAT
38+
| Backend.Feature.NEIGHBOR_STAT
3939
)
4040
"""The features of the backend."""
4141
suffixes: ClassVar[list[str]] = [".jax"]
@@ -82,7 +82,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
8282
type[NeighborStat]
8383
The neighbor statistics of the backend.
8484
"""
85-
raise NotImplementedError
85+
from deepmd.jax.utils.neighbor_stat import (
86+
NeighborStat,
87+
)
88+
89+
return NeighborStat
8690

8791
@property
8892
def serialize_hook(self) -> Callable[[str], dict]:

deepmd/dpmodel/utils/neighbor_stat.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Optional,
77
)
88

9+
import array_api_compat
910
import numpy as np
1011

1112
from deepmd.dpmodel.common import (
@@ -68,42 +69,42 @@ def call(
6869
np.ndarray
6970
The maximal number of neighbors
7071
"""
72+
xp = array_api_compat.array_namespace(coord, atype)
7173
nframes = coord.shape[0]
72-
coord = coord.reshape(nframes, -1, 3)
74+
coord = xp.reshape(coord, (nframes, -1, 3))
7375
nloc = coord.shape[1]
74-
coord = coord.reshape(nframes, nloc * 3)
76+
coord = xp.reshape(coord, (nframes, nloc * 3))
7577
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
7678
coord, atype, cell, self.rcut
7779
)
7880

79-
coord1 = extend_coord.reshape(nframes, -1)
81+
coord1 = xp.reshape(extend_coord, (nframes, -1))
8082
nall = coord1.shape[1] // 3
8183
coord0 = coord1[:, : nloc * 3]
8284
diff = (
83-
coord1.reshape([nframes, -1, 3])[:, None, :, :]
84-
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
85+
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
86+
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
8587
)
8688
assert list(diff.shape) == [nframes, nloc, nall, 3]
8789
# remove the diagonal elements
88-
mask = np.eye(nloc, nall, dtype=bool)
89-
diff[:, mask] = np.inf
90-
rr2 = np.sum(np.square(diff), axis=-1)
91-
min_rr2 = np.min(rr2, axis=-1)
90+
mask = xp.eye(nloc, nall, dtype=xp.bool)
91+
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
92+
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
93+
rr2 = xp.sum(xp.square(diff), axis=-1)
94+
min_rr2 = xp.min(rr2, axis=-1)
9295
# count the number of neighbors
9396
if not self.mixed_types:
9497
mask = rr2 < self.rcut**2
95-
nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int)
98+
nneis = []
9699
for ii in range(self.ntypes):
97-
nnei[:, :, ii] = np.sum(
98-
mask & (extend_atype == ii)[:, None, :], axis=-1
99-
)
100+
nneis.append(xp.sum(mask & (extend_atype == ii)[:, None, :], axis=-1))
101+
nnei = xp.stack(nneis, axis=-1)
100102
else:
101103
mask = rr2 < self.rcut**2
102104
# virtual type (<0) are not counted
103-
nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape(
104-
nframes, nloc, 1
105-
)
106-
max_nnei = np.max(nnei, axis=1)
105+
nnei = xp.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1)
106+
nnei = xp.reshape(nnei, (nframes, nloc, 1))
107+
max_nnei = xp.max(nnei, axis=1)
107108
return min_rr2, max_nnei
108109

109110

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
import jaxlib
4+
5+
from deepmd.jax.env import (
6+
jax,
7+
)
8+
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase
9+
10+
11+
class AutoBatchSize(AutoBatchSizeBase):
12+
"""Auto batch size.
13+
14+
Parameters
15+
----------
16+
initial_batch_size : int, default: 1024
17+
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
18+
is not set
19+
factor : float, default: 2.
20+
increased factor
21+
22+
"""
23+
24+
def __init__(
25+
self,
26+
initial_batch_size: int = 1024,
27+
factor: float = 2.0,
28+
):
29+
super().__init__(
30+
initial_batch_size=initial_batch_size,
31+
factor=factor,
32+
)
33+
34+
def is_gpu_available(self) -> bool:
35+
"""Check if GPU is available.
36+
37+
Returns
38+
-------
39+
bool
40+
True if GPU is available
41+
"""
42+
return jax.devices()[0].platform == "gpu"
43+
44+
def is_oom_error(self, e: Exception) -> bool:
45+
"""Check if the exception is an OOM error.
46+
47+
Parameters
48+
----------
49+
e : Exception
50+
Exception
51+
"""
52+
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
53+
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
54+
# (the meaningless error message should be considered as a bug in cusolver)
55+
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
56+
"RESOURCE_EXHAUSTED:" in e.args[0]
57+
):
58+
return True
59+
return False

deepmd/jax/utils/neighbor_stat.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections.abc import (
3+
Iterator,
4+
)
5+
from typing import (
6+
Optional,
7+
)
8+
9+
import numpy as np
10+
11+
from deepmd.dpmodel.common import (
12+
to_numpy_array,
13+
)
14+
from deepmd.dpmodel.utils.neighbor_stat import (
15+
NeighborStatOP,
16+
)
17+
from deepmd.jax.common import (
18+
to_jax_array,
19+
)
20+
from deepmd.jax.utils.auto_batch_size import (
21+
AutoBatchSize,
22+
)
23+
from deepmd.utils.data_system import (
24+
DeepmdDataSystem,
25+
)
26+
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat
27+
28+
29+
class NeighborStat(BaseNeighborStat):
30+
"""Neighbor statistics using JAX.
31+
32+
Parameters
33+
----------
34+
ntypes : int
35+
The num of atom types
36+
rcut : float
37+
The cut-off radius
38+
mixed_type : bool, optional, default=False
39+
Treat all types as a single type.
40+
"""
41+
42+
def __init__(
43+
self,
44+
ntypes: int,
45+
rcut: float,
46+
mixed_type: bool = False,
47+
) -> None:
48+
super().__init__(ntypes, rcut, mixed_type)
49+
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
50+
self.auto_batch_size = AutoBatchSize()
51+
52+
def iterator(
53+
self, data: DeepmdDataSystem
54+
) -> Iterator[tuple[np.ndarray, float, str]]:
55+
"""Iterator method for producing neighbor statistics data.
56+
57+
Yields
58+
------
59+
np.ndarray
60+
The maximal number of neighbors
61+
float
62+
The squared minimal distance between two atoms
63+
str
64+
The directory of the data system
65+
"""
66+
for ii in range(len(data.system_dirs)):
67+
for jj in data.data_systems[ii].dirs:
68+
data_set = data.data_systems[ii]
69+
data_set_data = data_set._load_set(jj)
70+
minrr2, max_nnei = self.auto_batch_size.execute_all(
71+
self._execute,
72+
data_set_data["coord"].shape[0],
73+
data_set.get_natoms(),
74+
data_set_data["coord"],
75+
data_set_data["type"],
76+
data_set_data["box"] if data_set.pbc else None,
77+
)
78+
yield np.max(max_nnei, axis=0), np.min(minrr2), jj
79+
80+
def _execute(
81+
self,
82+
coord: np.ndarray,
83+
atype: np.ndarray,
84+
cell: Optional[np.ndarray],
85+
):
86+
"""Execute the operation.
87+
88+
Parameters
89+
----------
90+
coord
91+
The coordinates of atoms.
92+
atype
93+
The atom types.
94+
cell
95+
The cell.
96+
"""
97+
minrr2, max_nnei = self.op(
98+
to_jax_array(coord),
99+
to_jax_array(atype),
100+
to_jax_array(cell),
101+
)
102+
minrr2 = to_numpy_array(minrr2)
103+
max_nnei = to_numpy_array(max_nnei)
104+
return minrr2, max_nnei

source/tests/common/dpmodel/test_neighbor_stat.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

source/tests/pt/test_neighbor_stat.py renamed to source/tests/consistent/test_neighbor_stat.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from ..seed import (
1313
GLOBAL_SEED,
1414
)
15+
from .common import (
16+
INSTALLED_JAX,
17+
INSTALLED_PT,
18+
INSTALLED_TF,
19+
)
1520

1621

1722
def gen_sys(nframes):
@@ -42,7 +47,7 @@ def setUp(self):
4247
def tearDown(self):
4348
shutil.rmtree("system_0")
4449

45-
def test_neighbor_stat(self):
50+
def run_neighbor_stat(self, backend):
4651
for rcut in (0.0, 1.0, 2.0, 4.0):
4752
for mixed_type in (True, False):
4853
with self.subTest(rcut=rcut, mixed_type=mixed_type):
@@ -52,7 +57,7 @@ def test_neighbor_stat(self):
5257
rcut=rcut,
5358
type_map=["TYPE", "NO_THIS_TYPE"],
5459
mixed_type=mixed_type,
55-
backend="pytorch",
60+
backend=backend,
5661
)
5762
upper = np.ceil(rcut) + 1
5863
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
@@ -67,3 +72,18 @@ def test_neighbor_stat(self):
6772
if not mixed_type:
6873
ret.append(0)
6974
np.testing.assert_array_equal(max_nbor_size, ret)
75+
76+
@unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
77+
def test_neighbor_stat_tf(self):
78+
self.run_neighbor_stat("tensorflow")
79+
80+
@unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
81+
def test_neighbor_stat_pt(self):
82+
self.run_neighbor_stat("pytorch")
83+
84+
def test_neighbor_stat_dp(self):
85+
self.run_neighbor_stat("numpy")
86+
87+
@unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
88+
def test_neighbor_stat_jax(self):
89+
self.run_neighbor_stat("jax")

0 commit comments

Comments
 (0)