Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class JAXBackend(Backend):
# | Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError
from deepmd.jax.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> type["NeighborStat"]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = xp.asarray(self.sel_cumsum)
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
# use UTC+0 time
"time": str(datetime.datetime.now(tz=datetime.timezone.utc)),
}
if filename_extension == ".dp":
if filename_extension in (".dp", ".hlo"):
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
Expand Down Expand Up @@ -141,7 +141,7 @@ def load_dp_model(filename: str) -> dict:
The loaded model dict, including meta information.
"""
filename_extension = Path(filename).suffix
if filename_extension == ".dp":
if filename_extension in {".dp", ".hlo"}:
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
Expand Down
2 changes: 2 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from flax import (
nnx,
)
from jax import export as jax_export

jax.config.update("jax_enable_x64", True)

__all__ = [
"jax",
"jnp",
"nnx",
"jax_export",
]
1 change: 1 addition & 0 deletions deepmd/jax/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading