Skip to content

Commit 11ce228

Browse files
amcadmusHan Wang
andauthored
Adaptive trust levels (#495)
* support adaptive trust level * update README * fix bugs in readme * adaptive lower trust level support percentage of total number of frames Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 9108224 commit 11ce228

File tree

2 files changed

+196
-54
lines changed

2 files changed

+196
-54
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,13 @@ The bold notation of key (such aas **type_map**) means that it's a necessary key
549549
| **model_devi_skip** | Integer | 0 | Number of structures skipped for fp in each MD
550550
| **model_devi_f_trust_lo** | Float | 0.05 | Lower bound of forces for the selection.
551551
| **model_devi_f_trust_hi** | Float | 0.15 | Upper bound of forces for the selection
552-
| **model_devi_e_trust_lo** | Float | 1e10 | Lower bound of energies for the selection. Recommend to set them a high number, since forces provide more precise information. Special cases such as energy minimization may need this. |
553-
| **model_devi_e_trust_hi** | Float | 1e10 | Upper bound of energies for the selection. |
552+
| **model_devi_v_trust_lo** | Float | 1e10 | Lower bound of virial for the selection. Should be used with DeePMD-kit v2.x |
553+
| **model_devi_v_trust_hi** | Float | 1e10 | Upper bound of virial for the selection. Should be used with DeePMD-kit v2.x |
554+
| model_devi_adapt_trust_lo | Boolean | False | Adaptively determines the lower trust levels of force and virial. This option should be used together with `model_devi_numb_candi_f`, `model_devi_numb_candi_v` and optionally with `model_devi_perc_candi_f` and `model_devi_perc_candi_v`. `dpgen` will make two sets: 1. From the frames with force model deviation lower than `model_devi_f_trust_hi`, select `max(model_devi_numb_candi_f, model_devi_perc_candi_f*n_frames)` frames with largest force model deviation. 2. From the frames with virial model deviation lower than `model_devi_v_trust_hi`, select `max(model_devi_numb_candi_v, model_devi_perc_candi_v*n_frames)` frames with largest virial model deviation. The union of the two sets is made as candidate dataset|
555+
| model_devi_numb_candi_f | Int | 10 | See `model_devi_adapt_trust_lo`.|
556+
| model_devi_numb_candi_v | Int | 0 | See `model_devi_adapt_trust_lo`.|
557+
| model_devi_perc_candi_f | Float | 0.0 | See `model_devi_adapt_trust_lo`.|
558+
| model_devi_perc_candi_v | Float | 0.0 | See `model_devi_adapt_trust_lo`.|
554559
| **model_devi_clean_traj** | Boolean | true | Deciding whether to clean traj folders in MD since they are too large. |
555560
| **model_devi_nopbc** | Boolean | False | Assume open boundary condition in MD simulations. |
556561
| model_devi_activation_func | List of list of string | [["tanh","tanh"],["tanh","gelu"],["gelu","tanh"],["gelu","gelu"]] | Set activation functions for models, length of the List should be the same as `numb_models`, and two elements in the list of string respectively assign activation functions to the embedding and fitting nets within each model. *Backward compatibility*: the orginal "List of String" format is still supported, where embedding and fitting nets of one model use the same activation function, and the length of the List should be the same as `numb_models`|

dpgen/generator/run.py

Lines changed: 189 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import scipy.constants as pc
2929
from collections import Counter
3030
from distutils.version import LooseVersion
31+
from typing import List
3132
from numpy.linalg import norm
3233
from dpgen import dlog
3334
from dpgen import SHORT_CMD
@@ -1321,11 +1322,169 @@ def check_bad_box(conf_name,
13211322
raise RuntimeError('unknow key', key)
13221323
return is_bad
13231324

1325+
1326+
def _select_by_model_devi_standard(
1327+
modd_system_task: List[str],
1328+
f_trust_lo : float,
1329+
f_trust_hi : float,
1330+
v_trust_lo : float,
1331+
v_trust_hi : float,
1332+
cluster_cutoff : float,
1333+
model_devi_skip : int = 0,
1334+
detailed_report_make_fp : bool = True,
1335+
):
1336+
fp_candidate = []
1337+
if detailed_report_make_fp:
1338+
fp_rest_accurate = []
1339+
fp_rest_failed = []
1340+
cc = 0
1341+
counter = Counter()
1342+
counter['candidate'] = 0
1343+
counter['failed'] = 0
1344+
counter['accurate'] = 0
1345+
for tt in modd_system_task :
1346+
with warnings.catch_warnings():
1347+
warnings.simplefilter("ignore")
1348+
all_conf = np.loadtxt(os.path.join(tt, 'model_devi.out'))
1349+
for ii in range(all_conf.shape[0]) :
1350+
if all_conf[ii][0] < model_devi_skip :
1351+
continue
1352+
cc = int(all_conf[ii][0])
1353+
if cluster_cutoff is None:
1354+
if (all_conf[ii][1] < v_trust_hi and all_conf[ii][1] >= v_trust_lo) or \
1355+
(all_conf[ii][4] < f_trust_hi and all_conf[ii][4] >= f_trust_lo) :
1356+
fp_candidate.append([tt, cc])
1357+
counter['candidate'] += 1
1358+
elif (all_conf[ii][1] >= v_trust_hi ) or (all_conf[ii][4] >= f_trust_hi ):
1359+
if detailed_report_make_fp:
1360+
fp_rest_failed.append([tt, cc])
1361+
counter['failed'] += 1
1362+
elif (all_conf[ii][1] < v_trust_lo and all_conf[ii][4] < f_trust_lo ):
1363+
if detailed_report_make_fp:
1364+
fp_rest_accurate.append([tt, cc])
1365+
counter['accurate'] += 1
1366+
else :
1367+
raise RuntimeError('md traj %s frame %d with f devi %f does not belong to either accurate, candidiate and failed, it should not happen' % (tt, ii, all_conf[ii][4]))
1368+
else:
1369+
idx_candidate = np.where(np.logical_and(all_conf[ii][7:] < f_trust_hi, all_conf[ii][7:] >= f_trust_lo))[0]
1370+
for jj in idx_candidate:
1371+
fp_candidate.append([tt, cc, jj])
1372+
counter['candidate'] += len(idx_candidate)
1373+
idx_rest_accurate = np.where(all_conf[ii][7:] < f_trust_lo)[0]
1374+
if detailed_report_make_fp:
1375+
for jj in idx_rest_accurate:
1376+
fp_rest_accurate.append([tt, cc, jj])
1377+
counter['accurate'] += len(idx_rest_accurate)
1378+
idx_rest_failed = np.where(all_conf[ii][7:] >= f_trust_hi)[0]
1379+
if detailed_report_make_fp:
1380+
for jj in idx_rest_failed:
1381+
fp_rest_failed.append([tt, cc, jj])
1382+
counter['failed'] += len(idx_rest_failed)
1383+
1384+
return fp_rest_accurate, fp_candidate, fp_rest_failed, counter
1385+
1386+
1387+
1388+
def _select_by_model_devi_adaptive_trust_low(
1389+
modd_system_task: List[str],
1390+
f_trust_hi : float,
1391+
numb_candi_f : int,
1392+
perc_candi_f : float,
1393+
v_trust_hi : float,
1394+
numb_candi_v : int,
1395+
perc_candi_v : float,
1396+
model_devi_skip : int = 0
1397+
):
1398+
"""
1399+
modd_system_task model deviation tasks belonging to one system
1400+
f_trust_hi
1401+
numb_candi_f number of candidate due to the f model deviation
1402+
perc_candi_f percentage of candidate due to the f model deviation
1403+
v_trust_hi
1404+
numb_candi_v number of candidate due to the v model deviation
1405+
perc_candi_v percentage of candidate due to the v model deviation
1406+
model_devi_skip
1407+
1408+
returns
1409+
accur the accurate set
1410+
candi the candidate set
1411+
failed the failed set
1412+
counter counters, number of elements in the sets
1413+
f_trust_lo adapted trust level of f
1414+
v_trust_lo adapted trust level of v
1415+
"""
1416+
idx_v = 1
1417+
idx_f = 4
1418+
accur = set()
1419+
candi = set()
1420+
failed = []
1421+
coll_v = []
1422+
coll_f = []
1423+
for tt in modd_system_task:
1424+
with warnings.catch_warnings():
1425+
warnings.simplefilter("ignore")
1426+
model_devi = np.loadtxt(os.path.join(tt, 'model_devi.out'))
1427+
for ii in range(model_devi.shape[0]) :
1428+
if model_devi[ii][0] < model_devi_skip :
1429+
continue
1430+
cc = int(model_devi[ii][0])
1431+
# tt: name of task folder
1432+
# cc: time step of the frame
1433+
md_v = model_devi[ii][idx_v]
1434+
md_f = model_devi[ii][idx_f]
1435+
if md_f > f_trust_hi or md_v > v_trust_hi:
1436+
failed.append([tt, cc])
1437+
else:
1438+
coll_v.append([model_devi[ii][idx_v], tt, cc])
1439+
coll_f.append([model_devi[ii][idx_f], tt, cc])
1440+
# now accur takes all non-failed frames,
1441+
# will be substracted by candidate lat er
1442+
accur.add((tt, cc))
1443+
# sort
1444+
coll_v.sort()
1445+
coll_f.sort()
1446+
assert(len(coll_v) == len(coll_f))
1447+
# calcuate numbers
1448+
numb_candi_v = max(numb_candi_v, int(perc_candi_v * 0.01 * len(coll_v)))
1449+
numb_candi_f = max(numb_candi_f, int(perc_candi_f * 0.01 * len(coll_f)))
1450+
# adjust number of candidate
1451+
if len(coll_v) < numb_candi_v:
1452+
numb_candi_v = len(coll_v)
1453+
if len(coll_f) < numb_candi_f:
1454+
numb_candi_f = len(coll_f)
1455+
# compute trust lo
1456+
if numb_candi_v == 0:
1457+
v_trust_lo = v_trust_hi
1458+
else:
1459+
v_trust_lo = coll_v[-numb_candi_v][0]
1460+
if numb_candi_f == 0:
1461+
f_trust_lo = f_trust_hi
1462+
else:
1463+
f_trust_lo = coll_f[-numb_candi_f][0]
1464+
# add to candidate set
1465+
for ii in range(len(coll_v) - numb_candi_v, len(coll_v)):
1466+
candi.add(tuple(coll_v[ii][1:]))
1467+
for ii in range(len(coll_f) - numb_candi_f, len(coll_f)):
1468+
candi.add(tuple(coll_f[ii][1:]))
1469+
# accurate set is substracted by the candidate set
1470+
accur = accur - candi
1471+
# convert to list
1472+
candi = [list(ii) for ii in candi]
1473+
accur = [list(ii) for ii in accur]
1474+
# counters
1475+
counter = Counter()
1476+
counter['candidate'] = len(candi)
1477+
counter['failed'] = len(failed)
1478+
counter['accurate'] = len(accur)
1479+
1480+
return accur, candi, failed, counter, f_trust_lo, v_trust_lo
1481+
1482+
13241483
def _make_fp_vasp_inner (modd_path,
13251484
work_path,
13261485
model_devi_skip,
1327-
e_trust_lo,
1328-
e_trust_hi,
1486+
v_trust_lo,
1487+
v_trust_hi,
13291488
f_trust_lo,
13301489
f_trust_hi,
13311490
fp_task_min,
@@ -1352,63 +1511,41 @@ def _make_fp_vasp_inner (modd_path,
13521511

13531512
fp_tasks = []
13541513
cluster_cutoff = jdata['cluster_cutoff'] if jdata.get('use_clusters', False) else None
1514+
model_devi_adapt_trust_lo = jdata.get('model_devi_adapt_trust_lo', False)
13551515
# skip save *.out if detailed_report_make_fp is False, default is True
13561516
detailed_report_make_fp = jdata.get("detailed_report_make_fp", True)
13571517
# skip bad box criteria
13581518
skip_bad_box = jdata.get('fp_skip_bad_box')
13591519
# skip discrete structure in cluster
13601520
fp_cluster_vacuum = jdata.get('fp_cluster_vacuum',None)
13611521
for ss in system_index :
1362-
fp_candidate = []
1363-
if detailed_report_make_fp:
1364-
fp_rest_accurate = []
1365-
fp_rest_failed = []
13661522
modd_system_glob = os.path.join(modd_path, 'task.' + ss + '.*')
13671523
modd_system_task = glob.glob(modd_system_glob)
13681524
modd_system_task.sort()
1369-
cc = 0
1370-
counter = Counter()
1371-
counter['candidate'] = 0
1372-
counter['failed'] = 0
1373-
counter['accurate'] = 0
1374-
for tt in modd_system_task :
1375-
with warnings.catch_warnings():
1376-
warnings.simplefilter("ignore")
1377-
all_conf = np.loadtxt(os.path.join(tt, 'model_devi.out'))
1378-
for ii in range(all_conf.shape[0]) :
1379-
if all_conf[ii][0] < model_devi_skip :
1380-
continue
1381-
cc = int(all_conf[ii][0])
1382-
if cluster_cutoff is None:
1383-
if (all_conf[ii][1] < e_trust_hi and all_conf[ii][1] >= e_trust_lo) or \
1384-
(all_conf[ii][4] < f_trust_hi and all_conf[ii][4] >= f_trust_lo) :
1385-
fp_candidate.append([tt, cc])
1386-
counter['candidate'] += 1
1387-
elif (all_conf[ii][1] >= e_trust_hi ) or (all_conf[ii][4] >= f_trust_hi ):
1388-
if detailed_report_make_fp:
1389-
fp_rest_failed.append([tt, cc])
1390-
counter['failed'] += 1
1391-
elif (all_conf[ii][1] < e_trust_lo and all_conf[ii][4] < f_trust_lo ):
1392-
if detailed_report_make_fp:
1393-
fp_rest_accurate.append([tt, cc])
1394-
counter['accurate'] += 1
1395-
else :
1396-
raise RuntimeError('md traj %s frame %d with f devi %f does not belong to either accurate, candidiate and failed, it should not happen' % (tt, ii, all_conf[ii][4]))
1397-
else:
1398-
idx_candidate = np.where(np.logical_and(all_conf[ii][7:] < f_trust_hi, all_conf[ii][7:] >= f_trust_lo))[0]
1399-
for jj in idx_candidate:
1400-
fp_candidate.append([tt, cc, jj])
1401-
counter['candidate'] += len(idx_candidate)
1402-
idx_rest_accurate = np.where(all_conf[ii][7:] < f_trust_lo)[0]
1403-
if detailed_report_make_fp:
1404-
for jj in idx_rest_accurate:
1405-
fp_rest_accurate.append([tt, cc, jj])
1406-
counter['accurate'] += len(idx_rest_accurate)
1407-
idx_rest_failed = np.where(all_conf[ii][7:] >= f_trust_hi)[0]
1408-
if detailed_report_make_fp:
1409-
for jj in idx_rest_failed:
1410-
fp_rest_failed.append([tt, cc, jj])
1411-
counter['failed'] += len(idx_rest_failed)
1525+
1526+
# assumed e -> v
1527+
if not model_devi_adapt_trust_lo:
1528+
fp_rest_accurate, fp_candidate, fp_rest_failed, counter \
1529+
= _select_by_model_devi_standard(
1530+
modd_system_task,
1531+
f_trust_lo, f_trust_hi,
1532+
v_trust_lo, v_trust_hi,
1533+
cluster_cutoff,
1534+
model_devi_skip,
1535+
detailed_report_make_fp = detailed_report_make_fp)
1536+
else:
1537+
numb_candi_f = jdata.get('model_devi_numb_candi_f', 10)
1538+
numb_candi_v = jdata.get('model_devi_numb_candi_v', 0)
1539+
perc_candi_f = jdata.get('model_devi_perc_candi_f', 0.)
1540+
perc_candi_v = jdata.get('model_devi_perc_candi_v', 0.)
1541+
fp_rest_accurate, fp_candidate, fp_rest_failed, counter, f_trust_lo_ad, v_trust_lo_ad \
1542+
= _select_by_model_devi_adaptive_trust_low(
1543+
modd_system_task,
1544+
f_trust_hi, numb_candi_f, perc_candi_f,
1545+
v_trust_hi, numb_candi_v, perc_candi_v,
1546+
model_devi_skip = model_devi_skip)
1547+
dlog.info("system {0:s} {1:9s} : f_trust_lo {2:6.3f} v_trust_lo {3:6.3f}".format(ss, 'adapted', f_trust_lo_ad, v_trust_lo_ad))
1548+
14121549
# print a report
14131550
fp_sum = sum(counter.values())
14141551
for cc_key, cc_value in counter.items():
@@ -1768,8 +1905,8 @@ def _make_fp_vasp_configs(iter_index,
17681905
jdata):
17691906
fp_task_max = jdata['fp_task_max']
17701907
model_devi_skip = jdata['model_devi_skip']
1771-
e_trust_lo = 1e+10
1772-
e_trust_hi = 1e+10
1908+
v_trust_lo = jdata.get('model_devi_v_trust_lo', 1e10)
1909+
v_trust_hi = jdata.get('model_devi_v_trust_hi', 1e10)
17731910
f_trust_lo = jdata['model_devi_f_trust_lo']
17741911
f_trust_hi = jdata['model_devi_f_trust_hi']
17751912
type_map = jdata['type_map']
@@ -1787,7 +1924,7 @@ def _make_fp_vasp_configs(iter_index,
17871924
# make configs
17881925
fp_tasks = _make_fp_vasp_inner(modd_path, work_path,
17891926
model_devi_skip,
1790-
e_trust_lo, e_trust_hi,
1927+
v_trust_lo, v_trust_hi,
17911928
f_trust_lo, f_trust_hi,
17921929
task_min, fp_task_max,
17931930
[],

0 commit comments

Comments
 (0)