diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e5c0557851..05e2dc9d61 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -2,6 +2,8 @@ """Utilities for the array API.""" import array_api_compat +import array_api_extra as xpx +import numpy as np from packaging.version import ( Version, ) @@ -48,7 +50,11 @@ def xp_swapaxes(a, axis1, axis2): def xp_take_along_axis(arr, indices, axis): xp = array_api_compat.array_namespace(arr) - if Version(xp.__array_api_version__) >= Version("2024.12"): + if ( + Version(xp.__array_api_version__) >= Version("2024.12") + or array_api_compat.is_numpy_array(arr) + or array_api_compat.is_jax_array(arr) + ): # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 return xp.take_along_axis(arr, indices, axis=axis) arr = xp_swapaxes(arr, axis, -1) @@ -73,3 +79,22 @@ def xp_take_along_axis(arr, indices, axis): out = xp.take(arr, indices) out = xp.reshape(out, shape) return xp_swapaxes(out, axis, -1) + + +def xp_ravel(input: np.ndarray) -> np.ndarray: + """Flattens the input tensor.""" + xp = array_api_compat.array_namespace(input) + return xp.reshape(input, [-1]) + + +def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: + """Reduces all values from the src tensor to the indices specified in the index tensor.""" + xp = array_api_compat.array_namespace(input) + idx = xp.arange(input.size, dtype=xp.int64) + idx = xp.reshape(idx, input.shape) + new_idx = xp_take_along_axis(idx, index, axis=dim) + new_idx = xp_ravel(new_idx) + shape = input.shape + input = xp_ravel(input) + input = xpx.at(input, new_idx).add(xp_ravel(src)) + return xp.reshape(input, shape) diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index af1429ce25..f0081c6d65 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -3,6 +3,9 @@ import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_scatter_sum, +) from deepmd.dpmodel.common import ( GLOBAL_ENER_FLOAT_PRECISION, ) @@ -98,20 +101,12 @@ def communicate_extended_output( mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) - # jax only - if array_api_compat.is_jax_array(force): - from deepmd.jax.common import ( - scatter_sum, - ) - - force = scatter_sum( - force, - 1, - mapping, - model_ret[kk_derv_r], - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") + force = xp_scatter_sum( + force, + 1, + mapping, + model_ret[kk_derv_r], + ) new_ret[kk_derv_r] = force else: # name holders @@ -127,20 +122,12 @@ def communicate_extended_output( vldims + derv_c_ext_dims, dtype=vv.dtype, ) - # jax only - if array_api_compat.is_jax_array(virial): - from deepmd.jax.common import ( - scatter_sum, - ) - - virial = scatter_sum( - virial, - 1, - mapping, - model_ret[kk_derv_c], - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") + virial = xp_scatter_sum( + virial, + 1, + mapping, + model_ret[kk_derv_c], + ) new_ret[kk_derv_c] = virial new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) else: diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 59f36d11ad..f372e97eb5 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -95,13 +95,3 @@ def __dlpack__(self, *args, **kwargs): def __dlpack_device__(self, *args, **kwargs): return self.value.__dlpack_device__(*args, **kwargs) - - -def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray: - """Reduces all values from the src tensor to the indices specified in the index tensor.""" - idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape) - new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel() - shape = input.shape - input = input.ravel() - input = input.at[new_idx].add(src.ravel()) - return input.reshape(shape) diff --git a/pyproject.toml b/pyproject.toml index fd0c76839b..dbcf4eaab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ 'ml_dtypes', 'mendeleev', 'array-api-compat', + 'array-api-extra>=0.5.0', ] requires-python = ">=3.9" keywords = ["deepmd"]