Skip to content

Commit 0c6c568

Browse files
pd: Ignore if branch of 0-size (#4617)
The support for higher-order differentiation and complex control flow in `paddle.jit` static graphs is not very strong. Therefore, branches related to 0-size are ignored to avoid accuracy issues during training. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved the consistency and stability of key computational processes across the application, ensuring that data processing occurs reliably under all conditions. - Enhanced compatibility with the underlying execution framework, leading to more predictable outcomes even in edge-case scenarios. - Adjusted control flow for handling empty coordinate tensors, allowing calculations to proceed under specific conditions. - Modified conditions for executing operations on variables, ensuring compatibility with the framework's operational modes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: HydrogenSulfate <490868991@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 80d445b commit 0c6c568

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

deepmd/pd/model/descriptor/se_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
prod_env_mat,
2020
)
2121
from deepmd.pd.utils import (
22+
decomp,
2223
env,
2324
)
2425
from deepmd.pd.utils.env import (
@@ -744,7 +745,8 @@ def forward(
744745
"Compressed environment is not implemented yet."
745746
)
746747
else:
747-
if rr.numel() > 0:
748+
# NOTE: control flow with double backward is not supported well yet by paddle.jit
749+
if not paddle.framework.in_dynamic_mode() or decomp.numel(rr) > 0:
748750
rr = rr * mm.unsqueeze(2).astype(rr.dtype)
749751
ss = rr[:, :, :1]
750752
if self.compress:

deepmd/pd/model/network/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def forward(
100100
yy: paddle.Tensor
101101
The output.
102102
"""
103-
# if xx.numel() > 0:
104-
if decomp.numel(xx):
103+
# NOTE: control flow with double backward is not supported well yet by paddle.jit
104+
if not paddle.framework.in_dynamic_mode() or decomp.numel(xx) > 0:
105105
variance, mean = (
106106
paddle.var(xx, axis=-1, unbiased=False, keepdim=True),
107107
paddle.mean(xx, axis=-1, keepdim=True),

deepmd/pd/utils/nlist.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import paddle
88

99
from deepmd.pd.utils import (
10+
decomp,
1011
env,
1112
)
1213
from deepmd.pd.utils.region import (
@@ -97,7 +98,9 @@ def build_neighbor_list(
9798
nall = coord.shape[1] // 3
9899
# fill virtual atoms with large coords so they are not neighbors of any
99100
# real atom.
100-
if coord.numel() > 0:
101+
102+
# NOTE: control flow with double backward is not supported well yet by paddle.jit
103+
if not paddle.framework.in_dynamic_mode() or decomp.numel(coord) > 0:
101104
xmax = paddle.max(coord) + 2.0 * rcut
102105
else:
103106
xmax = paddle.zeros([], dtype=coord.dtype).to(device=coord.place) + 2.0 * rcut
@@ -240,7 +243,8 @@ def build_directional_neighbor_list(
240243
nall_neig = coord_neig.shape[1] // 3
241244
# fill virtual atoms with large coords so they are not neighbors of any
242245
# real atom.
243-
if coord_neig.numel() > 0:
246+
# NOTE: control flow with double backward is not supported well yet by paddle.jit
247+
if not paddle.framework.in_dynamic_mode() or decomp.numel(coord_neig) > 0:
244248
xmax = paddle.max(coord_cntl) + 2.0 * rcut
245249
else:
246250
xmax = (

0 commit comments

Comments
 (0)