Skip to content

Commit 6e815a2

Browse files
njzjzwanghan-iapcmcoderabbitai[bot]
authored
fix(dpmodel): fix precision (#4343)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new environment variable `DP_DTYPE_PROMOTION_STRICT` to enhance precision handling in TensorFlow tests. - Added a decorator `@cast_precision` to several descriptor classes, improving precision management during computations. - Updated JAX configuration to enable strict dtype promotion based on the new environment variable. - Enhanced serialization and deserialization processes to include precision attributes across multiple classes. - **Bug Fixes** - Enhanced type handling and input processing in the `GeneralFitting` class for better output predictions. - Improved handling of atomic contributions and exclusions in the `BaseAtomicModel` class. - Addressed potential type mismatches during matrix operations in the `NativeLayer` class. - **Chores** - Updated caching mechanisms in the testing workflow to ensure unique keys based on run parameters. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 058e066 commit 6e815a2

File tree

16 files changed

+175
-34
lines changed

16 files changed

+175
-34
lines changed

.github/workflows/test_python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
env:
6363
NUM_WORKERS: 0
6464
DP_TEST_TF2_ONLY: 1
65+
DP_DTYPE_PROMOTION_STRICT: 1
6566
if: matrix.group == 1
6667
- run: mv .test_durations .test_durations_${{ matrix.group }}
6768
- name: Upload partial durations

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,19 @@ def forward_common_atomic(
201201
ret_dict = self.apply_out_stat(ret_dict, atype)
202202

203203
# nf x nloc
204-
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
204+
atom_mask = ext_atom_mask[:, :nloc]
205205
if self.atom_excl is not None:
206-
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
206+
atom_mask = xp.logical_and(
207+
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
208+
)
207209

208210
for kk in ret_dict.keys():
209211
out_shape = ret_dict[kk].shape
210212
out_shape2 = math.prod(out_shape[2:])
211-
ret_dict[kk] = (
212-
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
213-
* atom_mask[:, :, None]
214-
).reshape(out_shape)
215-
ret_dict["mask"] = atom_mask
213+
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
214+
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
215+
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
216+
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)
216217

217218
return ret_dict
218219

deepmd/dpmodel/common.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
ABC,
44
abstractmethod,
55
)
6+
from functools import (
7+
wraps,
8+
)
69
from typing import (
710
Any,
11+
Callable,
812
Optional,
13+
overload,
914
)
1015

1116
import array_api_compat
@@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
116121
return np.from_dlpack(x)
117122

118123

