Skip to content

Commit 2018d62

Browse files
authored
fix(jax): fix repflows JIT issues (#4775)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved compatibility with just-in-time (JIT) compilation in certain scenarios, preventing potential errors during execution. End-user functionality remains unchanged. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 8a9fc78 commit 2018d62

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,10 @@ def call(
13281328
)
13291329
nb, nloc, nnei = nlist.shape
13301330
nall = node_ebd_ext.shape[1]
1331-
n_edge = int(xp.sum(xp.astype(nlist_mask, xp.int32)))
1331+
# int cannot jit; do not run it when self.use_dynamic_sel == False
1332+
n_edge = (
1333+
int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0
1334+
)
13321335
node_ebd = node_ebd_ext[:, :nloc, :]
13331336
assert (nb, nloc) == node_ebd.shape[:2]
13341337
if not self.use_dynamic_sel:

0 commit comments

Comments
 (0)