Skip to content

Commit e9efb10

Browse files
authored
Merge pull request #348 from deepmodeling/damp
feat: implement dft damping for ase calc
2 parents 7e514fc + 2e2407f commit e9efb10

File tree

8 files changed

+42
-21
lines changed

8 files changed

+42
-21
lines changed

lambench/metrics/direct_task_weights.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ MD22:
3535
energy_weight: 1.0
3636
force_weight: 1.0
3737
virial_weight: null
38-
energy_std: 0.007941836149915322
39-
force_std: 1.1391327961625524
38+
energy_std: 0.008959619353114803
39+
force_std: 1.1964522496892305
4040
virial_std: null
4141
REANN_CO2_Ni100:
4242
domain: Catalysis

lambench/metrics/results/metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
},
5454
"MD22": {
5555
"DISPLAY_NAME": "MD22",
56-
"DESCRIPTION": "Dataset containing MD trajectories of the 42-atom tetrapeptide Ac-Ala3-NHMe from the MD22 benchmark set. Calculations were performed using FHI-aims and i-Pi software at the DFT-PBE+MBD level of theory. Trajectories were sampled at temperatures between 400-500 K at 1 fs resolution. [https://www.science.org/doi/10.1126/sciadv.adf0873]",
56+
"DESCRIPTION": "Dataset containing MD trajectories of the 42-atom tetrapeptide Ac-Ala3-NHMe from the MD22 benchmark set. Calculations were performed using FHI-aims and i-Pi software at the DFT-PBE+MBD level of theory. The dataset was relabeled using Gaussian with PBE/6-31G(d). Trajectories were sampled at temperatures between 400-500 K at 1 fs resolution. [https://www.science.org/doi/10.1126/sciadv.adf0873]",
5757
"domain": "Molecules",
5858
"energy_rmse": {
5959
"DISPLAY_NAME": "E RMSE (meV)",

lambench/models/ase_models.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from functools import cached_property
44
from pathlib import Path
5-
from typing import Callable, Optional
5+
from typing import Callable, Literal, Optional
66

77
import dpdata
88
import numpy as np
@@ -12,10 +12,12 @@
1212
)
1313
from ase import Atoms
1414
from ase.calculators.calculator import Calculator
15+
from ase.calculators.mixing import SumCalculator
1516
from ase.constraints import FixSymmetry
1617
from ase.filters import FrechetCellFilter
1718
from ase.io import write
1819
from ase.optimize import FIRE
20+
from dftd3.ase import DFTD3
1921
from tqdm import tqdm
2022

2123
from lambench.models.basemodel import BaseLargeAtomModel
@@ -179,7 +181,7 @@ def evaluate(
179181
import torch
180182

181183
torch.set_default_dtype(torch.float32)
182-
return self.run_ase_dptest(self, task.test_data)
184+
return self.run_ase_dptest(self, task.test_data, task.dispersion_correction)
183185
elif isinstance(task, CalculatorTask):
184186
if task.task_name == "nve_md":
185187
from lambench.tasks.calculator.nve_md.nve_md import (
@@ -265,7 +267,12 @@ def evaluate(
265267
)
266268

267269
@staticmethod
268-
def run_ase_dptest(model: ASEModel, test_data: Path) -> dict:
270+
def run_ase_dptest(
271+
model: ASEModel,
272+
test_data: Path,
273+
dispersion_correction: Literal["d3bj", "d3zero"] | None = None,
274+
# check all supported levels at dftd3.qcschema._available_levels
275+
) -> dict:
269276
# Add fparam for charge and spin multiplicity if needed
270277
datatype = DataType(
271278
"fparam",
@@ -277,6 +284,10 @@ def run_ase_dptest(model: ASEModel, test_data: Path) -> dict:
277284
dpdata.LabeledSystem.register_data_type(datatype)
278285

279286
calc = model.calc
287+
if dispersion_correction:
288+
calc = SumCalculator(
289+
[calc, DFTD3(method="PBE", dispersion_correction=dispersion_correction)]
290+
)
280291

281292
energy_err = []
282293
energy_pre = []

lambench/tasks/direct/direct_tasks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import ClassVar
2+
from typing import ClassVar, Literal
33
from lambench.tasks.base_task import BaseTask
44
from lambench.databases.direct_predict_table import DirectPredictRecord
55

@@ -12,6 +12,8 @@ class DirectPredictTask(BaseTask):
1212

1313
record_type: ClassVar = DirectPredictRecord
1414
task_config: ClassVar = Path(__file__).parent / "direct_tasks.yml"
15+
dispersion_correction: Literal["d3bj", "d3zero"] | None = None
1516

1617
def __init__(self, task_name: str, **kwargs):
1718
super().__init__(task_name=task_name, test_data=kwargs["test_data"])
19+
self.dispersion_correction = kwargs.get("dispersion_correction")

lambench/tasks/direct/direct_tasks.yml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
ANI:
2-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/ANI"
2+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
33
HEA25_S:
4-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/HEA25S"
4+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S"
55
HEA25_bulk:
6-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/HEA25"
6+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25"
77
MoS2:
8-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/MoS2"
8+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2"
99
MD22:
10-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/MD22"
10+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
1111
REANN_CO2_Ni100:
12-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/REANN_CO2_Ni100"
12+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100"
1313
NequIP_NC_2022:
14-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/NequIP_NC_2022"
14+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/NequIP_NC_2022"
1515
AIMD-Chig:
16-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/AIMD_chig"
16+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/AIMD_chig"
1717
Cu_MgO_catalysts:
18-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/Cu_MgO_CO2"
18+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Cu_MgO_CO2"
19+
dispersion_correction: d3zero
1920
Si_ZEO22:
20-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/Si_ZEO22"
21+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Si_ZEO22"
22+
dispersion_correction: d3bj
2123
HPt_NC_2022:
22-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/HPt_NC2022"
24+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022"
2325
Ca_batteries_CM2021:
24-
test_data: "/bohr/lambench-ood-zwtr/v3/LAMBench-TestData-v3/Ca_batteries"
26+
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries"
2527
## DEPRECATED
2628
# Collision:
2729
# test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"

lambench/workflow/dflow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from dflow.plugins.bohrium import BohriumDatasetsArtifact, create_job_group
1313
from dflow.plugins.dispatcher import DispatcherExecutor
1414
from dflow.python import OP, Artifact, PythonOPTemplate
15+
1516
import dpdata
17+
import dftd3
18+
import cffi
19+
import pycparser
1620

1721
import lambench
1822
from lambench.models.basemodel import BaseLargeAtomModel
@@ -60,7 +64,8 @@ def submit_tasks_dflow(
6064
image=model.virtualenv,
6165
envs={k: v for k, v in os.environ.items() if k.startswith("MYSQL")},
6266
python_packages=[
63-
Path(package.__path__[0]) for package in [lambench, dpdata]
67+
Path(package.__path__[0])
68+
for package in [lambench, dpdata, dftd3, cffi, pycparser]
6469
],
6570
),
6671
parameters={

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"tqdm",
1616
"dpdata @ git+https://github.com/deepmodeling/dpdata.git@devel#egg=dpdata",
1717
"pandas",
18+
"dftd3"
1819
]
1920

2021
authors = [

tests/metrics/test_visualization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_aggregate_ood_results_for_one_model(
1414
model.show_calculator_task = False
1515
aggregator = ResultsFetcher()
1616
result = aggregator.aggregate_ood_results_for_one_model(model=model)
17-
np.testing.assert_almost_equal(result["Molecules"], 0.234724350, decimal=5)
17+
np.testing.assert_almost_equal(result["Molecules"], desired=0.22748765, decimal=5)
1818
np.testing.assert_almost_equal(result["Inorganic Materials"], 0.2972349, decimal=5)
1919
assert result["Catalysis"] is None
2020
with caplog.at_level(logging.WARNING):

0 commit comments

Comments
 (0)