Skip to content

Commit cf70d94

Browse files
perf: use torch.split in replace of slicing ops in repflow (#4687)
The benchmark result shows an overhead introduced by slicing operators in the current implementation. This PR replaces slicing for each tensor with a unified `torch.split` op. It brings a speed-up of 6.7% while improves the code readability. Tested on OMat with 9 DPA-3 layers and batch size=auto:512. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Streamlined the internal logic for processing numerical components to reduce complexity. - Enhanced internal validation checks related to the `bias` variable to boost overall system robustness and maintainability. - Updated method signatures for `optim_angle_update` and `optim_edge_update` to improve clarity and usability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <amoycaic@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a1b5089 commit cf70d94

File tree

2 files changed

+60
-96
lines changed

2 files changed

+60
-96
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -796,42 +796,36 @@ def optim_angle_update(
796796
feat: str = "edge",
797797
) -> np.ndarray:
798798
xp = array_api_compat.array_namespace(angle_ebd, node_ebd, edge_ebd)
799-
angle_dim = angle_ebd.shape[-1]
800-
node_dim = node_ebd.shape[-1]
801-
edge_dim = edge_ebd.shape[-1]
802-
sub_angle_idx = (0, angle_dim)
803-
sub_node_idx = (angle_dim, angle_dim + node_dim)
804-
sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim)
805-
sub_edge_idx_ik = (
806-
angle_dim + node_dim + edge_dim,
807-
angle_dim + node_dim + 2 * edge_dim,
808-
)
809799

810800
if feat == "edge":
801+
assert self.edge_angle_linear1 is not None
811802
matrix, bias = self.edge_angle_linear1.w, self.edge_angle_linear1.b
812803
elif feat == "angle":
804+
assert self.angle_self_linear is not None
813805
matrix, bias = self.angle_self_linear.w, self.angle_self_linear.b
814806
else:
815807
raise NotImplementedError
808+
assert bias is not None
809+
810+
angle_dim = angle_ebd.shape[-1]
811+
node_dim = node_ebd.shape[-1]
812+
edge_dim = edge_ebd.shape[-1]
816813
assert angle_dim + node_dim + 2 * edge_dim == matrix.shape[0]
814+
# Array API does not provide a way to split the array
815+
sub_angle = matrix[:angle_dim, ...] # angle_dim
816+
sub_node = matrix[angle_dim : angle_dim + node_dim, ...] # node_dim
817+
sub_edge_ij = matrix[
818+
angle_dim + node_dim : angle_dim + node_dim + edge_dim, ...
819+
] # edge_dim
820+
sub_edge_ik = matrix[angle_dim + node_dim + edge_dim :, ...] # edge_dim
817821

818822
# nf * nloc * a_sel * a_sel * angle_dim
819-
sub_angle_update = xp.matmul(
820-
angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1], :]
821-
)
822-
823+
sub_angle_update = xp.matmul(angle_ebd, sub_angle)
823824
# nf * nloc * angle_dim
824-
sub_node_update = xp.matmul(
825-
node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :]
826-
)
827-
825+
sub_node_update = xp.matmul(node_ebd, sub_node)
828826
# nf * nloc * a_nnei * angle_dim
829-
sub_edge_update_ij = xp.matmul(
830-
edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1], :]
831-
)
832-
sub_edge_update_ik = xp.matmul(
833-
edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1], :]
834-
)
827+
sub_edge_update_ij = xp.matmul(edge_ebd, sub_edge_ij)
828+
sub_edge_update_ik = xp.matmul(edge_ebd, sub_edge_ik)
835829

836830
result_update = (
837831
bias
@@ -851,36 +845,31 @@ def optim_edge_update(
851845
feat: str = "node",
852846
) -> np.ndarray:
853847
xp = array_api_compat.array_namespace(node_ebd, node_ebd_ext, edge_ebd, nlist)
854-
node_dim = node_ebd.shape[-1]
855-
edge_dim = edge_ebd.shape[-1]
856-
sub_node_idx = (0, node_dim)
857-
sub_node_ext_idx = (node_dim, 2 * node_dim)
858-
sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim)
859848

