Skip to content

Commit 26047e9

Browse files
fix condition block dtype mismatch in jit.save and enable 2 unitest
1 parent 7b2476f commit 26047e9

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

deepmd/pd/utils/nlist.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,11 @@ def build_multiple_neighbor_list(
385385
).to(device=nlist.place)
386386
# nb x nloc x nsel
387387
nlist = paddle.concat([nlist, pad], axis=-1)
388-
nsel = nsels[-1]
388+
if paddle.is_tensor(nsel):
389+
nsel = paddle.to_tensor(nsels[-1], dtype=nsel.dtype)
390+
else:
391+
nsel = nsels[-1]
392+
389393
# nb x nall x 3
390394
coord1 = coord.reshape([nb, -1, 3])
391395
nall = coord1.shape[1]

source/tests/pd/model/test_jit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def tearDown(self):
115115
JITTest.tearDown(self)
116116

117117

118-
@unittest.skip("var dtype int32/int64 confused in if block")
119118
class TestEnergyModelDPA2(unittest.TestCase, JITTest):
120119
def setUp(self):
121120
input_json = str(Path(__file__).parent / "water/se_atten.json")

source/tests/pd/test_multitask.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def tearDown(self):
183183
shutil.rmtree(f)
184184

185185

186-
@unittest.skip("Paddle do not support MultiTaskSeA.")
187186
class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest):
188187
def setUp(self):
189188
multitask_se_e2_a = deepcopy(multitask_template)

0 commit comments

Comments
 (0)