Skip to content

Commit 3939786

Browse files
njzjzcoderabbitai[bot]pre-commit-ci[bot]
authored
feat(jax/array-api): dpa1 (#4160)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Updated method for converting input to NumPy arrays, enhancing performance and compatibility with array-like structures. - Simplified handling of weight, bias, and identity variables for improved compatibility with array backends. - Introduced new network classes and enhanced network management functionalities. - Added support for the new `array_api_strict` backend in testing. - **Bug Fixes** - Fixed serialization process to ensure accurate conversion of weights and biases. - **Tests** - Added tests to validate the new functionalities and ensure compatibility across various backends, including JAX and Array API Strict. - **Chores** - Continued improvements to project structure and dependencies for better maintainability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9a15bc0 commit 3939786

File tree

29 files changed

+1022
-173
lines changed

29 files changed

+1022
-173
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""Utilities for the array API."""
33

4+
import array_api_compat
5+
46

57
def support_array_api(version: str) -> callable:
68
"""Mark a function as supporting the specific version of the array API.
@@ -27,3 +29,41 @@ def set_version(func: callable) -> callable:
2729
return func
2830

2931
return set_version
32+
33+
34+
# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
35+
# but it hasn't been released yet
36+
# below is a pure Python implementation of take_along_axis
37+
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
38+
def xp_swapaxes(a, axis1, axis2):
39+
xp = array_api_compat.array_namespace(a)
40+
axes = list(range(a.ndim))
41+
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
42+
a = xp.permute_dims(a, axes)
43+
return a
44+
45+
46+
def xp_take_along_axis(arr, indices, axis):
47+
xp = array_api_compat.array_namespace(arr)
48+
arr = xp_swapaxes(arr, axis, -1)
49+
indices = xp_swapaxes(indices, axis, -1)
50+
51+
m = arr.shape[-1]
52+
n = indices.shape[-1]
53+
54+
shape = list(arr.shape)
55+
shape.pop(-1)
56+
shape = [*shape, n]
57+
58+
arr = xp.reshape(arr, (-1,))
59+
if n != 0:
60+
indices = xp.reshape(indices, (-1, n))
61+
else:
62+
indices = xp.reshape(indices, (0, 0))
63+
64+
offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
65+
indices = xp.reshape(offset + indices, (-1,))
66+
67+
out = xp.take(arr, indices)
68+
out = xp.reshape(out, shape)
69+
return xp_swapaxes(out, axis, -1)

0 commit comments

Comments
 (0)