Skip to content

Commit fb8668b

Browse files
authored
add an option to merge data to one H5 file (#1119)
Fix #617. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 3e1891d commit fb8668b

File tree

4 files changed

+142
-21
lines changed

4 files changed

+142
-21
lines changed

dpgen/generator/arginfo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def training_args() -> List[Argument]:
8080
doc_training_reuse_start_pref_f = "The prefactor of force loss at the start of the training." + doc_reusing
8181
doc_model_devi_activation_func = "The activation function in the model. The shape of list should be (N_models, 2), where 2 represents the embedding and fitting network. This option will override default parameters."
8282
doc_srtab_file_path = 'The path of the table for the short-range pairwise interaction which is needed when using DP-ZBL potential'
83+
doc_one_h5 = "Before training, all of the training data will be merged into one HDF5 file."
8384

8485
return [
8586
Argument("numb_models", int, optional=False, doc=doc_numb_models),
@@ -100,6 +101,7 @@ def training_args() -> List[Argument]:
100101
Argument("model_devi_activation_func", [None, list], optional=True, doc=doc_model_devi_activation_func),
101102
Argument("srtab_file_path",str,optional=True,
102103
doc=doc_srtab_file_path),
104+
Argument("one_h5", bool, optional=True, default=False, doc=doc_one_h5),
103105
]
104106

105107

dpgen/generator/run.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from dpgen.generator.lib.ele_temp import NBandsEsti
6363
from dpgen.remote.decide_machine import convert_mdata
6464
from dpgen.dispatcher.Dispatcher import make_submission
65-
from dpgen.util import sepline, expand_sys_str, normalize
65+
from dpgen.util import sepline, expand_sys_str, normalize, convert_training_data_to_hdf5
6666
from dpgen import ROOT_PATH
6767
from pymatgen.io.vasp import Incar,Kpoints,Potcar
6868
from dpgen.auto_test.lib.vasp import make_kspacing_kpoints
@@ -385,7 +385,7 @@ def make_train (iter_index,
385385
jinput['loss']['start_pref_f'] = training_reuse_start_pref_f
386386
jinput['learning_rate']['start_lr'] = training_reuse_start_lr
387387

388-
388+
input_files = []
389389
for ii in range(numb_models) :
390390
task_path = os.path.join(work_path, train_task_fmt % ii)
391391
create_path(task_path)
@@ -429,6 +429,7 @@ def make_train (iter_index,
429429
# dump the input.json
430430
with open(os.path.join(task_path, train_input_file), 'w') as outfile:
431431
json.dump(jinput, outfile, indent = 4)
432+
input_files.append(os.path.join(task_path, train_input_file))
432433

433434
# link old models
434435
if iter_index > 0 :
@@ -454,7 +455,9 @@ def make_train (iter_index,
454455
_link_old_models(work_path, old_model_files, ii)
455456
# Copy user defined forward files
456457
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
457-
458+
# HDF5 format for training data
459+
if jdata.get('one_h5', False):
460+
convert_training_data_to_hdf5(input_files, os.path.join(work_path, "data.hdf5"))
458461

459462

460463
def _link_old_models(work_path, old_model_files, ii):
@@ -568,24 +571,28 @@ def run_train (iter_index,
568571
backward_files+= ['model.ckpt.meta', 'model.ckpt.index', 'model.ckpt.data-00000-of-00001', 'checkpoint']
569572
if jdata.get("dp_compress", False):
570573
backward_files.append('frozen_model_compressed.pb')
571-
init_data_sys_ = jdata['init_data_sys']
572-
init_data_sys = []
573-
for ii in init_data_sys_ :
574-
init_data_sys.append(os.path.join('data.init', ii))
575-
trans_comm_data = []
576-
cwd = os.getcwd()
577-
os.chdir(work_path)
578-
fp_data = glob.glob(os.path.join('data.iters', 'iter.*', '02.fp', 'data.*'))
579-
for ii in itertools.chain(init_data_sys, fp_data) :
580-
sys_paths = expand_sys_str(ii)
581-
for single_sys in sys_paths:
582-
if "#" not in single_sys:
583-
trans_comm_data += glob.glob(os.path.join(single_sys, 'set.*'))
584-
trans_comm_data += glob.glob(os.path.join(single_sys, 'type*.raw'))
585-
trans_comm_data += glob.glob(os.path.join(single_sys, 'nopbc'))
586-
else:
587-
# H5 file
588-
trans_comm_data.append(single_sys.split("#")[0])
574+
if not jdata.get('one_h5', False):
575+
init_data_sys_ = jdata['init_data_sys']
576+
init_data_sys = []
577+
for ii in init_data_sys_ :
578+
init_data_sys.append(os.path.join('data.init', ii))
579+
trans_comm_data = []
580+
cwd = os.getcwd()
581+
os.chdir(work_path)
582+
fp_data = glob.glob(os.path.join('data.iters', 'iter.*', '02.fp', 'data.*'))
583+
for ii in itertools.chain(init_data_sys, fp_data) :
584+
sys_paths = expand_sys_str(ii)
585+
for single_sys in sys_paths:
586+
if "#" not in single_sys:
587+
trans_comm_data += glob.glob(os.path.join(single_sys, 'set.*'))
588+
trans_comm_data += glob.glob(os.path.join(single_sys, 'type*.raw'))
589+
trans_comm_data += glob.glob(os.path.join(single_sys, 'nopbc'))
590+
else:
591+
# H5 file
592+
trans_comm_data.append(single_sys.split("#")[0])
593+
else:
594+
cwd = os.getcwd()
595+
trans_comm_data = ["data.hdf5"]
589596
# remove duplicated files
590597
trans_comm_data = list(set(trans_comm_data))
591598
os.chdir(cwd)

dpgen/util.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#!/usr/bin/env python
22
# coding: utf-8
3+
import json
4+
import os
35
from typing import Union, List
46
from pathlib import Path
57

68
import h5py
9+
import dpdata
710
from dargs import Argument
811

912
from dpgen import dlog
@@ -83,3 +86,60 @@ def normalize(arginfo: Argument, data: dict, strict_check: bool = True) -> dict:
8386
data = arginfo.normalize_value(data, trim_pattern="_*")
8487
arginfo.check_value(data, strict=strict_check)
8588
return data
89+
90+
91+
def convert_training_data_to_hdf5(input_files: List[str], h5_file: str):
92+
"""Convert training data to HDF5 format and update the input files.
93+
94+
Parameters
95+
----------
96+
input_files : list of str
97+
DeePMD-kit input file names
98+
h5_file : str
99+
HDF5 file name
100+
"""
101+
systems = []
102+
h5_dir = Path(h5_file).parent.absolute()
103+
cwd = Path.cwd().absolute()
104+
for ii in input_files:
105+
ii = Path(ii)
106+
dd = ii.parent.absolute()
107+
with open(ii, 'r+') as f:
108+
jinput = json.load(f)
109+
if 'training_data' in jinput['training']:
110+
# v2.0
111+
p_sys = jinput['training']['training_data']['systems']
112+
else:
113+
# v1.x
114+
p_sys = jinput['training']['systems']
115+
for ii, pp in enumerate(p_sys):
116+
if "#" in pp:
117+
# HDF5 file
118+
p1, p2 = pp.split("#")
119+
ff = os.path.normpath(str((dd / p1).absolute().relative_to(cwd)))
120+
pp = ff + "#" + p2
121+
new_pp = os.path.normpath(os.path.relpath(ff, h5_dir)) + "/" + p2
122+
else:
123+
pp = os.path.normpath(str((dd / pp).absolute().relative_to(cwd)))
124+
new_pp = os.path.normpath(os.path.relpath(pp, h5_dir))
125+
p_sys[ii] = os.path.normpath(os.path.relpath(h5_file, dd)) + "#/" + str(new_pp)
126+
systems.append(pp)
127+
f.seek(0)
128+
json.dump(jinput, f, indent=4)
129+
systems = list(set(systems))
130+
131+
dlog.info("Combining %d training systems to %s...", len(systems), h5_file)
132+
133+
with h5py.File(h5_file, 'w') as f:
134+
for ii in systems:
135+
if "#" in ii:
136+
p1, p2 = ii.split("#")
137+
p1 = os.path.normpath(os.path.relpath(p1, h5_dir))
138+
group = f.create_group(str(p1) + "/" + p2)
139+
s = dpdata.LabeledSystem(ii, fmt="deepmd/hdf5")
140+
s.to("deepmd/hdf5", group)
141+
else:
142+
pp = os.path.normpath(os.path.relpath(ii, h5_dir))
143+
group = f.create_group(str(pp))
144+
s = dpdata.LabeledSystem(ii, fmt="deepmd/npy")
145+
s.to("deepmd/hdf5", group)

tests/generator/test_make_train.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,58 @@ def test_1_data_v1_h5(self) :
362362
shutil.rmtree('iter.000000')
363363
os.remove('data/deepmd.hdf5')
364364

365+
def test_1_data_v1_one_h5(self) :
366+
"""Test `one_h5` option."""
367+
dpdata.LabeledSystem("data/deepmd", fmt='deepmd/npy').to_deepmd_hdf5('data/deepmd.hdf5')
368+
with open (param_file_v1, 'r') as fp :
369+
jdata = json.load (fp)
370+
jdata.pop('use_ele_temp', None)
371+
jdata['init_data_sys'].append('deepmd.hdf5')
372+
jdata['init_batch_size'].append('auto')
373+
jdata['one_h5'] = True
374+
with open (machine_file_v1, 'r') as fp:
375+
mdata = json.load (fp)
376+
make_train(0, jdata, mdata)
377+
# make fake fp results #data == fp_task_min
378+
_make_fake_fp(0, 0, jdata['fp_task_min'])
379+
# make iter1 train
380+
make_train(1, jdata, mdata)
381+
# check data is linked
382+
self.assertTrue(os.path.isdir(os.path.join('iter.000001', '00.train', 'data.iters', 'iter.000000', '02.fp')))
383+
# check models inputs
384+
with open(os.path.join('iter.%06d' % 1,
385+
'00.train',
386+
'%03d' % 0,
387+
"input.json")) as fp:
388+
jdata0 = json.load(fp)
389+
self.assertEqual(jdata0['training']['systems'], [
390+
'../data.hdf5#/data.init/deepmd',
391+
'../data.hdf5#/data.init/deepmd.hdf5/',
392+
'../data.hdf5#/data.iters/iter.000000/02.fp/data.000',
393+
])
394+
# test run_train -- confirm transferred files are correct
395+
with tempfile.TemporaryDirectory() as remote_root:
396+
run_train(1, jdata, {
397+
"api_version": "1.0",
398+
"train_command": (
399+
"test -f ../data.hdf5"
400+
"&& touch frozen_model.pb lcurve.out model.ckpt.meta model.ckpt.index model.ckpt.data-00000-of-00001 checkpoint"
401+
"&& echo dp"
402+
),
403+
"train_machine": {
404+
"batch_type": "shell",
405+
"local_root": "./",
406+
"remote_root": remote_root,
407+
"context_type": "local",
408+
},
409+
"train_resources": {
410+
"group_size": 1,
411+
},
412+
})
413+
414+
# remove testing dirs
415+
shutil.rmtree('iter.000001')
416+
shutil.rmtree('iter.000000')
365417

366418
if __name__ == '__main__':
367419
unittest.main()

0 commit comments

Comments
 (0)