Skip to content

Commit 6de10d0

Browse files
authored
model_devi | sort models in default&support user-defined order (#610)
* model_devi | sort models in default&support user-defined order current behavior: 1)the list of models written in input.lammps for model_devi is created by "glob" without "sort" --> may give different (random) orders in different machine environments --> would randomly alter the model being used to sample configurations (always use the first model written in imput.lammps) ---> could lead to unexpected performance, such as, when using together with "model_devi_activation_func" that allows four models to be nonequivalent. 2)when preparing input.lammps from iuser-provided template, the line begin with "pair_style deepmd" will be overwritten by dpgen, thus overwrites the user defined order of models ---> could lead to unexpected performance, such as, when using together with "model_devi_activation_func" that allows four models to be nonequivalent and users indeed expected a specific order of models. changes: 1)sorted the list of models in default, thus always use graph.000.pb to sample configurations; 2)check weather user writes the full line of begin with "pair_style deepmd" (by checking the length), if yes, leave it be; if not, overwrites with the default settings (use graph.000.pb to sample); besides, the original error trigger is retained if key words "pair_style deepmd" are not provided in the template. * adjust space/tab * adjust annotation
1 parent 7692835 commit 6de10d0

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

dpgen/generator/run.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def make_model_devi (iter_index,
804804
iter_name = make_iter_name(iter_index)
805805
train_path = os.path.join(iter_name, train_name)
806806
train_path = os.path.abspath(train_path)
807-
models = glob.glob(os.path.join(train_path, "graph*pb"))
807+
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
808808
work_path = os.path.join(iter_name, model_devi_name)
809809
create_path(work_path)
810810
for mm in models :
@@ -890,7 +890,7 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
890890
iter_name = make_iter_name(iter_index)
891891
train_path = os.path.join(iter_name, train_name)
892892
train_path = os.path.abspath(train_path)
893-
models = glob.glob(os.path.join(train_path, "graph*pb"))
893+
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
894894
task_model_list = []
895895
for ii in models:
896896
task_model_list.append(os.path.join('..', os.path.basename(ii)))
@@ -947,7 +947,23 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
947947
# revise input of lammps
948948
with open('input.lammps') as fp:
949949
lmp_lines = fp.readlines()
950-
lmp_lines = revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version = deepmd_version)
950+
# only revise the line "pair_style deepmd" if the user has not written the full line (checked by then length of the line)
951+
template_has_pair_deepmd=1
952+
for line_idx,line_context in enumerate(lmp_lines):
953+
if (line_context[0] != "#") and ("pair_style" in line_context) and ("deepmd" in line_context):
954+
template_has_pair_deepmd=0
955+
template_pair_deepmd_idx=line_idx
956+
if template_has_pair_deepmd == 0:
957+
if LooseVersion(deepmd_version) < LooseVersion('1'):
958+
if len(lmp_lines[template_pair_deepmd_idx].split()) != (len(models) + len(["pair_style","deepmd","10", "model_devi.out"])):
959+
lmp_lines = revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version = deepmd_version)
960+
else:
961+
if len(lmp_lines[template_pair_deepmd_idx].split()) != (len(models) + len(["pair_style","deepmd","out_freq", "10", "out_file", "model_devi.out"])):
962+
lmp_lines = revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version = deepmd_version)
963+
#use revise_lmp_input_model to raise error message if "part_style" or "deepmd" not found
964+
else:
965+
lmp_lines = revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version = deepmd_version)
966+
951967
lmp_lines = revise_lmp_input_dump(lmp_lines, trj_freq)
952968
lmp_lines = revise_by_keys(
953969
lmp_lines, total_rev_keys[:total_num_lmp], total_rev_item[:total_num_lmp]

0 commit comments

Comments
 (0)