Skip to content

Commit 78b8f01

Browse files
authored
perf: change order of element-wise op in edge angle update calculations (#4677)
This PR changes the order of element-wise multiply when calculating `weighted_edge_angle_update`. The largest matrix should be calculated last to avoid saving large intermediate results and unnecessary broadcast. I've tested this PR on OMat with 9 DPA-3 layers and batch size=auto:512. | Metric | Before | After | Improvement | |------------------------|----------|----------|-------------| | Peak Memory | 25.0G | 21.4G | -15% | | Speed (per 100 steps) | 31.27s | 28.9s | +7.7% | Since this is an element-wise multiply, changing the order of arguments should not affect the result. The correctness is verified by `torch.allclose`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Improved internal calculation order for weighted updates to enhance code clarity and maintainability, while ensuring the functionality remains consistent. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0918b22 commit 78b8f01

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,9 +1117,9 @@ def call(
11171117

11181118
# nb x nloc x a_nnei x a_nnei x e_dim
11191119
weighted_edge_angle_update = (
1120-
edge_angle_update
1121-
* a_sw[:, :, :, xp.newaxis, xp.newaxis]
1120+
a_sw[:, :, :, xp.newaxis, xp.newaxis]
11221121
* a_sw[:, :, xp.newaxis, :, xp.newaxis]
1122+
* edge_angle_update
11231123
)
11241124
# nb x nloc x a_nnei x e_dim
11251125
reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / (

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ def forward(
698698

699699
# nb x nloc x a_nnei x a_nnei x e_dim
700700
weighted_edge_angle_update = (
701-
edge_angle_update
702-
* a_sw[:, :, :, None, None]
701+
a_sw[:, :, :, None, None]
703702
* a_sw[:, :, None, :, None]
703+
* edge_angle_update
704704
)
705705
# nb x nloc x a_nnei x e_dim
706706
reduced_edge_angle_update = torch.sum(

0 commit comments

Comments
 (0)