Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lambench/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lambench
from pathlib import Path
from collections import defaultdict
from lambench.workflow.entrypoint import gather_models
from lambench.workflow.entrypoint import gather_model_params, gather_model
from datetime import datetime

#############################
Expand All @@ -13,7 +13,7 @@


def get_leaderboard_models(timestamp: Optional[datetime] = None) -> list:
models = gather_models()
models = [gather_model(param, "") for param in gather_model_params()]
if timestamp is not None:
models = [
model for model in models if model.model_metadata.date_added <= timestamp
Expand Down
35 changes: 26 additions & 9 deletions lambench/models/ase_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
import logging
from functools import cached_property
from pathlib import Path
from typing import Callable, Literal, Optional

Expand All @@ -17,6 +16,7 @@
from ase.filters import FrechetCellFilter
from ase.io import write
from ase.optimize import FIRE
from ase.calculators.emt import EMT
from dftd3.ase import DFTD3
from tqdm import tqdm

Expand Down Expand Up @@ -80,8 +80,8 @@ class ASEModel(BaseLargeAtomModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@cached_property
def calc(self, head=None) -> Calculator:
@property
def calc(self) -> Calculator:
"""ASE Calculator with the model loaded."""
calculator_dispatch = {
"MACE": self._init_mace_calculator,
Expand All @@ -96,10 +96,18 @@ def calc(self, head=None) -> Calculator:
}

if self.model_family not in calculator_dispatch:
raise ValueError(f"Model {self.model_name} is not supported by ASEModel")
logging.warning(
f"Model {self.model_name} is not supported by ASEModel, using EMT as default calculator."
)
return EMT()

return calculator_dispatch[self.model_family]()

@calc.setter
def calc(self, value: Calculator):
logging.warning("Overriding the default calculator.")
self._calc = value

def _init_mace_calculator(self) -> Calculator:
from mace.calculators import mace_mp

Expand Down Expand Up @@ -139,7 +147,10 @@ def _init_uma_calculator(self) -> Calculator:
from fairchem.core import FAIRChemCalculator

predictor = load_predict_unit(self.model_path, device="cuda")
return FAIRChemCalculator(predictor, task_name="omat")
if self.model_domain == "molecules":
return FAIRChemCalculator(predictor, task_name="omol")
else:
return FAIRChemCalculator(predictor, task_name="omat")

def _init_mattersim_calculator(self) -> Calculator:
from mattersim.forcefield import MatterSimCalculator
Expand All @@ -149,10 +160,16 @@ def _init_mattersim_calculator(self) -> Calculator:
def _init_dp_calculator(self) -> Calculator:
from deepmd.calculator import DP

return DP(
model=self.model_path,
head="MP_traj_v024_alldata_mixu",
)
if self.supports_omol and self.model_domain == "molecules":
return DP(
model=self.model_path,
head="OMol25",
)
else:
return DP(
model=self.model_path,
head="MP_traj_v024_alldata_mixu",
)

def _init_grace_calculator(self) -> Calculator:
from tensorpotential.calculator import grace_fm
Expand Down
4 changes: 4 additions & 0 deletions lambench/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class BaseLargeAtomModel(BaseModel):
show_finetune_task (bool): Flag indicating if the finetune task should be displayed or executed. Default is False.
show_calculator_task (bool): Flag indicating if the calculator task should be displayed or executed. Default is False.
skip_tasks (list[SkipTaskType]): List of task types that should be skipped during evaluation.
supports_omol (bool): Flag indicating if the model is trained with OMol25 or not.
model_domain (Optional[str]): The model head or task_name to be used for models with multiple domains. Default is None, referring to the head used for `materials` often MPTrj.
Methods:
evaluate(task) -> dict[str, float]:
Abstract method for evaluating the model on a given task. Implementations should return
Expand All @@ -58,6 +60,8 @@ class BaseLargeAtomModel(BaseModel):
show_finetune_task: bool = False
show_calculator_task: bool = False
skip_tasks: list[SkipTaskType] = []
supports_omol: bool = False
model_domain: Optional[str] = None

@abstractmethod
def evaluate(self, task) -> dict[str, float]:
Expand Down
56 changes: 34 additions & 22 deletions lambench/workflow/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,39 @@
MODELS = Path(lambench.__file__).parent / "models/models_config.yml"


def gather_models(
def gather_model_params(
model_names: Optional[list[str]] = None,
) -> list[BaseLargeAtomModel]:
) -> list[dict]:
"""
Gather models from the models_config.yml file.
Gather model parameters from the models_config.yml file for selected models.
"""

models = []
model_params = []
with open(MODELS, "r") as f:
model_config: list[dict] = yaml.safe_load(f)
for model_param in model_config:
if model_names and model_param["model_name"] not in model_names:
continue
if model_param["model_type"] == "DP":
models.append(DPModel(**model_param))
elif model_param["model_type"] == "ASE":
models.append(ASEModel(**model_param))
else:
raise ValueError(
f"Model type {model_param['model_type']} is not supported."
)
return models
model_params.append(model_param)

return model_params


def gather_model(model_param: dict, model_domain: str) -> BaseLargeAtomModel:
model_param["model_domain"] = model_domain
if model_param["model_type"] == "DP":
return DPModel(**model_param)
elif model_param["model_type"] == "ASE":
return ASEModel(**model_param)
else:
raise ValueError(f"Model type {model_param['model_type']} is not supported.")


job_list: TypeAlias = list[tuple[BaseTask, BaseLargeAtomModel]]


def gather_task_type(
models: list[BaseLargeAtomModel],
model_params: list[dict],
task_class: Type[BaseTask],
task_names: Optional[list[str]] = None,
) -> job_list:
Expand All @@ -54,18 +58,26 @@ def gather_task_type(
tasks = []
with open(task_class.task_config, "r") as f:
task_configs: dict[str, dict] = yaml.safe_load(f)
for model in models:
if not hasattr(model, "_finetune") and issubclass(
for model_param in model_params:
if not model_param["model_type"] == "DP" and issubclass(
task_class, PropertyFinetuneTask
):
continue # Regular ASEModel does not support PropertyFinetuneTask
for task_name, task_params in task_configs.items():
if (task_names and task_name not in task_names) or task_class.__name__ in (
model.skip_tasks
model_param["skip_tasks"]
):
continue
task = task_class(task_name=task_name, **task_params)
if not task.exist(model.model_name):
if not task.exist(model_param["model_name"]):
# model_domain = task.domain if task.domain else "" # in the future we may have tasks with specific domain.

# currently only need to distinguish direct tasks for molecules and materials due to OMol25 training set.
if task_name in []: # to be added in a separate PR.
model_domain = "molecules"
else:
model_domain = "materials"
model = gather_model(model_param, model_domain)
tasks.append((task, model))
return tasks

Expand All @@ -77,18 +89,18 @@ def gather_jobs(
) -> job_list:
jobs: job_list = []

models = gather_models(model_names)
if not models:
model_params = gather_model_params(model_names)
if not model_params:
logging.warning("No models found, skipping task gathering.")
return jobs

logging.info(f"Found {len(models)} models, gathering tasks.")
logging.info(f"Found {len(model_params)} models, gathering tasks.")
for task_class in BaseTask.__subclasses__():
if task_types and task_class.__name__ not in task_types:
continue
jobs.extend(
gather_task_type(
models=models, task_class=task_class, task_names=task_names
model_params=model_params, task_class=task_class, task_names=task_names
)
)

Expand Down
16 changes: 4 additions & 12 deletions tests/tasks/calculator/test_nve_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from lambench.metrics.utils import aggregated_nve_md_results
import pytest
from ase import Atoms
from ase.calculators.emt import EMT
from lambench.models.ase_models import ASEModel
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Calculator' is not used.

Suggested change
from lambench.models.ase_models import ASEModel

Copilot uses AI. Check for mistakes.
import numpy as np

Expand All @@ -19,13 +18,7 @@ def setup_testing_data():


@pytest.fixture
def setup_calculator():
"""Fixture to provide an ASE calculator (EMT)."""
return EMT()


@pytest.fixture
def setup_model(setup_calculator):
def setup_model():
"""Fixture to provide an ASE model."""
ase_models = ASEModel(
model_family="TEST",
Expand All @@ -39,15 +32,14 @@ def setup_model(setup_calculator):
},
virtualenv="",
)
ase_models.calc = setup_calculator
return ase_models


def test_nve_simulation_metrics(setup_testing_data, setup_calculator):
def test_nve_simulation_metrics(setup_testing_data, setup_model):
"""Test NVE simulation metrics for std, and steps."""
result = nve_simulation_single(
setup_testing_data,
setup_calculator,
setup_model.calc,
timestep=1.0,
num_steps=100,
temperature_K=300,
Expand All @@ -58,7 +50,7 @@ def test_nve_simulation_metrics(setup_testing_data, setup_calculator):
assert isinstance(result["slope"], float), "Slope should be a float."


def test_nve_simulation_crash_handling(setup_testing_data, setup_calculator):
def test_nve_simulation_crash_handling(setup_testing_data):
"""Test crash handling by simulating an intentional crash."""
atoms = setup_testing_data

Expand Down
19 changes: 9 additions & 10 deletions tests/workflow/test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from lambench.models.dp_models import DPModel
from lambench.tasks import PropertyFinetuneTask
import pytest
from lambench.workflow.entrypoint import gather_task_type
from unittest.mock import MagicMock


def _create_dp_model(skip_tasks=[]):
return DPModel(
model_name="test_model",
model_family="test_family",
model_type="DP",
model_path="test_path",
virtualenv="test_env",
model_metadata={
return {
"model_name": "test_model",
"model_family": "test_family",
"model_type": "DP",
"model_path": "test_path",
"virtualenv": "test_env",
"model_metadata": {
"pretty_name": "test",
"date_added": "2023-10-01",
"extra_content": "test",
"num_parameters": 1000,
"packages": {"torch": "2.0.0"},
},
skip_tasks=skip_tasks,
)
"skip_tasks": skip_tasks,
}


@pytest.fixture
Expand Down
Loading