@@ -70,7 +70,10 @@ def test_consistency(
7070 rtol , atol = get_tols (prec )
7171 if prec == "float64" :
7272 atol = 1e-8 # marginal GPU test cases...
73-
73+ coord_ext = np .concatenate ([self .coord_ext [:1 ], self .coord_ext [:1 ]], axis = 0 )
74+ atype_ext = np .concatenate ([self .atype_ext [:1 ], self .atype_ext [:1 ]], axis = 0 )
75+ nlist = np .concatenate ([self .nlist [:1 ], self .nlist [:1 ]], axis = 0 )
76+ mapping = np .concatenate ([self .mapping [:1 ], self .mapping [:1 ]], axis = 0 )
7477 repflow = RepFlowArgs (
7578 n_dim = 20 ,
7679 e_dim = 10 ,
@@ -108,18 +111,18 @@ def test_consistency(
108111 dd0 .repflows .mean = paddle .to_tensor (davg , dtype = dtype , place = env .DEVICE )
109112 dd0 .repflows .stddev = paddle .to_tensor (dstd , dtype = dtype , place = env .DEVICE )
110113 rd0 , _ , _ , _ , _ = dd0 (
111- paddle .to_tensor (self . coord_ext , dtype = dtype , place = env .DEVICE ),
112- paddle .to_tensor (self . atype_ext , dtype = paddle .int64 , place = env .DEVICE ),
113- paddle .to_tensor (self . nlist , dtype = paddle .int64 , place = env .DEVICE ),
114- paddle .to_tensor (self . mapping , dtype = paddle .int64 , place = env .DEVICE ),
114+ paddle .to_tensor (coord_ext , dtype = dtype , place = env .DEVICE ),
115+ paddle .to_tensor (atype_ext , dtype = paddle .int64 , place = env .DEVICE ),
116+ paddle .to_tensor (nlist , dtype = paddle .int64 , place = env .DEVICE ),
117+ paddle .to_tensor (mapping , dtype = paddle .int64 , place = env .DEVICE ),
115118 )
116119 # serialization
117120 dd1 = DescrptDPA3 .deserialize (dd0 .serialize ())
118121 rd1 , _ , _ , _ , _ = dd1 (
119- paddle .to_tensor (self . coord_ext , dtype = dtype , place = env .DEVICE ),
120- paddle .to_tensor (self . atype_ext , dtype = paddle .int64 , place = env .DEVICE ),
121- paddle .to_tensor (self . nlist , dtype = paddle .int64 , place = env .DEVICE ),
122- paddle .to_tensor (self . mapping , dtype = paddle .int64 , place = env .DEVICE ),
122+ paddle .to_tensor (coord_ext , dtype = dtype , place = env .DEVICE ),
123+ paddle .to_tensor (atype_ext , dtype = paddle .int64 , place = env .DEVICE ),
124+ paddle .to_tensor (nlist , dtype = paddle .int64 , place = env .DEVICE ),
125+ paddle .to_tensor (mapping , dtype = paddle .int64 , place = env .DEVICE ),
123126 )
124127 np .testing .assert_allclose (
125128 rd0 .numpy (),
@@ -130,7 +133,7 @@ def test_consistency(
130133 # dp impl
131134 dd2 = DPDescrptDPA3 .deserialize (dd0 .serialize ())
132135 rd2 , _ , _ , _ , _ = dd2 .call (
133- self . coord_ext , self . atype_ext , self . nlist , self . mapping
136+ coord_ext , atype_ext , nlist , mapping
134137 )
135138 np .testing .assert_allclose (
136139 rd0 .numpy (),
0 commit comments