Skip to content

Commit 6f01d02

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 39880b9 commit 6f01d02

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,15 @@ def call(
538538
xp_take_along_axis(flat_map, xp.reshape(nlist, (nframes, -1)), axis=1),
539539
nlist.shape,
540540
)
541-
541+
542542
if self.use_dynamic_sel:
543543
# get graph index
544544
edge_index, angle_index = get_graph_index(
545-
nlist, nlist_mask, a_nlist_mask, nall, use_loc_mapping=self.use_loc_mapping
545+
nlist,
546+
nlist_mask,
547+
a_nlist_mask,
548+
nall,
549+
use_loc_mapping=self.use_loc_mapping,
546550
)
547551
# flat all the tensors
548552
# n_edge x 1

deepmd/dpmodel/utils/network.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,9 @@ def get_graph_index(
10631063
n2e_index = n2e_index[xp.astype(nlist_mask, xp.bool)]
10641064

10651065
# node_ext(j) to edge(ij) index_select
1066-
frame_shift = xp.arange(nf, dtype=nlist.dtype) * (nall if not use_loc_mapping else nloc)
1066+
frame_shift = xp.arange(nf, dtype=nlist.dtype) * (
1067+
nall if not use_loc_mapping else nloc
1068+
)
10671069
shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis]
10681070
# n_edge
10691071
n_ext2e_index = shifted_nlist[xp.astype(nlist_mask, xp.bool)]

0 commit comments

Comments
 (0)