Skip to content

Commit 6280a5a

Browse files
fix bugs and enable more pd unitests
1 parent 26047e9 commit 6280a5a

File tree

11 files changed

+159
-254
lines changed

11 files changed

+159
-254
lines changed

deepmd/pd/infer/deep_eval.py

Lines changed: 135 additions & 230 deletions
Large diffs are not rendered by default.

deepmd/pd/utils/env.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
set_default_nthreads,
1616
)
1717

18-
log = logging.getLogger(__name__)
19-
2018
SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
2119
try:
2220
# only linux
@@ -87,6 +85,7 @@ def enable_prim(enable: bool = True):
8785

8886
core.set_prim_eager_enabled(True)
8987
core._set_prim_all_enabled(True)
88+
log = logging.getLogger(__name__)
9089
log.info("Enable prim in eager and static mode.")
9190

9291

source/tests/pd/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def eval_model(
135135
if has_spin:
136136
input_dict["spin"] = batch_spin
137137
batch_output = model(**input_dict)
138+
# 'atom_energy', 'energy', 'force', 'virial', 'mask'
138139
if isinstance(batch_output, tuple):
139140
batch_output = batch_output[0]
140141
if not return_tensor:

source/tests/pd/model/test_forward_lower.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -22,6 +19,9 @@
2219
from ...seed import (
2320
GLOBAL_SEED,
2421
)
22+
from ..common import (
23+
eval_model,
24+
)
2525
from .test_permutation import ( # model_dpau,
2626
model_dpa1,
2727
model_dpa2,

source/tests/pd/model/test_make_hessian_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def ff(xx):
137137
)
138138

139139

140-
@unittest.skip("TODO")
140+
@unittest.skip("Skip temporarily")
141141
class TestDPModel(unittest.TestCase, HessianTest):
142142
def setUp(self):
143143
paddle.seed(2)

source/tests/pd/model/test_permutation_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -18,6 +15,9 @@
1815
from ...seed import (
1916
GLOBAL_SEED,
2017
)
18+
from ..common import (
19+
eval_model,
20+
)
2121
from .test_permutation import ( # model_dpau,
2222
model_dpa1,
2323
model_dpa2,

source/tests/pd/model/test_rot_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -18,6 +15,9 @@
1815
from ...seed import (
1916
GLOBAL_SEED,
2017
)
18+
from ..common import (
19+
eval_model,
20+
)
2121
from .test_permutation_denoise import (
2222
model_dpa1,
2323
model_dpa2,

source/tests/pd/model/test_smooth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -18,6 +15,9 @@
1815
from ...seed import (
1916
GLOBAL_SEED,
2017
)
18+
from ..common import (
19+
eval_model,
20+
)
2121
from .test_permutation import ( # model_dpau,
2222
model_dos,
2323
model_dpa1,

source/tests/pd/model/test_smooth_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -18,6 +15,9 @@
1815
from ...seed import (
1916
GLOBAL_SEED,
2017
)
18+
from ..common import (
19+
eval_model,
20+
)
2121
from .test_permutation_denoise import (
2222
model_dpa2,
2323
)

source/tests/pd/model/test_trans_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import numpy as np
66
import paddle
77

8-
from deepmd.pd.infer.deep_eval import (
9-
eval_model,
10-
)
118
from deepmd.pd.model.model import (
129
get_model,
1310
)
@@ -18,6 +15,9 @@
1815
from ...seed import (
1916
GLOBAL_SEED,
2017
)
18+
from ..common import (
19+
eval_model,
20+
)
2121
from .test_permutation_denoise import (
2222
model_dpa1,
2323
model_dpa2,

0 commit comments

Comments
 (0)