33 Optional ,
44)
55
6+ import array_api_compat
67import numpy as np
78
89from deepmd .dpmodel .atomic_model .base_atomic_model import (
@@ -75,7 +76,8 @@ def __init__(
7576 else :
7677 self .atomic_model : T_AtomicModel = T_AtomicModel (* args , ** kwargs )
7778 self .precision_dict = PRECISION_DICT
78- self .reverse_precision_dict = RESERVED_PRECISON_DICT
79+ # not supported by flax
80+ # self.reverse_precision_dict = RESERVED_PRECISON_DICT
7981 self .global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
8082 self .global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION
8183
@@ -253,9 +255,7 @@ def input_type_cast(
253255 str ,
254256 ]:
255257 """Cast the input data to global float type."""
256- input_prec = self .reverse_precision_dict [
257- self .precision_dict [coord .dtype .name ]
258- ]
258+ input_prec = RESERVED_PRECISON_DICT [self .precision_dict [coord .dtype .name ]]
259259 ###
260260 ### type checking would not pass jit, convert to coord prec anyway
261261 ###
@@ -264,10 +264,7 @@ def input_type_cast(
264264 for vv in [box , fparam , aparam ]
265265 ]
266266 box , fparam , aparam = _lst
267- if (
268- input_prec
269- == self .reverse_precision_dict [self .global_np_float_precision ]
270- ):
267+ if input_prec == RESERVED_PRECISON_DICT [self .global_np_float_precision ]:
271268 return coord , box , fparam , aparam , input_prec
272269 else :
273270 pp = self .global_np_float_precision
@@ -286,8 +283,7 @@ def output_type_cast(
286283 ) -> dict [str , np .ndarray ]:
287284 """Convert the model output to the input prec."""
288285 do_cast = (
289- input_prec
290- != self .reverse_precision_dict [self .global_np_float_precision ]
286+ input_prec != RESERVED_PRECISON_DICT [self .global_np_float_precision ]
291287 )
292288 pp = self .precision_dict [input_prec ]
293289 odef = self .model_output_def ()
@@ -366,17 +362,18 @@ def _format_nlist(
366362 nnei : int ,
367363 extra_nlist_sort : bool = False ,
368364 ):
365+ xp = array_api_compat .array_namespace (extended_coord , nlist )
369366 n_nf , n_nloc , n_nnei = nlist .shape
370367 extended_coord = extended_coord .reshape ([n_nf , - 1 , 3 ])
371368 nall = extended_coord .shape [1 ]
372369 rcut = self .get_rcut ()
373370
374371 if n_nnei < nnei :
375372 # make a copy before revise
376- ret = np . concatenate (
373+ ret = xp . concat (
377374 [
378375 nlist ,
379- - 1 * np .ones ([n_nf , n_nloc , nnei - n_nnei ], dtype = nlist .dtype ),
376+ - 1 * xp .ones ([n_nf , n_nloc , nnei - n_nnei ], dtype = nlist .dtype ),
380377 ],
381378 axis = - 1 ,
382379 )
@@ -385,16 +382,16 @@ def _format_nlist(
385382 n_nf , n_nloc , n_nnei = nlist .shape
386383 # make a copy before revise
387384 m_real_nei = nlist >= 0
388- ret = np .where (m_real_nei , nlist , 0 )
385+ ret = xp .where (m_real_nei , nlist , 0 )
389386 coord0 = extended_coord [:, :n_nloc , :]
390387 index = ret .reshape (n_nf , n_nloc * n_nnei , 1 ).repeat (3 , axis = 2 )
391- coord1 = np .take_along_axis (extended_coord , index , axis = 1 )
388+ coord1 = xp .take_along_axis (extended_coord , index , axis = 1 )
392389 coord1 = coord1 .reshape (n_nf , n_nloc , n_nnei , 3 )
393- rr = np .linalg .norm (coord0 [:, :, None , :] - coord1 , axis = - 1 )
394- rr = np .where (m_real_nei , rr , float ("inf" ))
395- rr , ret_mapping = np .sort (rr , axis = - 1 ), np .argsort (rr , axis = - 1 )
396- ret = np .take_along_axis (ret , ret_mapping , axis = 2 )
397- ret = np .where (rr > rcut , - 1 , ret )
390+ rr = xp .linalg .norm (coord0 [:, :, None , :] - coord1 , axis = - 1 )
391+ rr = xp .where (m_real_nei , rr , float ("inf" ))
392+ rr , ret_mapping = xp .sort (rr , axis = - 1 ), xp .argsort (rr , axis = - 1 )
393+ ret = xp .take_along_axis (ret , ret_mapping , axis = 2 )
394+ ret = xp .where (rr > rcut , - 1 , ret )
398395 ret = ret [..., :nnei ]
399396 # not extra_nlist_sort and n_nnei <= nnei:
400397 elif n_nnei == nnei :
0 commit comments