Skip to content

Commit 6304b66

Browse files
authored
Update simplify.py for api > 1.0, and fix multi-suffixes bug . (#670)
* Update simplify.py Compatibility for api > 1.0 * Update simplify.py fix the bug when multiple suffix applied, dp test obmit the middle ones and keep the last one, which lead to conflicts in dp test outputs. e.g.: dp test request 4 files: details.0.e.out details.1.e.out details.2.e.out details.3.e.out; actually output 1 file: details.e.out * Update simplify.py import function LooseVersion for API > 1.0
1 parent 7d2c1de commit 6304b66

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

dpgen/simplify/simplify.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from dpgen import dlog
2323
from dpgen import SHORT_CMD
2424
from dpgen.util import sepline
25-
from dpgen.dispatcher.Dispatcher import Dispatcher, make_dispatcher
25+
from distutils.version import LooseVersion
26+
from dpgen.dispatcher.Dispatcher import Dispatcher, _split_tasks, make_dispatcher, make_submission
2627
from dpgen.generator.run import make_train, run_train, post_train, run_fp, post_fp, fp_name, model_devi_name, train_name, train_task_fmt, sys_link_fp_vasp_pp, make_fp_vasp_incar, make_fp_vasp_kp, make_fp_vasp_cp_cvasp, data_system_fmt, model_devi_task_fmt, fp_task_fmt
2728
# TODO: maybe the following functions can be moved to dpgen.util
2829
from dpgen.generator.lib.utils import log_iter, make_iter_name, create_path, record_iter
@@ -245,7 +246,7 @@ def run_model_devi(iter_index, jdata, mdata):
245246
commands = []
246247
detail_file_names = []
247248
for ii, mm in enumerate(task_model_list):
248-
detail_file_name = "{prefix}.{ii}".format(
249+
detail_file_name = "{prefix}-{ii}".format(
249250
prefix=detail_file_name_prefix,
250251
ii=ii,
251252
)
@@ -268,17 +269,36 @@ def run_model_devi(iter_index, jdata, mdata):
268269
forward_files = [rest_data_name]
269270
backward_files = sum([[pf+".e.out", pf+".f.out", pf+".v.out"] for pf in detail_file_names], [])
270271

271-
dispatcher = make_dispatcher(mdata['model_devi_machine'], mdata['model_devi_resources'], work_path, run_tasks, model_devi_group_size)
272-
dispatcher.run_jobs(mdata['model_devi_resources'],
272+
api_version = mdata.get('api_version', '0.9')
273+
if LooseVersion(api_version) < LooseVersion('1.0'):
274+
warnings.warn(f"the dpdispatcher will be updated to new version."
275+
f"And the interface may be changed. Please check the documents for more details")
276+
dispatcher = make_dispatcher(mdata['model_devi_machine'], mdata['model_devi_resources'], work_path, run_tasks, model_devi_group_size)
277+
dispatcher.run_jobs(mdata['model_devi_resources'],
273278
commands,
274279
work_path,
275280
run_tasks,
276281
model_devi_group_size,
277282
model_names,
278283
forward_files,
279284
backward_files,
280-
outlog='model_devi.log',
281-
errlog='model_devi.log')
285+
outlog = 'model_devi.log',
286+
errlog = 'model_devi.log')
287+
288+
elif LooseVersion(api_version) >= LooseVersion('1.0'):
289+
submission = make_submission(
290+
mdata['model_devi_machine'],
291+
mdata['model_devi_resources'],
292+
commands=commands,
293+
work_path=work_path,
294+
run_tasks=run_tasks,
295+
group_size=model_devi_group_size,
296+
forward_common_files=model_names,
297+
forward_files=forward_files,
298+
backward_files=backward_files,
299+
outlog = 'model_devi.log',
300+
errlog = 'model_devi.log')
301+
submission.run_submission()
282302

283303

284304
def post_model_devi(iter_index, jdata, mdata):
@@ -309,13 +329,13 @@ def post_model_devi(iter_index, jdata, mdata):
309329
sys_name = os.path.basename(task).split('.')[1]
310330
all_names.add(sys_name)
311331
# e.out
312-
details_e = glob.glob(os.path.join(task, "{}.*.e.out".format(detail_file_name_prefix)))
332+
details_e = glob.glob(os.path.join(task, "{}-*.e.out".format(detail_file_name_prefix)))
313333
e_all = np.array([np.loadtxt(detail_e, ndmin=2)[:, 1] for detail_e in details_e])
314334
e_std = np.std(e_all, axis=0)
315335
n_frame = e_std.size
316336

317337
# f.out
318-
details_f = glob.glob(os.path.join(task, "{}.*.f.out".format(detail_file_name_prefix)))
338+
details_f = glob.glob(os.path.join(task, "{}-*.f.out".format(detail_file_name_prefix)))
319339
f_all = np.array([np.loadtxt(detail_f, ndmin=2)[:, 3:6].reshape((n_frame, -1, 3)) for detail_f in details_f])
320340
# (n_model, n_frame, n_atom, 3)
321341
f_std = np.std(f_all, axis=0)

0 commit comments

Comments
 (0)