Skip to content

Commit 30b1225

Browse files
committed
ut: fix pd compat with test data
1 parent 6f01d02 commit 30b1225

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

source/tests/pd/model/test_dpa3.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def test_consistency(
7070
rtol, atol = get_tols(prec)
7171
if prec == "float64":
7272
atol = 1e-8 # marginal GPU test cases...
73-
73+
coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0)
74+
atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0)
75+
nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0)
76+
mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0)
7477
repflow = RepFlowArgs(
7578
n_dim=20,
7679
e_dim=10,
@@ -108,18 +111,18 @@ def test_consistency(
108111
dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype, place=env.DEVICE)
109112
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype, place=env.DEVICE)
110113
rd0, _, _, _, _ = dd0(
111-
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
112-
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
113-
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
114-
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
114+
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
115+
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
116+
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
117+
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
115118
)
116119
# serialization
117120
dd1 = DescrptDPA3.deserialize(dd0.serialize())
118121
rd1, _, _, _, _ = dd1(
119-
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
120-
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
121-
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
122-
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
122+
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
123+
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
124+
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
125+
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
123126
)
124127
np.testing.assert_allclose(
125128
rd0.numpy(),
@@ -130,7 +133,7 @@ def test_consistency(
130133
# dp impl
131134
dd2 = DPDescrptDPA3.deserialize(dd0.serialize())
132135
rd2, _, _, _, _ = dd2.call(
133-
self.coord_ext, self.atype_ext, self.nlist, self.mapping
136+
coord_ext, atype_ext, nlist, mapping
134137
)
135138
np.testing.assert_allclose(
136139
rd0.numpy(),

0 commit comments

Comments
 (0)