Skip to content
Merged
2 changes: 1 addition & 1 deletion deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class JAXBackend(Backend):
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down
27 changes: 18 additions & 9 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,24 @@
self.output_def = output_def
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
if model_file.endswith(".hlo"):
model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
elif model_file.endswith(".savedmodel"):
from deepmd.jax.jax2tf.tfmodel import (

Check warning on line 104 in deepmd/jax/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/infer/deep_eval.py#L103-L104

Added lines #L103 - L104 were not covered by tests
TFModelWrapper,
)

self.dp = TFModelWrapper(model_file)

Check warning on line 108 in deepmd/jax/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/infer/deep_eval.py#L108

Added line #L108 was not covered by tests
else:
raise ValueError("Unsupported file extension")

Check warning on line 110 in deepmd/jax/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/infer/deep_eval.py#L110

Added line #L110 was not covered by tests
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
11 changes: 11 additions & 0 deletions deepmd/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf

Check warning on line 2 in deepmd/jax/jax2tf/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/__init__.py#L2

Added line #L2 was not covered by tests

if not tf.executing_eagerly():

Check warning on line 4 in deepmd/jax/jax2tf/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/__init__.py#L4

Added line #L4 was not covered by tests
# TF disallow temporary eager execution
raise RuntimeError(

Check warning on line 6 in deepmd/jax/jax2tf/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/__init__.py#L6

Added line #L6 was not covered by tests
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the "
"TensorFlow backend (disables eager execution). "
"If you are converting a model between different backends, "
"considering converting to the `.dp` format first."
)
172 changes: 172 additions & 0 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

Check warning on line 2 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L2

Added line #L2 was not covered by tests

import tensorflow as tf
from jax.experimental import (

Check warning on line 5 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L4-L5

Added lines #L4 - L5 were not covered by tests
jax2tf,
)

from deepmd.jax.model.base_model import (

Check warning on line 9 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L9

Added line #L9 was not covered by tests
BaseModel,
)


def deserialize_to_file(model_file: str, data: dict) -> None:

Check warning on line 14 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L14

