Skip to content

Commit b1ea0e9

Browse files
Yi-FanLipre-commit-ci[bot]njzjz
authored
Bugfix for pimd: sorting model_devi files (#1470)
The previous regular expression did not sort the multiple model_devi.out files in pimd correctly. As a result, the correspondence between the model_devi files and the trajectory files was wrong. Since the wrong configurations were picked, the model_devi accuracy cannot be improved to a very high level when pimd is used. This PR fixes the bug. --------- Signed-off-by: Yifan Li李一帆 <yifanl0716@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent aefcbc2 commit b1ea0e9

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

dpgen/generator/run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,10 +2168,11 @@ def _read_model_devi_file(
21682168
model_devi_merge_traj: bool = False,
21692169
):
21702170
model_devi_files = glob.glob(os.path.join(task_path, "model_devi*.out"))
2171-
model_devi_files_sorted = sorted(
2172-
model_devi_files, key=lambda x: int(re.search(r"(\d+)", x).group(1))
2173-
)
2174-
if len(model_devi_files_sorted) > 1:
2171+
if len(model_devi_files) > 1:
2172+
model_devi_files_sorted = sorted(
2173+
model_devi_files,
2174+
key=lambda x: int(re.search(r"model_devi(\d+)\.out", x).group(1)),
2175+
)
21752176
with open(model_devi_files_sorted[0]) as f:
21762177
first_line = f.readline()
21772178
if not (first_line.startswith("#")):

tests/generator/test_make_md.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import glob
33
import json
44
import os
5+
import re
56
import shutil
67
import sys
78
import unittest
@@ -274,22 +275,36 @@ def test_read_model_devi_file_pimd(self):
274275
if os.path.isdir(path):
275276
shutil.rmtree(path)
276277
os.makedirs(path, exist_ok=True)
278+
path = os.path.join(path, "iter.000000/01.model_devi/task.000.000000")
277279
os.makedirs(os.path.join(path, "traj"), exist_ok=True)
278280
for i in range(4):
279281
for j in range(0, 5, 2):
280-
with open(os.path.join(path, f"traj/{j}.lammpstrj{i+1}"), "a"):
281-
pass
282+
with open(os.path.join(path, f"traj/{j}.lammpstrj{i+1}"), "a") as fp:
283+
fp.write(f"{i} {j}\n")
282284
model_devi_array = np.zeros([3, 7])
283285
model_devi_array[:, 0] = np.array([0, 2, 4])
286+
model_devi_total_array = np.zeros([12, 7])
287+
total_steps = np.array([0, 2, 4, 5, 7, 9, 10, 12, 14, 15, 17, 19])
288+
model_devi_total_array[:, 0] = total_steps
284289
for i in range(4):
290+
model_devi_array[:, 4] = 0.1 * (i + 1) * np.arange(1, 4)
291+
model_devi_total_array[i * 3 : (i + 1) * 3, 4] = model_devi_array[:, 4]
285292
np.savetxt(
286-
os.path.join(path, f"model_devi{i+1}.out"), model_devi_array, fmt="%d"
293+
os.path.join(path, f"model_devi{i+1}.out"),
294+
model_devi_array,
295+
fmt="%.12e",
287296
)
288297
_read_model_devi_file(path)
289-
model_devi_total_array = np.zeros([12, 7])
290-
total_steps = np.array([0, 2, 4, 5, 7, 9, 10, 12, 14, 15, 17, 19])
291-
model_devi_total_array[:, 0] = total_steps
292298
model_devi_out = np.loadtxt(os.path.join(path, "model_devi.out"))
299+
traj_files = glob.glob(os.path.join(path, "traj/*.lammpstrj"))
300+
traj_files = sorted(
301+
traj_files, key=lambda x: int(re.search(r"(\d+).lammpstrj", x).group(1))
302+
)
303+
for idx, traj in enumerate(traj_files):
304+
traj_content = np.loadtxt(traj)
305+
ibead = idx // 3
306+
istep = (idx % 3) * 2
307+
np.testing.assert_array_almost_equal(traj_content, np.array([ibead, istep]))
293308
np.testing.assert_array_almost_equal(model_devi_out, model_devi_total_array)
294309
for istep in total_steps:
295310
self.assertTrue(

0 commit comments

Comments
 (0)