@@ -298,7 +298,6 @@ def nlist_distinguish_types(
298298 tmp_atype ,
299299 axis = 2 ,
300300 indices = nlist .masked_fill (mask , 0 ),
301- broadcast = False ,
302301 )
303302 tnlist = tnlist .masked_fill (mask , - 1 )
304303 snsel = tnlist .shape [2 ]
@@ -312,7 +311,11 @@ def nlist_distinguish_types(
312311 paddle .argsort (pick_mask , axis = - 1 , descending = True , stable = True ),
313312 )
314313 # nloc x s(nsel)
315- inlist = paddle .take_along_axis (nlist , axis = 2 , indices = imap , broadcast = False )
314+ inlist = paddle .take_along_axis (
315+ nlist ,
316+ axis = 2 ,
317+ indices = imap ,
318+ )
316319 inlist = inlist .masked_fill (~ (pick_mask .to (paddle .bool )), - 1 )
317320 # nloc x nsel[ii]
318321 ret_nlist .append (paddle .split (inlist , [ss , snsel - ss ], axis = - 1 )[0 ])
@@ -394,7 +397,9 @@ def build_multiple_neighbor_list(
394397 )
395398 # nb x nloc x nsel x 3
396399 coord2 = paddle .take_along_axis (
397- coord1 , axis = 1 , indices = index , broadcast = False
400+ coord1 ,
401+ axis = 1 ,
402+ indices = index ,
398403 ).reshape ([nb , nloc , nsel , 3 ])
399404 # nb x nloc x nsel x 3
400405 diff = coord2 - coord0 [:, :, None , :]
@@ -472,27 +477,27 @@ def extend_coord_with_ghosts(
472477 nbuff = paddle .amax (nbuff , axis = 0 )
473478 nbuff_cpu = nbuff .cpu ()
474479 xi = (
475- paddle .arange (
476- - nbuff_cpu [ 0 ], nbuff_cpu [ 0 ] + 1 , 1 , dtype = env .GLOBAL_PD_FLOAT_PRECISION
480+ paddle .arange (- nbuff_cpu [ 0 ], nbuff_cpu [ 0 ] + 1 , 1 ). to (
481+ dtype = env .GLOBAL_PD_FLOAT_PRECISION
477482 )
478483 # .cpu()
479484 ) # pylint: disable=no-explicit-dtype
480485 yi = (
481- paddle .arange (
482- - nbuff_cpu [ 1 ], nbuff_cpu [ 1 ] + 1 , 1 , dtype = env .GLOBAL_PD_FLOAT_PRECISION
486+ paddle .arange (- nbuff_cpu [ 1 ], nbuff_cpu [ 1 ] + 1 , 1 ). to (
487+ dtype = env .GLOBAL_PD_FLOAT_PRECISION
483488 )
484489 # .cpu()
485490 ) # pylint: disable=no-explicit-dtype
486491 zi = (
487- paddle .arange (
488- - nbuff_cpu [ 2 ], nbuff_cpu [ 2 ] + 1 , 1 , dtype = env .GLOBAL_PD_FLOAT_PRECISION
492+ paddle .arange (- nbuff_cpu [ 2 ], nbuff_cpu [ 2 ] + 1 , 1 ). to (
493+ dtype = env .GLOBAL_PD_FLOAT_PRECISION
489494 )
490495 # .cpu()
491496 ) # pylint: disable=no-explicit-dtype
492497 eye_3 = (
493- paddle .eye (3 , dtype = env . GLOBAL_PD_FLOAT_PRECISION )
498+ paddle .eye (3 )
494499 # .cpu()
495- )
500+ ). to ( dtype = env . GLOBAL_PD_FLOAT_PRECISION )
496501 xyz = xi .reshape ([- 1 , 1 , 1 , 1 ]).astype (eye_3 .dtype ) * eye_3 [0 ]
497502 xyz = xyz + yi .reshape ([1 , - 1 , 1 , 1 ]).astype (eye_3 .dtype ) * eye_3 [1 ]
498503 xyz = xyz + zi .reshape ([1 , 1 , - 1 , 1 ]).astype (eye_3 .dtype ) * eye_3 [2 ]
0 commit comments