Skip to content

Commit 6f3c323

Browse files
fix
1 parent 1c03d45 commit 6f3c323

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

deepmd/pd/utils/nlist.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ def nlist_distinguish_types(
298298
tmp_atype,
299299
axis=2,
300300
indices=nlist.masked_fill(mask, 0),
301-
broadcast=False,
302301
)
303302
tnlist = tnlist.masked_fill(mask, -1)
304303
snsel = tnlist.shape[2]
@@ -312,7 +311,11 @@ def nlist_distinguish_types(
312311
paddle.argsort(pick_mask, axis=-1, descending=True, stable=True),
313312
)
314313
# nloc x s(nsel)
315-
inlist = paddle.take_along_axis(nlist, axis=2, indices=imap, broadcast=False)
314+
inlist = paddle.take_along_axis(
315+
nlist,
316+
axis=2,
317+
indices=imap,
318+
)
316319
inlist = inlist.masked_fill(~(pick_mask.to(paddle.bool)), -1)
317320
# nloc x nsel[ii]
318321
ret_nlist.append(paddle.split(inlist, [ss, snsel - ss], axis=-1)[0])
@@ -394,7 +397,9 @@ def build_multiple_neighbor_list(
394397
)
395398
# nb x nloc x nsel x 3
396399
coord2 = paddle.take_along_axis(
397-
coord1, axis=1, indices=index, broadcast=False
400+
coord1,
401+
axis=1,
402+
indices=index,
398403
).reshape([nb, nloc, nsel, 3])
399404
# nb x nloc x nsel x 3
400405
diff = coord2 - coord0[:, :, None, :]
@@ -472,27 +477,27 @@ def extend_coord_with_ghosts(
472477
nbuff = paddle.amax(nbuff, axis=0)
473478
nbuff_cpu = nbuff.cpu()
474479
xi = (
475-
paddle.arange(
476-
-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, dtype=env.GLOBAL_PD_FLOAT_PRECISION
480+
paddle.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1).to(
481+
dtype=env.GLOBAL_PD_FLOAT_PRECISION
477482
)
478483
# .cpu()
479484
) # pylint: disable=no-explicit-dtype
480485
yi = (
481-
paddle.arange(
482-
-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, dtype=env.GLOBAL_PD_FLOAT_PRECISION
486+
paddle.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1).to(
487+
dtype=env.GLOBAL_PD_FLOAT_PRECISION
483488
)
484489
# .cpu()
485490
) # pylint: disable=no-explicit-dtype
486491
zi = (
487-
paddle.arange(
488-
-nbuff_cpu[2], nbuff_cpu[2] + 1, 1, dtype=env.GLOBAL_PD_FLOAT_PRECISION
492+
paddle.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1).to(
493+
dtype=env.GLOBAL_PD_FLOAT_PRECISION
489494
)
490495
# .cpu()
491496
) # pylint: disable=no-explicit-dtype
492497
eye_3 = (
493-
paddle.eye(3, dtype=env.GLOBAL_PD_FLOAT_PRECISION)
498+
paddle.eye(3)
494499
# .cpu()
495-
)
500+
).to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
496501
xyz = xi.reshape([-1, 1, 1, 1]).astype(eye_3.dtype) * eye_3[0]
497502
xyz = xyz + yi.reshape([1, -1, 1, 1]).astype(eye_3.dtype) * eye_3[1]
498503
xyz = xyz + zi.reshape([1, 1, -1, 1]).astype(eye_3.dtype) * eye_3[2]

0 commit comments

Comments
 (0)