Added line #L14 was not covered by tests
"""Deserialize the dictionary to a model file.

Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".savedmodel"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
call_lower = model.call_lower

Check warning on line 27 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L24-L27

Added lines #L24 - L27 were not covered by tests

tf_model = tf.Module()

Check warning on line 29 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L29

Added line #L29 was not covered by tests

def exported_whether_do_atomic_virial(do_atomic_virial):
def call_lower_with_fixed_do_atomic_virial(

Check warning on line 32 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L31-L32

Added lines #L31 - L32 were not covered by tests
coord, atype, nlist, mapping, fparam, aparam
):
return call_lower(

Check warning on line 35 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L35

Added line #L35 was not covered by tests
coord,
atype,
nlist,
mapping,
fparam,
aparam,
do_atomic_virial=do_atomic_virial,
)

return jax2tf.convert(

Check warning on line 45 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L45

Added line #L45 was not covered by tests
call_lower_with_fixed_do_atomic_virial,
polymorphic_shapes=[
"(nf, nloc + nghost, 3)",
"(nf, nloc + nghost)",
f"(nf, nloc, {model.get_nnei()})",
"(nf, nloc + nghost)",
f"(nf, {model.get_dim_fparam()})",
f"(nf, nloc, {model.get_dim_aparam()})",
],
with_gradient=True,
)

# Save a function that can take scalar inputs.
# We need to explicit set the function name, so C++ can find it.
@tf.function(

Check warning on line 60 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L60

Added line #L60 was not covered by tests
autograph=False,
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_without_atomic_virial(

Check warning on line 71 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L71

Added line #L71 was not covered by tests
coord, atype, nlist, mapping, fparam, aparam
):
return exported_whether_do_atomic_virial(do_atomic_virial=False)(

Check warning on line 74 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L74

Added line #L74 was not covered by tests
coord, atype, nlist, mapping, fparam, aparam
)

tf_model.call_lower = call_lower_without_atomic_virial

Check warning on line 78 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L78

Added line #L78 was not covered by tests

@tf.function(

Check warning on line 80 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L80

Added line #L80 was not covered by tests
autograph=False,
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
return exported_whether_do_atomic_virial(do_atomic_virial=True)(

Check warning on line 92 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L91-L92

Added lines #L91 - L92 were not covered by tests
coord, atype, nlist, mapping, fparam, aparam
)

tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial

Check warning on line 96 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L96

Added line #L96 was not covered by tests

# set functions to export other attributes
@tf.function
def get_type_map():
return tf.constant(model.get_type_map(), dtype=tf.string)

Check warning on line 101 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L99-L101

Added lines #L99 - L101 were not covered by tests

tf_model.get_type_map = get_type_map

Check warning on line 103 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L103

Added line #L103 was not covered by tests

@tf.function
def get_rcut():
return tf.constant(model.get_rcut(), dtype=tf.double)

Check warning on line 107 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L105-L107

Added lines #L105 - L107 were not covered by tests

tf_model.get_rcut = get_rcut

Check warning on line 109 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L109

Added line #L109 was not covered by tests

@tf.function
def get_dim_fparam():
return tf.constant(model.get_dim_fparam(), dtype=tf.int64)

Check warning on line 113 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L111-L113

Added lines #L111 - L113 were not covered by tests

tf_model.get_dim_fparam = get_dim_fparam

Check warning on line 115 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L115

Added line #L115 was not covered by tests

@tf.function
def get_dim_aparam():
return tf.constant(model.get_dim_aparam(), dtype=tf.int64)

Check warning on line 119 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L117-L119

Added lines #L117 - L119 were not covered by tests

tf_model.get_dim_aparam = get_dim_aparam

Check warning on line 121 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L121

Added line #L121 was not covered by tests

@tf.function
def get_sel_type():
return tf.constant(model.get_sel_type(), dtype=tf.int64)

Check warning on line 125 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L123-L125

Added lines #L123 - L125 were not covered by tests

tf_model.get_sel_type = get_sel_type

Check warning on line 127 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L127

Added line #L127 was not covered by tests

@tf.function
def is_aparam_nall():
return tf.constant(model.is_aparam_nall(), dtype=tf.bool)

Check warning on line 131 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L129-L131

Added lines #L129 - L131 were not covered by tests

tf_model.is_aparam_nall = is_aparam_nall

Check warning on line 133 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L133

Added line #L133 was not covered by tests

@tf.function
def model_output_type():
return tf.constant(model.model_output_type(), dtype=tf.string)

Check warning on line 137 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L135-L137

Added lines #L135 - L137 were not covered by tests

tf_model.model_output_type = model_output_type

Check warning on line 139 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L139

Added line #L139 was not covered by tests

@tf.function
def mixed_types():
return tf.constant(model.mixed_types(), dtype=tf.bool)

Check warning on line 143 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L141-L143

Added lines #L141 - L143 were not covered by tests

tf_model.mixed_types = mixed_types

Check warning on line 145 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L145

Added line #L145 was not covered by tests

if model.get_min_nbor_dist() is not None:

Check warning on line 147 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L147

Added line #L147 was not covered by tests

@tf.function
def get_min_nbor_dist():
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double)

Check warning on line 151 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L149-L151

Added lines #L149 - L151 were not covered by tests

tf_model.get_min_nbor_dist = get_min_nbor_dist

Check warning on line 153 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L153

Added line #L153 was not covered by tests

@tf.function
def get_sel():
return tf.constant(model.get_sel(), dtype=tf.int64)

Check warning on line 157 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L155-L157

Added lines #L155 - L157 were not covered by tests

tf_model.get_sel = get_sel

Check warning on line 159 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L159

Added line #L159 was not covered by tests

@tf.function
def get_model_def_script():
return tf.constant(

Check warning on line 163 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L161-L163

Added lines #L161 - L163 were not covered by tests
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
)

tf_model.get_model_def_script = get_model_def_script
tf.saved_model.save(

Check warning on line 168 in deepmd/jax/jax2tf/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/serialization.py#L167-L168

Added lines #L167 - L168 were not covered by tests
tf_model,
model_file,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
Loading