Skip to content

Commit ab4d05f

Browse files
authored
Merge pull request #194 from njzjz/labeled
support simplify LabeledSystem
2 parents 8a91740 + c2cf160 commit ab4d05f

File tree

1 file changed

+35
-11
lines changed

1 file changed

+35
-11
lines changed

dpgen/simplify/simplify.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,27 @@
3737
detail_file_name_prefix = "details"
3838

3939

40+
def get_system_cls(jdata):
41+
if jdata.get("labeled", False):
42+
return dpdata.LabeledSystem
43+
return dpdata.System
44+
45+
46+
def get_systems(path, jdata):
47+
system = get_system_cls(jdata)
48+
systems = dpdata.MultiSystems(
49+
*[system(os.path.join(path, s), fmt='deepmd/npy') for s in os.listdir(path)])
50+
return systems
51+
52+
4053
def init_pick(iter_index, jdata, mdata):
4154
"""pick up init data from dataset randomly"""
4255
pick_data = jdata['pick_data']
4356
init_pick_number = jdata['init_pick_number']
4457
# use MultiSystems with System
4558
# TODO: support System and LabeledSystem
46-
# TODO: support MultiSystems with LabeledSystem
4759
# TODO: support other format
48-
systems = dpdata.MultiSystems(
49-
*[dpdata.System(os.path.join(pick_data, s), fmt='deepmd/npy') for s in os.listdir(pick_data)])
60+
systems = get_systems(pick_data, jdata)
5061
# label the system
5162
labels = []
5263
for key, system in systems.systems.items():
@@ -135,7 +146,7 @@ def run_model_devi(iter_index, jdata, mdata, dispatcher):
135146
)
136147
# TODO: support 0.x?
137148
command = "{python} -m deepmd test -m {model} -s {system} -n {numb_test} -d {detail_file}".format(
138-
python=mdata['python_path'],
149+
python=mdata['python_test_path'],
139150
model=mm,
140151
system=rest_data_name,
141152
numb_test=data_size,
@@ -197,7 +208,8 @@ def post_model_devi(iter_index, jdata, mdata):
197208
f_std = np.max(f_std, axis=1)
198209
# (n_frame,)
199210

200-
for subsys, e_devi, f_devi in zip(dpdata.System(os.path.join(task, rest_data_name), fmt='deepmd/npy'), e_std, f_std):
211+
system_cls = get_system_cls(jdata)
212+
for subsys, e_devi, f_devi in zip(system_cls(os.path.join(task, rest_data_name), fmt='deepmd/npy'), e_std, f_std):
201213
if (e_devi < e_trust_hi and e_devi >= e_trust_lo) or (f_devi < f_trust_hi and f_devi >= f_trust_lo) :
202214
sys_candinate.append(subsys)
203215
elif (e_devi >= e_trust_hi ) or (f_devi >= f_trust_hi ):
@@ -251,8 +263,14 @@ def make_fp(iter_index, jdata, mdata):
251263
work_path = os.path.join(iter_name, fp_name)
252264
create_path(work_path)
253265
picked_data_path = os.path.join(iter_name, model_devi_name, picked_data_name)
254-
systems = dpdata.MultiSystems(
255-
*[dpdata.System(os.path.join(picked_data_path, s), fmt='deepmd/npy') for s in os.listdir(picked_data_path)])
266+
if jdata.get("labeled", False):
267+
dlog.info("already labeled, skip make_fp and link data directly")
268+
os.symlink(os.path.abspath(picked_data_path), os.path.abspath(
269+
os.path.join(work_path, "task.%03d" % 0)))
270+
os.symlink(os.path.abspath(picked_data_path), os.path.abspath(
271+
os.path.join(work_path, "data.%03d" % 0)))
272+
return
273+
systems = get_systems(picked_data_path, jdata)
256274
fp_style = jdata['fp_style']
257275
if 'user_fp_params' in jdata.keys() :
258276
fp_params = jdata['user_fp_params']
@@ -378,12 +396,18 @@ def run_iter(param_file, machine_file):
378396
make_fp(ii, jdata, mdata)
379397
elif jj == 7:
380398
log_iter("run_fp", ii, jj)
381-
mdata = decide_fp_machine(mdata)
382-
disp = make_dispatcher(mdata['fp_machine'])
383-
run_fp(ii, jdata, mdata, disp)
399+
if jdata.get("labeled", False):
400+
dlog.info("already have labeled data, skip run_fp")
401+
else:
402+
mdata = decide_fp_machine(mdata)
403+
disp = make_dispatcher(mdata['fp_machine'])
404+
run_fp(ii, jdata, mdata, disp)
384405
elif jj == 8:
385406
log_iter("post_fp", ii, jj)
386-
post_fp(ii, jdata)
407+
if jdata.get("labeled", False):
408+
dlog.info("already have labeled data, skip post_fp")
409+
else:
410+
post_fp(ii, jdata)
387411
else:
388412
raise RuntimeError("unknown task %d, something wrong" % jj)
389413
record_iter(record, ii, jj)

0 commit comments

Comments
 (0)