|
37 | 37 | detail_file_name_prefix = "details" |
38 | 38 |
|
39 | 39 |
|
| 40 | +def get_system_cls(jdata): |
| 41 | + if jdata.get("labeled", False): |
| 42 | + return dpdata.LabeledSystem |
| 43 | + return dpdata.System |
| 44 | + |
| 45 | + |
| 46 | +def get_systems(path, jdata): |
| 47 | + system = get_system_cls(jdata) |
| 48 | + systems = dpdata.MultiSystems( |
| 49 | + *[system(os.path.join(path, s), fmt='deepmd/npy') for s in os.listdir(path)]) |
| 50 | + return systems |
| 51 | + |
| 52 | + |
40 | 53 | def init_pick(iter_index, jdata, mdata): |
41 | 54 | """pick up init data from dataset randomly""" |
42 | 55 | pick_data = jdata['pick_data'] |
43 | 56 | init_pick_number = jdata['init_pick_number'] |
44 | 57 | # use MultiSystems with System |
45 | 58 | # TODO: support System and LabeledSystem |
46 | | - # TODO: support MultiSystems with LabeledSystem |
47 | 59 | # TODO: support other format |
48 | | - systems = dpdata.MultiSystems( |
49 | | - *[dpdata.System(os.path.join(pick_data, s), fmt='deepmd/npy') for s in os.listdir(pick_data)]) |
| 60 | + systems = get_systems(pick_data, jdata) |
50 | 61 | # label the system |
51 | 62 | labels = [] |
52 | 63 | for key, system in systems.systems.items(): |
@@ -135,7 +146,7 @@ def run_model_devi(iter_index, jdata, mdata, dispatcher): |
135 | 146 | ) |
136 | 147 | # TODO: support 0.x? |
137 | 148 | command = "{python} -m deepmd test -m {model} -s {system} -n {numb_test} -d {detail_file}".format( |
138 | | - python=mdata['python_path'], |
| 149 | + python=mdata['python_test_path'], |
139 | 150 | model=mm, |
140 | 151 | system=rest_data_name, |
141 | 152 | numb_test=data_size, |
@@ -197,7 +208,8 @@ def post_model_devi(iter_index, jdata, mdata): |
197 | 208 | f_std = np.max(f_std, axis=1) |
198 | 209 | # (n_frame,) |
199 | 210 |
|
200 | | - for subsys, e_devi, f_devi in zip(dpdata.System(os.path.join(task, rest_data_name), fmt='deepmd/npy'), e_std, f_std): |
| 211 | + system_cls = get_system_cls(jdata) |
| 212 | + for subsys, e_devi, f_devi in zip(system_cls(os.path.join(task, rest_data_name), fmt='deepmd/npy'), e_std, f_std): |
201 | 213 | if (e_devi < e_trust_hi and e_devi >= e_trust_lo) or (f_devi < f_trust_hi and f_devi >= f_trust_lo) : |
202 | 214 | sys_candinate.append(subsys) |
203 | 215 | elif (e_devi >= e_trust_hi ) or (f_devi >= f_trust_hi ): |
@@ -251,8 +263,14 @@ def make_fp(iter_index, jdata, mdata): |
251 | 263 | work_path = os.path.join(iter_name, fp_name) |
252 | 264 | create_path(work_path) |
253 | 265 | picked_data_path = os.path.join(iter_name, model_devi_name, picked_data_name) |
254 | | - systems = dpdata.MultiSystems( |
255 | | - *[dpdata.System(os.path.join(picked_data_path, s), fmt='deepmd/npy') for s in os.listdir(picked_data_path)]) |
| 266 | + if jdata.get("labeled", False): |
| 267 | + dlog.info("already labeled, skip make_fp and link data directly") |
| 268 | + os.symlink(os.path.abspath(picked_data_path), os.path.abspath( |
| 269 | + os.path.join(work_path, "task.%03d" % 0))) |
| 270 | + os.symlink(os.path.abspath(picked_data_path), os.path.abspath( |
| 271 | + os.path.join(work_path, "data.%03d" % 0))) |
| 272 | + return |
| 273 | + systems = get_systems(picked_data_path, jdata) |
256 | 274 | fp_style = jdata['fp_style'] |
257 | 275 | if 'user_fp_params' in jdata.keys() : |
258 | 276 | fp_params = jdata['user_fp_params'] |
@@ -378,12 +396,18 @@ def run_iter(param_file, machine_file): |
378 | 396 | make_fp(ii, jdata, mdata) |
379 | 397 | elif jj == 7: |
380 | 398 | log_iter("run_fp", ii, jj) |
381 | | - mdata = decide_fp_machine(mdata) |
382 | | - disp = make_dispatcher(mdata['fp_machine']) |
383 | | - run_fp(ii, jdata, mdata, disp) |
| 399 | + if jdata.get("labeled", False): |
| 400 | + dlog.info("already have labeled data, skip run_fp") |
| 401 | + else: |
| 402 | + mdata = decide_fp_machine(mdata) |
| 403 | + disp = make_dispatcher(mdata['fp_machine']) |
| 404 | + run_fp(ii, jdata, mdata, disp) |
384 | 405 | elif jj == 8: |
385 | 406 | log_iter("post_fp", ii, jj) |
386 | | - post_fp(ii, jdata) |
| 407 | + if jdata.get("labeled", False): |
| 408 | + dlog.info("already have labeled data, skip post_fp") |
| 409 | + else: |
| 410 | + post_fp(ii, jdata) |
387 | 411 | else: |
388 | 412 | raise RuntimeError("unknown task %d, something wrong" % jj) |
389 | 413 | record_iter(record, ii, jj) |
|
0 commit comments