Skip to content

Commit 05ba1bf

Browse files
authored
fix(array-api): fix xp.where errors (#4624)
`xp.where` always requires a bool array as its first input, but previously, the array-api-strict package didn't require it. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Enhanced the internal filtering logic by standardizing type handling for exclusion conditions, ensuring more reliable and consistent operations across the system. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 918d4de commit 05ba1bf

File tree

4 files changed

+4
-0
lines changed

4 files changed

+4
-0
lines changed

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ def call(
899899
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
900900
# nfnl x nnei
901901
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
902+
exclude_mask = xp.astype(exclude_mask, xp.bool)
902903
# nfnl x nnei
903904
nlist = xp.reshape(nlist, (nf * nloc, nnei))
904905
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def call(
393393
):
394394
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
395395
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
396+
exclude_mask = xp.astype(exclude_mask, xp.bool)
396397
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
397398
# nf x nloc x nnei x 4
398399
dmatrix, diff, sw = self.env_mat.call(

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def call(
682682
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
683683
# nfnl x nnei
684684
nlist = xp.reshape(nlist, (nf * nloc, nnei))
685+
exclude_mask = xp.astype(exclude_mask, xp.bool)
685686
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
686687
# nfnl x nnei
687688
nlist_mask = nlist != -1

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def _call_common(
488488
)
489489
# nf x nloc
490490
exclude_mask = self.emask.build_type_exclude_mask(atype)
491+
exclude_mask = xp.astype(exclude_mask, xp.bool)
491492
# nf x nloc x nod
492493
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
493494
return {self.var_name: outs}

0 commit comments

Comments
 (0)