860849
if feat == "node":
861850
matrix, bias = self.node_edge_linear.w, self.node_edge_linear.b
862851
elif feat == "edge":
863852
matrix, bias = self.edge_self_linear.w, self.edge_self_linear.b
864853
else:
865854
raise NotImplementedError
866-
assert 2 * node_dim + edge_dim == matrix.shape[0]
855+
node_dim = node_ebd.shape[-1]
856+
edge_dim = edge_ebd.shape[-1]
857+
assert node_dim * 2 + edge_dim == matrix.shape[0]
858+
# Array API does not provide a way to split the array
859+
node = matrix[:node_dim, ...] # node_dim
860+
node_ext = matrix[node_dim : 2 * node_dim, ...] # node_dim
861+
edge = matrix[2 * node_dim : 2 * node_dim + edge_dim, ...] # edge_dim
867862

868863
# nf * nloc * node/edge_dim
869-
sub_node_update = xp.matmul(
870-
node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :]
871-
)
864+
sub_node_update = xp.matmul(node_ebd, node)
872865

873866
# nf * nall * node/edge_dim
874-
sub_node_ext_update = xp.matmul(
875-
node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1], :]
876-
)
867+
sub_node_ext_update = xp.matmul(node_ebd_ext, node_ext)
877868
# nf * nloc * nnei * node/edge_dim
878869
sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist)
879870

880871
# nf * nloc * nnei * node/edge_dim
881-
sub_edge_update = xp.matmul(
882-
edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1], :]
883-
)
872+
sub_edge_update = xp.matmul(edge_ebd, edge)
884873

885874
result_update = (
886875
bias

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -397,48 +397,37 @@ def optim_angle_update(
397397
edge_ebd: torch.Tensor,
398398
feat: str = "edge",
399399
) -> torch.Tensor:
400-
angle_dim = angle_ebd.shape[-1]
401-
node_dim = node_ebd.shape[-1]
402-
edge_dim = edge_ebd.shape[-1]
403-
sub_angle_idx = (0, angle_dim)
404-
sub_node_idx = (angle_dim, angle_dim + node_dim)
405-
sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim)
406-
sub_edge_idx_ik = (
407-
angle_dim + node_dim + edge_dim,
408-
angle_dim + node_dim + 2 * edge_dim,
409-
)
410-
411400
if feat == "edge":
401+
assert self.edge_angle_linear1 is not None
412402
matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias
413403
elif feat == "angle":
404+
assert self.angle_self_linear is not None
414405
matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias
415406
else:
416407
raise NotImplementedError
417-
assert angle_dim + node_dim + 2 * edge_dim == matrix.size()[0]
408+
assert bias is not None
418409

419-
# nf * nloc * a_sel * a_sel * angle_dim
420-
sub_angle_update = torch.matmul(
421-
angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1]]
410+
angle_dim = angle_ebd.shape[-1]
411+
node_dim = node_ebd.shape[-1]
412+
edge_dim = edge_ebd.shape[-1]
413+
# angle_dim, node_dim, edge_dim, edge_dim
414+
sub_angle, sub_node, sub_edge_ij, sub_edge_ik = torch.split(
415+
matrix, [angle_dim, node_dim, edge_dim, edge_dim]
422416
)
423417

418+
# nf * nloc * a_sel * a_sel * angle_dim
419+
sub_angle_update = torch.matmul(angle_ebd, sub_angle)
424420
# nf * nloc * angle_dim
425-
sub_node_update = torch.matmul(
426-
node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]]
427-
)
428-
421+
sub_node_update = torch.matmul(node_ebd, sub_node)
429422
# nf * nloc * a_nnei * angle_dim
430-
sub_edge_update_ij = torch.matmul(
431-
edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1]]
432-
)
433-
sub_edge_update_ik = torch.matmul(
434-
edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1]]
435-
)
423+
sub_edge_update_ij = torch.matmul(edge_ebd, sub_edge_ij)
424+
sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik)
436425

