Skip to content

Commit 3cecca4

Browse files
authored
Perf: replace unnecessary torch.split with indexing (#4505)
Some operations only use the first segment of the result tensor of `torch.split`. In this case, all the other segments are created and discarded. This slightly adds an overhead to the training process. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Simplified tensor slicing operations in the `RepformerLayer` class and the `nlist_distinguish_types` function, enhancing readability and performance. - **Documentation** - Updated comments for clarity regarding tensor shapes in the `RepformerLayer` class. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent beeb3d9 commit 3cecca4

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

deepmd/pt/model/descriptor/repformer_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor:
10031003
# nb x nloc x 3 x ng2
10041004
nb, nloc, _, ng2 = h2g2.shape
10051005
# nb x nloc x 3 x axis
1006-
h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0]
1006+
h2g2m = h2g2[..., :axis_neuron]
10071007
# nb x nloc x axis x ng2
10081008
g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
10091009
# nb x nloc x (axisxng2)

deepmd/pt/utils/nlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def nlist_distinguish_types(
310310
inlist = torch.gather(nlist, 2, imap)
311311
inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1)
312312
# nloc x nsel[ii]
313-
ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0])
313+
ret_nlist.append(inlist[..., :ss])
314314
return torch.concat(ret_nlist, dim=-1)
315315

316316

0 commit comments

Comments
 (0)