Skip to content

Commit 43d1a0d

Browse files
authored
support dp compress (#607)
Set `dp_compress` to `true` in parameters will enable model compression.
1 parent df9095d commit 43d1a0d

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ The bold notation of key (such aas **type_map**) means that it's a necessary key
545545
| training_iter0_model_path | list of string | ["/path/to/model0_ckpt/", ...] | The model used to init the first iter training. Number of element should be equal to `numb_models` |
546546
| training_init_model | bool | False | Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from `training_iter0_model_path`. |
547547
| **default_training_param** | Dict | | Training parameters for `deepmd-kit` in `00.train`. <br /> You can find instructions from here: (https://github.com/deepmodeling/deepmd-kit)..<br /> |
548+
| dp_compress | bool | false | Use `dp compress` to compress the model. Default is false. |
548549
| *#Exploration*
549550
| **model_devi_dt** | Float | 0.002 (recommend) | Timestep for MD |
550551
| **model_devi_skip** | Integer | 0 | Number of structures skipped for fp in each MD

dpgen/generator/run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,8 @@ def run_train (iter_index,
514514
commands.append(command)
515515
command = '%s freeze' % train_command
516516
commands.append(command)
517+
if jdata.get("dp_compress", False):
518+
commands.append("%s compress" % train_command)
517519
else:
518520
raise RuntimeError("DP-GEN currently only supports for DeePMD-kit 1.x or 2.x version!" )
519521

@@ -536,6 +538,8 @@ def run_train (iter_index,
536538
]
537539
backward_files = ['frozen_model.pb', 'lcurve.out', 'train.log']
538540
backward_files+= ['model.ckpt.meta', 'model.ckpt.index', 'model.ckpt.data-00000-of-00001', 'checkpoint']
541+
if jdata.get("dp_compress", False):
542+
backward_files.append('frozen_model_compressed.pb')
539543
init_data_sys_ = jdata['init_data_sys']
540544
init_data_sys = []
541545
for ii in init_data_sys_ :
@@ -621,7 +625,11 @@ def post_train (iter_index,
621625
return
622626
# symlink models
623627
for ii in range(numb_models) :
624-
task_file = os.path.join(train_task_fmt % ii, 'frozen_model.pb')
628+
if not jdata.get("dp_compress", False):
629+
model_name = 'frozen_model.pb'
630+
else:
631+
model_name = 'frozen_model_compressed.pb'
632+
task_file = os.path.join(train_task_fmt % ii, model_name)
625633
ofile = os.path.join(work_path, 'graph.%03d.pb' % ii)
626634
if os.path.isfile(ofile) :
627635
os.remove(ofile)

0 commit comments

Comments
 (0)