22from lambench .metrics .post_process import DIRECT_TASK_WEIGHTS
33from lambench .models .basemodel import BaseLargeAtomModel
44from lambench .databases .direct_predict_table import DirectPredictRecord
5- from lambench .metrics .utils import get_domain_to_direct_task_mapping
6- from lambench .metrics .utils import filter_direct_task_results , exp_average
5+ from lambench .databases .calculator_table import CalculatorRecord
6+ from lambench .metrics .utils import (
7+ get_domain_to_direct_task_mapping ,
8+ aggregated_nve_md_results ,
9+ filter_direct_task_results ,
10+ exp_average ,
11+ )
712from lambench .workflow .entrypoint import gather_models
13+ import numpy as np
814
915
1016def aggregate_domain_results_for_one_model (model : BaseLargeAtomModel ):
@@ -52,6 +58,30 @@ def aggregate_domain_results_for_one_model(model: BaseLargeAtomModel):
5258 return domain_results
5359
5460
61+ def fetch_stability_results (model : BaseLargeAtomModel ) -> float :
62+ """
63+ Fetch stability metrics for a model based on NVE MD simulations.
64+
65+ The stability is measured as the energy drift slope minus the logarithm of the success rate divided by 1000.
66+ A lower value indicates better stability.
67+ """
68+ task_results = CalculatorRecord .query (
69+ model_name = model .model_name , task_name = "nve_md"
70+ )
71+
72+ if len (task_results ) != 1 :
73+ logging .warning (
74+ f"Expected one record for { model .model_name } and nve_md, but got { len (task_results )} "
75+ )
76+ return None
77+
78+ metrics = aggregated_nve_md_results (task_results [0 ].metrics )
79+ slope = metrics ["slope" ]
80+ success_rate = metrics ["success_rate" ]
81+
82+ return slope - np .log (success_rate ) / 1000 # to penalize failed simulations
83+
84+
5585def aggregate_domain_results ():
5686 """
5787 This function aggregates the results across models and domains.
@@ -67,10 +97,9 @@ def aggregate_domain_results():
6797 ]
6898
6999 for model in leaderboard_models :
70- results [model .model_name ] = aggregate_domain_results_for_one_model (model )
100+ domain_results = aggregate_domain_results_for_one_model (model )
101+ stability = fetch_stability_results (model )
102+ domain_results ["Stability" ] = stability
103+ results [model .model_name ] = domain_results
71104
72105 return results
73-
74-
75- if __name__ == "__main__" :
76- print (aggregate_domain_results ())
0 commit comments