Skip to content

Commit 4f92efb

Browse files
authored
Merge pull request #98 from deepmodeling/dev-gmy
feat: batch infer
2 parents cad1d4f + b2413d0 commit 4f92efb

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

CITATION.cff

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ authors:
1717
orcid: https://orcid.org/0000-0001-6242-0439
1818
- family-names: Guo
1919
given-names: Mingyu
20-
affiliation: AI for Science Institute, Beijing
20+
affiliation: DP Technology; School of Chemistry, Sun Yat-sen University, Guangzhou
21+
orcid: https://orcid.org/0009-0008-3744-1543
2122
- family-names: Zhang
2223
given-names: Duo
2324
affiliation: AI for Science Institute, Beijing

lambench/models/ase_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def evaluate(self, task) -> Optional[dict[str, float]]:
156156
self, task.test_data, distance, task.workdir
157157
)
158158
}
159+
elif task.task_name == "batch_inference_efficiency":
160+
from lambench.tasks.calculator.infer_efficiency.infer_efficiency import (
161+
run_batch_infer,
162+
)
163+
warmup_ratio = task.calculator_params.get("warmup_ratio", 0.2)
164+
return {
165+
"metrics": run_batch_infer(
166+
self, task.test_data, warmup_ratio
167+
)
168+
}
159169
else:
160170
raise NotImplementedError(f"Task {task.task_name} is not implemented.")
161171

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@ phonon_mdr:
88
test_data: /bohr/lambench-phonon-y7vk/v1/MDR_PBE_phonon
99
calculator_params:
1010
distance: 0.01
11+
batch_inference_efficiency:
12+
test_data: /bohr/batch-infer-7ipn/v1/batch_infer_confs
13+
calculator_params:
14+
warmup_ratio: 0.2
15+
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from lambench.models.ase_models import ASEModel
2+
from ase import Atoms
3+
from ase.io import read
4+
import logging
5+
import time
6+
import numpy as np
7+
from typing import List, Dict, Tuple
8+
from pathlib import Path
9+
10+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w', filename='infer.log')
11+
12+
13+
def run_batch_infer(
14+
model: ASEModel,
15+
test_data: Path,
16+
warmup_ratio: float
17+
) -> Dict[str, Dict[str, float]]:
18+
"""
19+
Infer for all batches, return average time and success rate for each system.
20+
"""
21+
results = {}
22+
subfolders = [subfolder for subfolder in test_data.iterdir() if subfolder.is_dir()]
23+
for subfolder in subfolders:
24+
system_name = subfolder.name
25+
try:
26+
batch_result = run_one_batch_infer(model, subfolder, warmup_ratio)
27+
average_time = batch_result["average_time_per_step"]
28+
success_rate = batch_result["success_rate"]
29+
results[system_name] = {
30+
"average_time_per_step": average_time,
31+
"success_rate": success_rate
32+
}
33+
logging.info(f"Batch inference completed for system {system_name} with average time {average_time} s and success rate {success_rate:.2f}%")
34+
except Exception as e:
35+
logging.error(f"Error in batch inference for system {system_name}: {e}")
36+
results[system_name] = {
37+
"average_time_per_step": None,
38+
"success_rate": 0.0
39+
}
40+
return results
41+
42+
43+
def run_one_batch_infer(
44+
model: ASEModel,
45+
test_data: Path,
46+
warmup_ratio: float
47+
) -> Dict[str, float]:
48+
"""
49+
Infer for one batch, return averaged time and success rate, starting timing at warmup_ratio.
50+
"""
51+
test_files = list(test_data.glob("*.vasp"))
52+
test_atoms = [read(file) for file in test_files]
53+
start_index = int(len(test_atoms) * warmup_ratio)
54+
total_time = 0
55+
valid_steps = 0
56+
successful_inferences = 0
57+
total_inferences = len(test_atoms)
58+
59+
for i, atoms in enumerate(test_atoms):
60+
atoms.calc = model.calc
61+
start = time.time()
62+
try:
63+
energy = atoms.get_potential_energy()
64+
forces = atoms.get_forces()
65+
stress = atoms.get_stress()
66+
volume = atoms.get_volume()
67+
stress_tensor = np.zeros((3, 3))
68+
stress_tensor[0, 0] = stress[0]
69+
stress_tensor[1, 1] = stress[1]
70+
stress_tensor[2, 2] = stress[2]
71+
stress_tensor[1, 2] = stress[3]
72+
stress_tensor[0, 2] = stress[4]
73+
stress_tensor[0, 1] = stress[5]
74+
stress_tensor[2, 1] = stress[3]
75+
stress_tensor[2, 0] = stress[4]
76+
stress_tensor[1, 0] = stress[5]
77+
virial = -stress_tensor * volume
78+
successful_inferences += 1
79+
except Exception as e:
80+
logging.error(f"Error in inference for {str(atoms.symbols)}: {e}")
81+
continue
82+
83+
end = time.time()
84+
elapsed_time = end - start
85+
86+
if i >= start_index:
87+
total_time += elapsed_time
88+
valid_steps += 1
89+
90+
logging.info(f"Inference completed for system {str(atoms.symbols)} in {elapsed_time} s")
91+
92+
if valid_steps > 0:
93+
average_time_per_step = total_time / valid_steps
94+
else:
95+
average_time_per_step = np.nan
96+
97+
if total_inferences > 0:
98+
success_rate = (successful_inferences / total_inferences) * 100
99+
else:
100+
success_rate = 0.0
101+
102+
return {
103+
"average_time_per_step": average_time_per_step,
104+
"success_rate": success_rate
105+
}

0 commit comments

Comments
 (0)