Skip to content

Commit 87426d5

Browse files
authored
perf: reschedule plus op (#4688)
Similar changes of #4677 Brings +5% speed up compared with #4687 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Adjusted the order of operations in update calculations to enhance clarity while maintaining the same functional outcomes. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 52f8ece commit 87426d5

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,11 +834,12 @@ def optim_angle_update(
834834
)
835835

836836
result_update = (
837-
sub_angle_update
837+
bias
838838
+ sub_node_update[:, :, xp.newaxis, xp.newaxis, :]
839839
+ sub_edge_update_ij[:, :, xp.newaxis, :, :]
840840
+ sub_edge_update_ik[:, :, :, xp.newaxis, :]
841-
) + bias
841+
+ sub_angle_update
842+
)
842843
return result_update
843844

844845
def optim_edge_update(
@@ -882,8 +883,11 @@ def optim_edge_update(
882883
)
883884

884885
result_update = (
885-
sub_edge_update + sub_node_ext_update + sub_node_update[:, :, xp.newaxis, :]
886-
) + bias
886+
bias
887+
+ sub_node_update[:, :, xp.newaxis, :]
888+
+ sub_edge_update
889+
+ sub_node_ext_update
890+
)
887891
return result_update
888892

889893
def call(

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,11 +435,12 @@ def optim_angle_update(
435435
)
436436

437437
result_update = (
438-
sub_angle_update
438+
bias
439439
+ sub_node_update[:, :, None, None, :]
440440
+ sub_edge_update_ij[:, :, None, :, :]
441441
+ sub_edge_update_ik[:, :, :, None, :]
442-
) + bias
442+
+ sub_angle_update
443+
)
443444
return result_update
444445

445446
def optim_edge_update(
@@ -482,8 +483,11 @@ def optim_edge_update(
482483
)
483484

484485
result_update = (
485-
sub_edge_update + sub_node_ext_update + sub_node_update[:, :, None, :]
486-
) + bias
486+
bias
487+
+ sub_node_update[:, :, None, :]
488+
+ sub_edge_update
489+
+ sub_node_ext_update
490+
)
487491
return result_update
488492

489493
def forward(

0 commit comments

Comments
 (0)