437426
result_update = (
438427
bias
439-
+ sub_node_update[:, :, None, None, :]
440-
+ sub_edge_update_ij[:, :, None, :, :]
441-
+ sub_edge_update_ik[:, :, :, None, :]
428+
+ sub_node_update.unsqueeze(2).unsqueeze(3)
429+
+ sub_edge_update_ij.unsqueeze(2)
430+
+ sub_edge_update_ik.unsqueeze(3)
442431
+ sub_angle_update
443432
)
444433
return result_update
@@ -451,42 +440,30 @@ def optim_edge_update(
451440
nlist: torch.Tensor,
452441
feat: str = "node",
453442
) -> torch.Tensor:
454-
node_dim = node_ebd.shape[-1]
455-
edge_dim = edge_ebd.shape[-1]
456-
sub_node_idx = (0, node_dim)
457-
sub_node_ext_idx = (node_dim, 2 * node_dim)
458-
sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim)
459-
460443
if feat == "node":
461444
matrix, bias = self.node_edge_linear.matrix, self.node_edge_linear.bias
462445
elif feat == "edge":
463446
matrix, bias = self.edge_self_linear.matrix, self.edge_self_linear.bias
464447
else:
465448
raise NotImplementedError
466-
assert 2 * node_dim + edge_dim == matrix.size()[0]
449+
assert bias is not None
467450

468-
# nf * nloc * node/edge_dim
469-
sub_node_update = torch.matmul(
470-
node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]]
471-
)
451+
node_dim = node_ebd.shape[-1]
452+
edge_dim = edge_ebd.shape[-1]
453+
# node_dim, node_dim, edge_dim
454+
node, node_ext, edge = torch.split(matrix, [node_dim, node_dim, edge_dim])
472455

456+
# nf * nloc * node/edge_dim
457+
sub_node_update = torch.matmul(node_ebd, node)
473458
# nf * nall * node/edge_dim
474-
sub_node_ext_update = torch.matmul(
475-
node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1]]
476-
)
459+
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
477460
# nf * nloc * nnei * node/edge_dim
478461
sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist)
479-
480462
# nf * nloc * nnei * node/edge_dim
481-
sub_edge_update = torch.matmul(
482-
edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1]]
483-
)
463+
sub_edge_update = torch.matmul(edge_ebd, edge)
484464

485465
result_update = (
486-
bias
487-
+ sub_node_update[:, :, None, :]
488-
+ sub_edge_update
489-
+ sub_node_ext_update
466+
bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
490467
)
491468
return result_update
492469

@@ -614,7 +591,7 @@ def forward(
614591
nb, nloc, self.n_multi_edge_message, self.n_dim
615592
)
616593
for head_index in range(self.n_multi_edge_message):
617-
n_update_list.append(node_edge_update_mul_head[:, :, head_index, :])
594+
n_update_list.append(node_edge_update_mul_head[..., head_index, :])
618595
else:
619596
n_update_list.append(node_edge_update)
620597
# update node_ebd
@@ -649,14 +626,14 @@ def forward(
649626
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd)
650627
else:
651628
# use the first a_compress_dim dim for node and edge
652-
node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim]
653-
edge_ebd_for_angle = edge_ebd[:, :, :, : self.e_a_compress_dim]
629+
node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim]
630+
edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim]
654631
else:
655632
node_ebd_for_angle = node_ebd
656633
edge_ebd_for_angle = edge_ebd
657634

658635
# nb x nloc x a_nnei x e_dim
659-
edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :]
636+
edge_for_angle = edge_ebd_for_angle[..., : self.a_sel, :]
660637
# nb x nloc x a_nnei x e_dim
661638
edge_for_angle = torch.where(
662639
a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0
@@ -704,9 +681,7 @@ def forward(
704681

705682
# nb x nloc x a_nnei x a_nnei x e_dim
706683
weighted_edge_angle_update = (
707-
a_sw[:, :, :, None, None]
708-
* a_sw[:, :, None, :, None]
709-
* edge_angle_update
684+
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
710685
)
711686
# nb x nloc x a_nnei x e_dim
712687
reduced_edge_angle_update = torch.sum(

0 commit comments

Comments
 (0)