44import logging
55import time
66import numpy as np
7- from typing import List , Dict
7+ from typing import List , Dict , Tuple
88from pathlib import Path
99
1010logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filemode = 'w' , filename = 'infer.log' )
1313def run_batch_infer (
1414 model : ASEModel ,
1515 test_data : Path ,
16- timimg_ratio : float
17- ) -> Dict [str , float ]:
16+ warmup_ratio : float
17+ ) -> Dict [str , Dict [ str , float ] ]:
1818 """
19- Infer for all batches
19+ Infer for all batches, return average time and success rate for each system.
2020 """
2121 results = {}
2222 subfolders = [subfolder for subfolder in test_data .iterdir () if subfolder .is_dir ()]
2323 for subfolder in subfolders :
2424 system_name = subfolder .name
2525 try :
26- batch_result = run_one_batch_infer (model , subfolder , timimg_ratio )
26+ batch_result = run_one_batch_infer (model , subfolder , warmup_ratio )
2727 average_time = batch_result ["average_time_per_step" ]
28- results [system_name ] = average_time
29- logging .info (f"Batch inference completed for system { system_name } with average time { average_time } s" )
28+ success_rate = batch_result ["success_rate" ]
29+ results [system_name ] = {
30+ "average_time_per_step" : average_time ,
31+ "success_rate" : success_rate
32+ }
33+ logging .info (f"Batch inference completed for system { system_name } with average time { average_time } s and success rate { success_rate :.2f} %" )
3034 except Exception as e :
3135 logging .error (f"Error in batch inference for system { system_name } : { e } " )
36+ results [system_name ] = {
37+ "average_time_per_step" : None ,
38+ "success_rate" : 0.0
39+ }
3240 return results
3341
3442
3543def run_one_batch_infer (
3644 model : ASEModel ,
3745 test_data : Path ,
38- timimg_ratio : float
46+ warmup_ratio : float
3947) -> Dict [str , float ]:
4048 """
41- Infer for one batch, return averaged time, starting timing at timimg_ratio .
49+ Infer for one batch, return averaged time and success rate , starting timing at warmup_ratio .
4250 """
4351 test_files = list (test_data .glob ("*.vasp" ))
4452 test_atoms = [read (file ) for file in test_files ]
45- start_index = int (len (test_atoms ) * timimg_ratio )
53+ start_index = int (len (test_atoms ) * warmup_ratio )
4654 total_time = 0
4755 valid_steps = 0
56+ successful_inferences = 0
57+ total_inferences = len (test_atoms )
58+
4859 for i , atoms in enumerate (test_atoms ):
4960 atoms .calc = model .calc
5061 start = time .time ()
@@ -64,6 +75,7 @@ def run_one_batch_infer(
6475 stress_tensor [2 , 0 ] = stress [4 ]
6576 stress_tensor [1 , 0 ] = stress [5 ]
6677 virial = - stress_tensor * volume
78+ successful_inferences += 1
6779 except Exception as e :
6880 logging .error (f"Error in inference for { str (atoms .symbols )} : { e } " )
6981 continue
@@ -82,6 +94,12 @@ def run_one_batch_infer(
8294 else :
8395 average_time_per_step = np .nan
8496
97+ if total_inferences > 0 :
98+ success_rate = (successful_inferences / total_inferences ) * 100
99+ else :
100+ success_rate = 0.0
101+
85102 return {
86103 "average_time_per_step" : average_time_per_step ,
87- }
104+ "success_rate" : success_rate
105+ }
0 commit comments