6262from dpgen .generator .lib .ele_temp import NBandsEsti
6363from dpgen .remote .decide_machine import convert_mdata
6464from dpgen .dispatcher .Dispatcher import make_submission
65- from dpgen .util import sepline , expand_sys_str , normalize
65+ from dpgen .util import sepline , expand_sys_str , normalize , convert_training_data_to_hdf5
6666from dpgen import ROOT_PATH
6767from pymatgen .io .vasp import Incar ,Kpoints ,Potcar
6868from dpgen .auto_test .lib .vasp import make_kspacing_kpoints
@@ -385,7 +385,7 @@ def make_train (iter_index,
385385 jinput ['loss' ]['start_pref_f' ] = training_reuse_start_pref_f
386386 jinput ['learning_rate' ]['start_lr' ] = training_reuse_start_lr
387387
388-
388+ input_files = []
389389 for ii in range (numb_models ) :
390390 task_path = os .path .join (work_path , train_task_fmt % ii )
391391 create_path (task_path )
@@ -429,6 +429,7 @@ def make_train (iter_index,
429429 # dump the input.json
430430 with open (os .path .join (task_path , train_input_file ), 'w' ) as outfile :
431431 json .dump (jinput , outfile , indent = 4 )
432+ input_files .append (os .path .join (task_path , train_input_file ))
432433
433434 # link old models
434435 if iter_index > 0 :
@@ -454,7 +455,9 @@ def make_train (iter_index,
454455 _link_old_models (work_path , old_model_files , ii )
455456 # Copy user defined forward files
456457 symlink_user_forward_files (mdata = mdata , task_type = "train" , work_path = work_path )
457-
458+ # HDF5 format for training data
459+ if jdata .get ('one_h5' , False ):
460+ convert_training_data_to_hdf5 (input_files , os .path .join (work_path , "data.hdf5" ))
458461
459462
460463def _link_old_models (work_path , old_model_files , ii ):
@@ -568,24 +571,28 @@ def run_train (iter_index,
568571 backward_files += ['model.ckpt.meta' , 'model.ckpt.index' , 'model.ckpt.data-00000-of-00001' , 'checkpoint' ]
569572 if jdata .get ("dp_compress" , False ):
570573 backward_files .append ('frozen_model_compressed.pb' )
571- init_data_sys_ = jdata ['init_data_sys' ]
572- init_data_sys = []
573- for ii in init_data_sys_ :
574- init_data_sys .append (os .path .join ('data.init' , ii ))
575- trans_comm_data = []
576- cwd = os .getcwd ()
577- os .chdir (work_path )
578- fp_data = glob .glob (os .path .join ('data.iters' , 'iter.*' , '02.fp' , 'data.*' ))
579- for ii in itertools .chain (init_data_sys , fp_data ) :
580- sys_paths = expand_sys_str (ii )
581- for single_sys in sys_paths :
582- if "#" not in single_sys :
583- trans_comm_data += glob .glob (os .path .join (single_sys , 'set.*' ))
584- trans_comm_data += glob .glob (os .path .join (single_sys , 'type*.raw' ))
585- trans_comm_data += glob .glob (os .path .join (single_sys , 'nopbc' ))
586- else :
587- # H5 file
588- trans_comm_data .append (single_sys .split ("#" )[0 ])
574+ if not jdata .get ('one_h5' , False ):
575+ init_data_sys_ = jdata ['init_data_sys' ]
576+ init_data_sys = []
577+ for ii in init_data_sys_ :
578+ init_data_sys .append (os .path .join ('data.init' , ii ))
579+ trans_comm_data = []
580+ cwd = os .getcwd ()
581+ os .chdir (work_path )
582+ fp_data = glob .glob (os .path .join ('data.iters' , 'iter.*' , '02.fp' , 'data.*' ))
583+ for ii in itertools .chain (init_data_sys , fp_data ) :
584+ sys_paths = expand_sys_str (ii )
585+ for single_sys in sys_paths :
586+ if "#" not in single_sys :
587+ trans_comm_data += glob .glob (os .path .join (single_sys , 'set.*' ))
588+ trans_comm_data += glob .glob (os .path .join (single_sys , 'type*.raw' ))
589+ trans_comm_data += glob .glob (os .path .join (single_sys , 'nopbc' ))
590+ else :
591+ # H5 file
592+ trans_comm_data .append (single_sys .split ("#" )[0 ])
593+ else :
594+ cwd = os .getcwd ()
595+ trans_comm_data = ["data.hdf5" ]
589596 # remove duplicated files
590597 trans_comm_data = list (set (trans_comm_data ))
591598 os .chdir (cwd )
0 commit comments