Skip to content

Commit 1fbe4c1

Browse files
fix: merge get_np_precision to get_xp_precision (#4867)
`get_xp_precision` is more general and has more precisions (including bfloat16). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved internal handling of precision settings for numerical operations, streamlining how precision is determined. No changes to user-facing functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 98fb397 commit 1fbe4c1

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

deepmd/common.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636

3737
__all__ = [
38+
"GLOBAL_NP_FLOAT_PRECISION",
3839
"VALID_ACTIVATION",
3940
"VALID_PRECISION",
4041
"expand_sys_str",
@@ -249,16 +250,11 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:
249250
RuntimeError
250251
if string is invalid
251252
"""
252-
if precision == "default":
253-
return GLOBAL_NP_FLOAT_PRECISION
254-
elif precision == "float16":
255-
return np.float16
256-
elif precision == "float32":
257-
return np.float32
258-
elif precision == "float64":
259-
return np.float64
260-
else:
261-
raise RuntimeError(f"{precision} is not a valid precision")
253+
from deepmd.dpmodel.common import (
254+
get_xp_precision,
255+
)
256+
257+
return get_xp_precision(np, precision)
262258

263259

264260
def symlink_prefix_files(old_prefix: str, new_prefix: str) -> None:

0 commit comments

Comments
 (0)