124+
def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
125+
"""A decorator that casts and casts back the input
126+
and output tensor of a method.
127+
128+
The decorator should be used on an instance method.
129+
130+
The decorator will do the following thing:
131+
(1) It casts input arrays from the global precision
132+
to precision defined by property `precision`.
133+
(2) It casts output arrays from `precision` to
134+
the global precision.
135+
(3) It checks inputs and outputs and only casts when
136+
input or output is an array and its dtype matches
137+
the global precision and `precision`, respectively.
138+
If it does not match (e.g. it is an integer), the decorator
139+
will do nothing on it.
140+
141+
The decorator supports the array API.
142+
143+
Returns
144+
-------
145+
Callable
146+
a decorator that casts and casts back the input and
147+
output array of a method
148+
149+
Examples
150+
--------
151+
>>> class A:
152+
... def __init__(self):
153+
... self.precision = "float32"
154+
...
155+
... @cast_precision
156+
... def f(x: Array, y: Array) -> Array:
157+
... return x**2 + y
158+
"""
159+
160+
@wraps(func)
161+
def wrapper(self, *args, **kwargs):
162+
# only convert tensors
163+
returned_tensor = func(
164+
self,
165+
*[safe_cast_array(vv, "global", self.precision) for vv in args],
166+
**{
167+
kk: safe_cast_array(vv, "global", self.precision)
168+
for kk, vv in kwargs.items()
169+
},
170+
)
171+
if isinstance(returned_tensor, tuple):
172+
return tuple(
173+
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
174+
)
175+
elif isinstance(returned_tensor, dict):
176+
return {
177+
kk: safe_cast_array(vv, self.precision, "global")
178+
for kk, vv in returned_tensor.items()
179+
}
180+
else:
181+
return safe_cast_array(returned_tensor, self.precision, "global")
182+
183+
return wrapper
184+
185+
186+
@overload
187+
def safe_cast_array(
188+
input: np.ndarray, from_precision: str, to_precision: str
189+
) -> np.ndarray: ...
190+
@overload
191+
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...
192+
def safe_cast_array(
193+
input: Optional[np.ndarray], from_precision: str, to_precision: str
194+
) -> Optional[np.ndarray]:
195+
"""Convert an array from a precision to another precision.
196+
197+
If input is not an array or without the specific precision, the method will not
198+
cast it.
199+
200+
Array API is supported.
201+
202+
Parameters
203+
----------
204+
input : np.ndarray or None
205+
Input array
206+
from_precision : str
207+
Array data type that is casted from
208+
to_precision : str
209+
Array data type that casts to
210+
211+
Returns
212+
-------
213+
np.ndarray or None
214+
casted array
215+
"""
216+
if array_api_compat.is_array_api_obj(input):
217+
xp = array_api_compat.array_namespace(input)
218+
if input.dtype == get_xp_precision(xp, from_precision):
219+
return xp.astype(input, get_xp_precision(xp, to_precision))
220+
return input
221+
222+
119223
__all__ = [
120224
"GLOBAL_NP_FLOAT_PRECISION",
121225
"GLOBAL_ENER_FLOAT_PRECISION",

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
xp_take_along_axis,
2121
)
2222
from deepmd.dpmodel.common import (
23+
cast_precision,
2324
to_numpy_array,
2425
)
2526
from deepmd.dpmodel.utils import (
@@ -330,6 +331,7 @@ def __init__(
330331
self.tebd_dim = tebd_dim
331332
self.concat_output_tebd = concat_output_tebd
332333
self.trainable = trainable
334+
self.precision = precision
333335

334336
def get_rcut(self) -> float:
335337
"""Returns the cut-off radius."""
@@ -451,6 +453,7 @@ def change_type_map(
451453
obj["davg"] = obj["davg"][remap_index]
452454
obj["dstd"] = obj["dstd"][remap_index]
453455

456+
@cast_precision
454457
def call(
455458
self,
456459
coord_ext,

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
xp_take_along_axis,
1616
)
1717
from deepmd.dpmodel.common import (
18+
cast_precision,
1819
to_numpy_array,
1920
)
2021
from deepmd.dpmodel.utils import (
@@ -595,6 +596,7 @@ def init_subclass_params(sub_data, sub_class):
595596
self.rcut = self.repinit.get_rcut()
596597
self.ntypes = ntypes
597598
self.sel = self.repinit.sel
599+
self.precision = precision
598600

599601
def get_rcut(self) -> float:
600602
"""Returns the cut-off radius."""
@@ -760,6 +762,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
760762
stddev_list.append(self.repinit_three_body.stddev)
761763
return mean_list, stddev_list
762764

765+
@cast_precision
763766
def call(
764767
self,
765768
coord_ext: np.ndarray,

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NativeOP,
1717
)
1818
from deepmd.dpmodel.common import (
19+
cast_precision,
1920
to_numpy_array,
2021
)
2122
from deepmd.dpmodel.utils import (
@@ -30,9 +31,6 @@
3031
from deepmd.dpmodel.utils.update_sel import (
3132
UpdateSel,
3233
)
33-
from deepmd.env import (
34-
GLOBAL_NP_FLOAT_PRECISION,
35-
)
3634
from deepmd.utils.data_system import (
3735
DeepmdDataSystem,
3836
)
@@ -343,6 +341,7 @@ def reinit_exclude(
343341
self.exclude_types = exclude_types
344342
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
345343

344+
@cast_precision
346345
def call(
347346
self,
348347
coord_ext,
@@ -418,9 +417,7 @@ def call(
418417
# nf x nloc x ng x ng1
419418
grrg = np.einsum("flid,fljd->flij", gr, gr1)
420419
# nf x nloc x (ng x ng1)
421-
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype(
422-
GLOBAL_NP_FLOAT_PRECISION
423-
)
420+
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
424421
return grrg, gr[..., 1:], None, None, ww
425422

426423
def serialize(self) -> dict:
@@ -509,6 +506,7 @@ def update_sel(
509506

510507

511508
class DescrptSeAArrayAPI(DescrptSeA):
509+
@cast_precision
512510
def call(
513511
self,
514512
coord_ext,
@@ -588,7 +586,5 @@ def call(
588586
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
589587
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
590588
# nf x nloc x (ng x ng1)
591-
grrg = xp.astype(
592-
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
593-
)
589+
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
594590
return grrg, gr[..., 1:], None, None, ww

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NativeOP,
1616
)
1717
from deepmd.dpmodel.common import (
18+
cast_precision,
1819
get_xp_precision,
1920
to_numpy_array,
2021
)
@@ -292,6 +293,7 @@ def cal_g(
292293
gg = self.embeddings[(ll,)].call(ss)
293294
return gg
294295

296+
@cast_precision
295297
def call(
296298
self,
297299
coord_ext,
@@ -355,7 +357,6 @@ def call(
355357
res_rescale = 1.0 / 5.0
356358
res = xyz_scatter * res_rescale
357359
res = xp.reshape(res, (nf, nloc, ng))
358-
res = xp.astype(res, get_xp_precision(xp, "global"))
359360
return res, None, None, None, ww
360361

361362
def serialize(self) -> dict:

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NativeOP,
1616
)
1717
from deepmd.dpmodel.common import (
18+
cast_precision,
1819
get_xp_precision,
1920
to_numpy_array,
2021
)
@@ -267,6 +268,7 @@ def reinit_exclude(
267268
self.exclude_types = exclude_types
268269
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
269270

271+
@cast_precision
270272
def call(
271273
self,
272274
coord_ext,
@@ -320,7 +322,6 @@ def call(
320322
# we don't require atype is the same in all frames
321323
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
322324
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
323-
rr = xp.astype(rr, get_xp_precision(xp, self.precision))
324325

325326
for embedding_idx in itertools.product(
326327
range(self.ntypes), repeat=self.embeddings.ndim
@@ -352,7 +353,6 @@ def call(
352353
result += res_ij
353354
# nf x nloc x ng
354355
result = xp.reshape(result, (nf, nloc, ng))
355-
result = xp.astype(result, get_xp_precision(xp, "global"))
356356
return result, None, None, None, ww
357357

358358
def serialize(self) -> dict:

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
xp_take_along_axis,
1818
)
1919
from deepmd.dpmodel.common import (
20-
get_xp_precision,
20+
cast_precision,
2121
to_numpy_array,
2222
)
2323
from deepmd.dpmodel.utils import (
@@ -169,6 +169,7 @@ def __init__(
169169
self.tebd_dim = tebd_dim
170170
self.concat_output_tebd = concat_output_tebd
171171
self.trainable = trainable
172+
self.precision = precision
172173

173174
def get_rcut(self) -> float:
174175
"""Returns the cut-off radius."""
@@ -290,6 +291,7 @@ def change_type_map(
290291
obj["davg"] = obj["davg"][remap_index]
291292
obj["dstd"] = obj["dstd"][remap_index]
292293

294+
@cast_precision
293295
def call(
294296
self,
295297
coord_ext,
@@ -744,7 +746,6 @@ def call(
744746
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
745747
# nf x nl x ng
746748
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
747-
result = xp.astype(result, get_xp_precision(xp, "global"))
748749
return (
749750
result,
750751
None,

deepmd/dpmodel/fitting/dipole_fitting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from deepmd.dpmodel import (
1212
DEFAULT_PRECISION,
1313
)
14+
from deepmd.dpmodel.common import (
15+
cast_precision,
16+
)
1417
from deepmd.dpmodel.fitting.base_fitting import (
1518
BaseFitting,
1619
)
@@ -174,6 +177,7 @@ def output_def(self):
174177
]
175178
)
176179

180+
@cast_precision
177181
def call(
178182
self,
179183
descriptor: np.ndarray,

0 commit comments

Comments
 (0)