diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index e277c0708f..b8645f9bd4 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -71,6 +71,12 @@ jobs: uv pip install torch-runstats torch_dftd uv pip install --no-deps nequip==0.5.6 + - name: Install torchsim dependencies (Python 3.12+ only) + if: matrix.python-version == '3.12' + run: | + micromamba activate a2 + uv pip install .[torchsim] + - name: Install pymatgen from master if triggered by pymatgen repo dispatch if: github.event_name == 'repository_dispatch' && github.event.action == 'pymatgen-ci-trigger' run: | diff --git a/.gitignore b/.gitignore index 3bb2f9dca7..74ab53a90f 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ develop-eggs .installed.cfg lib lib64 +uv.lock # Installer logs pip-log.txt diff --git a/pyproject.toml b/pyproject.toml index c57b824ffb..2c7031449e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,9 @@ forcefields = [ "sevenn>=0.9.3", "deepmd-kit>=2.1.4", ] +torchsim = [ + "torch-sim-atomistic==0.4.1; python_version >= '3.12'" +] approxneb = ["pymatgen-analysis-diffusion>=2024.7.15"] ase = ["ase>=3.26.0"] ase-ext = ["tblite>=0.3.0; platform_system=='Linux'"] diff --git a/src/atomate2/torchsim/__init__.py b/src/atomate2/torchsim/__init__.py new file mode 100644 index 0000000000..111fdf7d76 --- /dev/null +++ b/src/atomate2/torchsim/__init__.py @@ -0,0 +1,9 @@ +"""TorchSim module for atomate2.""" + +from atomate2.torchsim.core import ( + TorchSimIntegrateMaker, + TorchSimOptimizeMaker, + TorchSimStaticMaker, +) + +__all__ = ["TorchSimIntegrateMaker", "TorchSimOptimizeMaker", "TorchSimStaticMaker"] diff --git a/src/atomate2/torchsim/core.py b/src/atomate2/torchsim/core.py new file mode 100644 index 0000000000..2e5b6ab8c0 --- /dev/null +++ b/src/atomate2/torchsim/core.py @@ -0,0 +1,635 @@ +"""Core module for TorchSim makers in atomate2.""" + +from __future__ import annotations + +import time +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import torch_sim as ts +from jobflow import Maker, Response, job +from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher + +from atomate2.torchsim.schema import ( + CONVERGENCE_FN_REGISTRY, + PROPERTY_FN_REGISTRY, + AutobatcherDetails, + ConvergenceFn, + PropertyFn, + TaskType, + TorchSimCalculation, + TorchSimModelType, + TorchSimTaskDoc, + TrajectoryReporterDetails, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + from pymatgen.core import Structure + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import Optimizer + from torch_sim.trajectory import TrajectoryReporter + + +def torchsim_job(method: Callable) -> job: + """Decorate the ``make`` method of TorchSim job makers. + + This is a thin wrapper around :obj:`~jobflow.core.job.Job` that configures common + settings for all TorchSim jobs. Namely, configures the output schema to be a + :obj:`.TorchSimTaskDoc`. + + Parameters + ---------- + method : callable + A TorchSim maker's make method. This should not be specified directly and is + implied by the decorator. + + Returns + ------- + callable + A decorated version of the make function that will generate jobs. + """ + return job(method, output_schema=TorchSimTaskDoc) + + +def process_trajectory_reporter_dict( + trajectory_reporter_dict: dict[str, Any] | None, +) -> tuple[TrajectoryReporter | None, TrajectoryReporterDetails | None]: + """Process the input dict into a TrajectoryReporter and details dictionary. + + Parameters + ---------- + trajectory_reporter_dict : dict[str, Any] | None + Dictionary configuration for the trajectory reporter. + + Returns + ------- + tuple[TrajectoryReporter | None, TrajectoryReporterDetails | None] + The trajectory reporter instance and its details dictionary. + """ + if trajectory_reporter_dict is None: + return None, None + trajectory_reporter_dict = deepcopy(trajectory_reporter_dict) + + prop_calculators = trajectory_reporter_dict.pop("prop_calculators", {}) + + # Convert prop_calculators to PropertyFn types and get functions + prop_calculators_typed: dict[int, list[PropertyFn]] = { + i: [PropertyFn(prop) if isinstance(prop, str) else prop for prop in props] + for i, props in prop_calculators.items() + } + prop_calculators_functions = { + i: {prop: PROPERTY_FN_REGISTRY[prop] for prop in props} + for i, props in prop_calculators_typed.items() + } + + trajectory_reporter = ts.TrajectoryReporter( + **trajectory_reporter_dict, prop_calculators=prop_calculators_functions + ) + + trajectory_reporter.filenames = [ + Path(p).resolve() for p in trajectory_reporter_dict.get("filenames", []) + ] + + reporter_details = TrajectoryReporterDetails( + state_frequency=trajectory_reporter.state_frequency, + trajectory_kwargs=trajectory_reporter.trajectory_kwargs, + prop_calculators=prop_calculators_typed, + state_kwargs=trajectory_reporter.state_kwargs, + metadata=trajectory_reporter.metadata, + filenames=trajectory_reporter.filenames, + ) + return trajectory_reporter, reporter_details + + +def _get_autobatcher_details( + autobatcher: InFlightAutoBatcher | BinningAutoBatcher, +) -> AutobatcherDetails: + """Extract the metadata of an autobatcher. + + Parameters + ---------- + autobatcher : InFlightAutoBatcher | BinningAutoBatcher + The autobatcher to convert. + + Returns + ------- + AutobatcherDetails + Dictionary representation of the autobatcher. + """ + return AutobatcherDetails( + autobatcher=type(autobatcher).__name__, # type: ignore[arg-type] + memory_scales_with=autobatcher.memory_scales_with, # type: ignore[arg-type] + max_memory_scaler=autobatcher.max_memory_scaler, + max_atoms_to_try=autobatcher.max_atoms_to_try, + memory_scaling_factor=autobatcher.memory_scaling_factor, + max_iterations=( + autobatcher.max_iterations + if isinstance(autobatcher, InFlightAutoBatcher) + else None + ), + max_memory_padding=autobatcher.max_memory_padding, + ) + + +def process_in_flight_autobatcher_dict( + structures: list[Structure], + model: ModelInterface, + autobatcher_dict: dict[str, Any] | bool, + max_iterations: int, +) -> tuple[InFlightAutoBatcher | bool, AutobatcherDetails | None]: + """Process the input dict into a InFlightAutoBatcher and details dictionary. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structures. + model : ModelInterface + The model interface. + autobatcher_dict : dict[str, Any] | bool + Dictionary configuration for the autobatcher or a boolean. + max_iterations : int + Maximum number of iterations. + + Returns + ------- + tuple[InFlightAutoBatcher | bool, AutobatcherDetails | None] + The autobatcher instance (or False) and its details dictionary. + """ + if isinstance(autobatcher_dict, bool): + # False means no autobatcher + if not autobatcher_dict: + return False, None + # otherwise, configure the autobatcher, with the private runners method + state = ts.initialize_state(structures, model.device, model.dtype) + autobatcher = ts.runners._configure_in_flight_autobatcher( # noqa: SLF001 + state, model, autobatcher=autobatcher_dict, max_iterations=max_iterations + ) + else: + autobatcher = InFlightAutoBatcher(model=model, **autobatcher_dict) + + autobatcher_details = _get_autobatcher_details(autobatcher) + return autobatcher, autobatcher_details + + +def process_binning_autobatcher_dict( + structures: list[Structure], + model: ModelInterface, + autobatcher_dict: dict[str, Any] | bool, +) -> tuple[BinningAutoBatcher | bool, AutobatcherDetails | None]: + """Process the input dict into a BinningAutoBatcher and details dictionary. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structures. + model : ModelInterface + The model interface. + autobatcher_dict : dict[str, Any] | bool + Dictionary configuration for the autobatcher or a boolean. + + Returns + ------- + tuple[BinningAutoBatcher | bool, AutobatcherDetails | None] + The autobatcher instance (or False) and its details dictionary. + """ + if isinstance(autobatcher_dict, bool): + # otherwise, configure the autobatcher, with the private runners method + state = ts.initialize_state(structures, model.device, model.dtype) + autobatcher = ts.runners._configure_batches_iterator( # noqa: SLF001 + state, model, autobatcher=autobatcher_dict + ) + # list means no autobatcher + if isinstance(autobatcher, list): + return False, None + else: + # pop max_iterations if present + autobatcher_dict = deepcopy(autobatcher_dict) + autobatcher_dict.pop("max_iterations", None) + autobatcher = BinningAutoBatcher(model=model, **autobatcher_dict) + + autobatcher_details = _get_autobatcher_details(autobatcher) + return autobatcher, autobatcher_details + + +def pick_model( + model_type: TorchSimModelType, model_path: str | Path, **model_kwargs: Any +) -> ModelInterface: + """Pick and instantiate a model based on the model type. + + Parameters + ---------- + model_type : TorchSimModelType + The type of model to instantiate. + model_path : str | Path + Path to the model file or checkpoint. + **model_kwargs : Any + Additional keyword arguments to pass to the model constructor. + + Returns + ------- + ModelInterface + The instantiated model. + + Raises + ------ + ValueError + If an invalid model type is provided. + """ + if model_type == TorchSimModelType.FAIRCHEMV1: + from torch_sim.models.fairchem_legacy import FairChemV1Model + + return FairChemV1Model(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.FAIRCHEM: + from torch_sim.models.fairchem import FairChemModel + + return FairChemModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.GRAPHPESWRAPPER: + from torch_sim.models.graphpes import GraphPESWrapper + + return GraphPESWrapper(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.MACE: + from torch_sim.models.mace import MaceModel + + return MaceModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.MATTERSIM: + from torch_sim.models.mattersim import MatterSimModel + + return MatterSimModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.METATOMIC: + from torch_sim.models.metatomic import MetatomicModel + + return MetatomicModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.NEQUIPFRAMEWORK: + from torch_sim.models.nequip_framework import NequIPFrameworkModel + + return NequIPFrameworkModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.ORB: + from torch_sim.models.orb import OrbModel + + return OrbModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.SEVENNET: + from torch_sim.models.sevennet import SevenNetModel + + return SevenNetModel(model=model_path, **model_kwargs) + if model_type == TorchSimModelType.LENNARD_JONES: + from torch_sim.models.lennard_jones import LennardJonesModel + + return LennardJonesModel(**model_kwargs) + + raise ValueError(f"Invalid model type: {model_type}") + + +@dataclass +class TorchSimOptimizeMaker(Maker): + """A maker class for performing optimization using TorchSim. + + Parameters + ---------- + name : str + The name of the job. + model : tuple[ModelType, str | Path] + The model to use for optimization. A tuple of (model_type, model_path). + optimizer : Optimizer + The TorchSim optimizer to use. + convergence_fn : ConvergenceFn | None + The convergence function type to use. + convergence_fn_kwargs : dict | None + Keyword arguments for the convergence function. + trajectory_reporter_dict : dict | None + Dictionary configuration for the trajectory reporter. + autobatcher_dict : dict | None + Dictionary configuration for the autobatcher. + max_steps : int + Maximum number of optimization steps. + steps_between_swaps : int + Number of steps between system swaps. + init_kwargs : dict | None + Additional initialization keyword arguments. + optimizer_kwargs : dict | None + Keyword arguments for the optimizer. + tags : list[str] | None + Tags for the job. + """ + + optimizer: Optimizer + model_type: TorchSimModelType + model_path: str | Path + model_kwargs: dict[str, Any] = field(default_factory=dict) + name: str = "torchsim optimize" + convergence_fn: ConvergenceFn = ConvergenceFn.FORCE # type: ignore[assignment] + convergence_fn_kwargs: dict | None = None + trajectory_reporter_dict: dict | None = None + autobatcher_dict: dict | bool = False + max_steps: int = 10_000 + steps_between_swaps: int = 5 + init_kwargs: dict | None = None + optimizer_kwargs: dict | None = None + tags: list[str] | None = None + + @torchsim_job + def make( + self, structures: list[Structure], prev_task: TorchSimTaskDoc | None = None + ) -> Response: + """Run a TorchSim optimization calculation. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structures to optimize. + prev_task : TorchSimTaskDoc | None + Previous task document if continuing from a previous calculation. + + Returns + ------- + Response + A response object containing the output task document. + """ + model = pick_model(self.model_type, self.model_path, **self.model_kwargs) + + convergence_fn_obj = CONVERGENCE_FN_REGISTRY[self.convergence_fn]( + **(self.convergence_fn_kwargs or {}) + ) + + # Configure trajectory reporter + trajectory_reporter, trajectory_reporter_details = ( + process_trajectory_reporter_dict(self.trajectory_reporter_dict) + ) + + # Configure autobatcher + max_iterations = self.max_steps // self.steps_between_swaps + autobatcher, autobatcher_details = process_in_flight_autobatcher_dict( + structures, + model, + autobatcher_dict=self.autobatcher_dict, + max_iterations=max_iterations, + ) + + optimizer_kwargs = self.optimizer_kwargs or {} + + start_time = time.time() + state = ts.optimize( + system=structures, + model=model, + optimizer=self.optimizer, + convergence_fn=convergence_fn_obj, + trajectory_reporter=trajectory_reporter, + autobatcher=autobatcher, + max_steps=self.max_steps, + steps_between_swaps=self.steps_between_swaps, + init_kwargs=self.init_kwargs, + **optimizer_kwargs, + ) + elapsed_time = time.time() - start_time + + final_structures = state.to_structures() + + # Create calculation object + calculation = TorchSimCalculation( + initial_structures=structures, + structures=final_structures, + trajectory_reporter=trajectory_reporter_details, + autobatcher=autobatcher_details, + model=self.model_type, + model_path=str(Path(self.model_path).resolve()), + task_type=TaskType.STRUCTURE_OPTIMIZATION, + optimizer=self.optimizer, + max_steps=self.max_steps, + steps_between_swaps=self.steps_between_swaps, + init_kwargs=self.init_kwargs or {}, + optimizer_kwargs=optimizer_kwargs, + ) + + # Create task document + task_doc = TorchSimTaskDoc( + structures=final_structures, + calcs_reversed=( + [calculation] + ([prev_task.calcs_reversed] if prev_task else []) + ), + time_elapsed=elapsed_time, + ) + + return Response(output=task_doc) + + +@dataclass +class TorchSimIntegrateMaker(Maker): + """A maker class for performing molecular dynamics using TorchSim. + + Parameters + ---------- + name : str + The name of the job. + model_type : TorchSimModelType + The type of model to use. + model_path : str | Path + Path to the model file or checkpoint. + integrator : Integrator + The TorchSim integrator to use. + n_steps : int + Number of integration steps to perform. + temperature : float | list[float] + Temperature(s) for the simulation in Kelvin. + timestep : float + Timestep for the integration in femtoseconds. + model_kwargs : dict[str, Any] + Keyword arguments for the model. + trajectory_reporter_dict : dict | None + Dictionary configuration for the trajectory reporter. + autobatcher_dict : dict | bool + Dictionary configuration for the autobatcher. + integrator_kwargs : dict | None + Keyword arguments for the integrator. + tags : list[str] | None + Tags for the job. + """ + + model_type: TorchSimModelType + model_path: str | Path + integrator: Any # Integrator type from torch_sim + n_steps: int + temperature: float | list[float] + timestep: float + name: str = "torchsim integrate" + model_kwargs: dict[str, Any] = field(default_factory=dict) + trajectory_reporter_dict: dict | None = None + autobatcher_dict: dict | bool = False + integrator_kwargs: dict | None = None + tags: list[str] | None = None + + @torchsim_job + def make( + self, structures: list[Structure], prev_task: TorchSimTaskDoc | None = None + ) -> Response: + """Run a TorchSim molecular dynamics calculation. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structures to simulate. + prev_task : TorchSimTaskDoc | None + Previous task document if continuing from a previous calculation. + + Returns + ------- + Response + A response object containing the output task document. + """ + model = pick_model(self.model_type, self.model_path, **self.model_kwargs) + + # Configure trajectory reporter + trajectory_reporter, trajectory_reporter_details = ( + process_trajectory_reporter_dict(self.trajectory_reporter_dict) + ) + + # Configure autobatcher + autobatcher, autobatcher_details = process_binning_autobatcher_dict( + structures, model, autobatcher_dict=self.autobatcher_dict + ) + + integrator_kwargs = self.integrator_kwargs or {} + + start_time = time.time() + state = ts.integrate( + system=structures, + model=model, + integrator=self.integrator, + n_steps=self.n_steps, + temperature=self.temperature, + timestep=self.timestep, + trajectory_reporter=trajectory_reporter, + autobatcher=autobatcher, + **integrator_kwargs, + ) + elapsed_time = time.time() - start_time + + final_structures = state.to_structures() + + # Create calculation object + calculation = TorchSimCalculation( + initial_structures=structures, + structures=final_structures, + trajectory_reporter=trajectory_reporter_details, + autobatcher=autobatcher_details, + model=self.model_type, + model_path=str(Path(self.model_path).resolve()), + task_type=TaskType.MOLECULAR_DYNAMICS, + integrator=self.integrator, + n_steps=self.n_steps, + temperature=self.temperature, + timestep=self.timestep, + integrator_kwargs=integrator_kwargs, + ) + + # Create task document + task_doc = TorchSimTaskDoc( + structures=final_structures, + calcs_reversed=( + [calculation] + ([prev_task.calcs_reversed] if prev_task else []) + ), + time_elapsed=elapsed_time, + ) + + return Response(output=task_doc) + + +@dataclass +class TorchSimStaticMaker(Maker): + """A maker class for performing static calculations using TorchSim. + + Parameters + ---------- + name : str + The name of the job. + model_type : TorchSimModelType + The type of model to use. + model_path : str | Path + Path to the model file or checkpoint. + model_kwargs : dict[str, Any] + Keyword arguments for the model. + trajectory_reporter_dict : dict | None + Dictionary configuration for the trajectory reporter. + autobatcher_dict : dict | bool + Dictionary configuration for the autobatcher. + tags : list[str] | None + Tags for the job. + """ + + model_type: TorchSimModelType + model_path: str | Path + name: str = "torchsim static" + model_kwargs: dict[str, Any] = field(default_factory=dict) + trajectory_reporter_dict: dict | None = None + autobatcher_dict: dict | bool = False + tags: list[str] | None = None + + @torchsim_job + def make( + self, structures: list[Structure], prev_task: TorchSimTaskDoc | None = None + ) -> Response: + """Run a TorchSim static calculation. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structures to calculate properties for. + prev_task : TorchSimTaskDoc | None + Previous task document if continuing from a previous calculation. + + Returns + ------- + Response + A response object containing the output task document. + """ + model = pick_model(self.model_type, self.model_path, **self.model_kwargs) + + # Configure trajectory reporter + trajectory_reporter, trajectory_reporter_details = ( + process_trajectory_reporter_dict(self.trajectory_reporter_dict) + ) + + # Configure autobatcher + autobatcher, autobatcher_details = process_binning_autobatcher_dict( + structures, model, autobatcher_dict=self.autobatcher_dict + ) + + start_time = time.time() + all_properties = ts.static( + system=structures, + model=model, + trajectory_reporter=trajectory_reporter, + autobatcher=autobatcher, + ) + elapsed_time = time.time() - start_time + + # Convert tensors to lists + all_properties_numpy = [ + {name: t.tolist() for name, t in prop_dict.items()} + for prop_dict in all_properties + ] + + # Create calculation object + calculation = TorchSimCalculation( + initial_structures=structures, + structures=structures, + trajectory_reporter=trajectory_reporter_details, + autobatcher=autobatcher_details, + model=self.model_type, + model_path=str(Path(self.model_path).resolve()), + task_type=TaskType.STATIC, + all_properties=all_properties_numpy, + ) + + # Create task document + task_doc = TorchSimTaskDoc( + structures=structures, + calcs_reversed=( + [calculation] + ([prev_task.calcs_reversed] if prev_task else []) + ), + time_elapsed=elapsed_time, + ) + + return Response(output=task_doc) diff --git a/src/atomate2/torchsim/schema.py b/src/atomate2/torchsim/schema.py new file mode 100644 index 0000000000..0fb0dfc91b --- /dev/null +++ b/src/atomate2/torchsim/schema.py @@ -0,0 +1,246 @@ +"""Schemas for TorchSim tasks.""" + +from __future__ import annotations + +import pathlib # noqa: TC003 +from enum import StrEnum # type: ignore[attr-defined] +from typing import TYPE_CHECKING, Any, Literal + +import torch_sim as ts +from pydantic import BaseModel, Field +from pymatgen.core import Structure # noqa: TC002 +from torch_sim.integrators import Integrator # noqa: TC002 +from torch_sim.optimizers import Optimizer # noqa: TC002 + +if TYPE_CHECKING: + from collections.abc import Callable + + +class TorchSimModelType(StrEnum): # type: ignore[attr-defined] + """Enum for model types.""" + + FAIRCHEMV1 = "FairChemV1Model" + FAIRCHEM = "FairChemModel" + GRAPHPESWRAPPER = "GraphPESWrapper" + MACE = "MaceModel" + MATTERSIM = "MatterSimModel" + METATOMIC = "MetatomicModel" + NEQUIPFRAMEWORK = "NequIPFrameworkModel" + ORB = "OrbModel" + SEVENNET = "SevenNetModel" + LENNARD_JONES = "LennardJonesModel" + + +class ConvergenceFn(StrEnum): # type: ignore[attr-defined] + """Enum for convergence function types.""" + + ENERGY = "energy" + FORCE = "force" + + +CONVERGENCE_FN_REGISTRY: dict[str, Callable] = { + "energy": ts.generate_energy_convergence_fn, + "force": ts.generate_force_convergence_fn, +} + + +class PropertyFn(StrEnum): + """Registry for property calculation functions. + + Because we are not able to pass live python functions through + workflow serialization, it is necessary to have an alternative + mechanism. While the functions included here are quite basic, + this gives users a place to patch in their own functions while + maintaining compatibility. + """ + + POTENTIAL_ENERGY = "potential_energy" + FORCES = "forces" + STRESS = "stress" + KINETIC_ENERGY = "kinetic_energy" + TEMPERATURE = "temperature" + + +class TaskType(StrEnum): # type: ignore[attr-defined] + """Enum for TorchSim task types.""" + + STATIC = "Static" + STRUCTURE_OPTIMIZATION = "Structure Optimization" + MOLECULAR_DYNAMICS = "Molecular Dynamics" + + +PROPERTY_FN_REGISTRY: dict[str, Callable] = { + "potential_energy": lambda state: state.energy, + "forces": lambda state: state.forces, + "stress": lambda state: state.stress, + "kinetic_energy": lambda state: ts.calc_kinetic_energy( + velocities=state.velocities, masses=state.masses + ), + "temperature": lambda state: state.calc_temperature(), +} + + +class TrajectoryReporterDetails(BaseModel): + """Details for a TorchSim trajectory reporter. + + Stores configuration and metadata for trajectory reporting. + """ + + state_frequency: int = Field( + ..., description="Frequency at which states are reported." + ) + + trajectory_kwargs: dict[str, Any] = Field( + default_factory=dict, + description=("Keyword arguments for trajectory reporter initialization."), + ) + + prop_calculators: dict[int, list[PropertyFn]] = Field( + default_factory=dict, + description=("Property calculators to apply at specific frequencies."), + ) + + state_kwargs: dict[str, Any] = Field( + default_factory=dict, + description="Keyword arguments for state reporting.", + ) + + metadata: dict[str, str] | None = Field( + None, description="Optional metadata for the trajectory reporter." + ) + + filenames: list[str | pathlib.Path] | None = Field( + None, description="List of output filenames for trajectory data." + ) + + +class AutobatcherDetails(BaseModel): + """Details for a TorchSim autobatcher configuration.""" + + autobatcher: Literal["BinningAutoBatcher", "InFlightAutoBatcher"] = Field( + ..., description="The type of autobatcher to use." + ) + + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = Field( + ..., description="How memory scales with system size." + ) + + max_memory_scaler: float | None = Field( + None, description="Maximum memory scaling factor." + ) + + max_atoms_to_try: int | None = Field( + None, description="Maximum number of atoms to try in batching." + ) + + memory_scaling_factor: float | None = Field( + None, description="Factor for memory scaling calculations." + ) + + max_iterations: int | None = Field( + None, description="Maximum number of autobatching iterations." + ) + + max_memory_padding: float | None = Field( + None, description="Maximum padding for memory allocation." + ) + + +class TorchSimCalculation(BaseModel): + """Schema for TorchSim calculation tasks. + + This schema supports three task types: Static, Structure Optimization, + and Molecular Dynamics. Different fields are populated depending on the task_type. + """ + + # Common fields (always present) + initial_structures: list[Structure] = Field( + ..., description="List of initial structures for the calculation." + ) + + structures: list[Structure] = Field( + ..., description="List of final structures from the calculation." + ) + + trajectory_reporter: TrajectoryReporterDetails | None = Field( + None, description="Configuration for the trajectory reporter." + ) + + autobatcher: AutobatcherDetails | None = Field( + None, description="Configuration for the autobatcher." + ) + + model: TorchSimModelType = Field( + ..., description="Name of the model used for the calculation." + ) + + model_path: str = Field(..., description="Path to the model file.") + + task_type: TaskType = Field( + ..., + description="Type of calculation performed (Static, Structure Optimization, " + "or Molecular Dynamics).", + ) + + # Optimization-specific fields (populated when task_type == STRUCTURE_OPTIMIZATION) + optimizer: Optimizer | None = Field( + None, description="The TorchSim optimizer instance used for optimization." + ) + + max_steps: int | None = Field( + None, description="Maximum number of optimization steps to perform." + ) + + steps_between_swaps: int | None = Field( + None, description="Number of steps between system swaps in the optimizer." + ) + + init_kwargs: dict[str, Any] | None = Field( + None, description="Additional keyword arguments for initialization." + ) + + optimizer_kwargs: dict[str, Any] | None = Field( + None, description="Keyword arguments for the optimizer configuration." + ) + + # MD-specific fields (populated when task_type == MOLECULAR_DYNAMICS) + integrator: Integrator | None = Field( + None, description="The TorchSim integrator instance used for MD simulation." + ) + + n_steps: int | None = Field( + None, description="Number of integration steps to perform." + ) + + temperature: float | list[float] | None = Field( + None, description="Temperature(s) for the simulation in Kelvin." + ) + + timestep: float | None = Field( + None, description="Timestep for the integration in femtoseconds." + ) + + integrator_kwargs: dict[str, Any] | None = Field( + None, description="Keyword arguments for the integrator configuration." + ) + + # Static calculation-specific fields (populated when task_type == STATIC) + all_properties: list[dict[str, list]] | None = Field( + None, description="List of calculated properties for each structure." + ) + + +class TorchSimTaskDoc(BaseModel): + """Base schema for TorchSim tasks.""" + + structures: list[Structure] = Field( + ..., description="List of final structures from the calculation." + ) + + calcs_reversed: list[TorchSimCalculation] = Field( + ..., description="List of calculations for the task." + ) + + time_elapsed: float = Field( + ..., description="Time elapsed for the calculation in seconds." + ) diff --git a/tests/torchsim/test_core.py b/tests/torchsim/test_core.py new file mode 100644 index 0000000000..08fe57febd --- /dev/null +++ b/tests/torchsim/test_core.py @@ -0,0 +1,305 @@ +"""Tests for TorchSim core makers.""" +# ruff: noqa: E402 + +from __future__ import annotations + +from pathlib import Path + +import pytest + +ts = pytest.importorskip("torch_sim") + +from ase.build import bulk +from jobflow import run_locally +from mace.calculators.foundations_models import download_mace_mp_checkpoint +from pymatgen.core import Structure +from pymatgen.io.ase import AseAtomsAdaptor + +from atomate2.torchsim.core import ( + TorchSimIntegrateMaker, + TorchSimOptimizeMaker, + TorchSimStaticMaker, +) +from atomate2.torchsim.schema import ConvergenceFn, TorchSimModelType + + +@pytest.fixture +def mace_model_path(): + """Download and return path to MACE model checkpoint.""" + return Path(download_mace_mp_checkpoint("small")) + + +@pytest.fixture +def ar_structure() -> Structure: + """Create a face-centered cubic (FCC) Argon structure.""" + atoms = bulk("Ar", "fcc", a=5.26, cubic=True) + return AseAtomsAdaptor.get_structure(atoms) + + +@pytest.fixture +def fe_structure() -> Structure: + """Create crystalline iron using ASE.""" + atoms = bulk("Fe", "fcc", a=5.26, cubic=True) + return AseAtomsAdaptor.get_structure(atoms) + + +def test_relax_job_comprehensive(ar_structure: Structure, tmp_path) -> None: + """Test TSOptimizeMaker with all kwargs. + + Includes trajectory reporter and autobatcher. + """ + # Perturb the structure to make optimization meaningful + perturbed_structure = ar_structure.copy() + perturbed_structure.translate_sites( + list(range(len(perturbed_structure))), [0.01, 0.01, 0.01] + ) + + n_systems = 2 + trajectory_reporter_dict = { + "filenames": [tmp_path / f"relax_{i}.h5md" for i in range(n_systems)], + "state_frequency": 5, + "prop_calculators": {1: ["potential_energy"]}, + } + + # Create autobatcher + autobatcher_dict = False + + maker = TorchSimOptimizeMaker( + model_type=TorchSimModelType.LENNARD_JONES, + model_path="", + optimizer=ts.Optimizer.fire, + convergence_fn=ConvergenceFn.FORCE, + trajectory_reporter_dict=trajectory_reporter_dict, + autobatcher_dict=autobatcher_dict, + max_steps=500, + steps_between_swaps=10, + init_kwargs={"cell_filter": ts.CellFilter.unit}, + model_kwargs={"sigma": 3.405, "epsilon": 0.0104, "compute_stress": True}, + ) + + job = maker.make([perturbed_structure] * n_systems) + response_dict = run_locally(job, ensure_success=True, root_dir=tmp_path) + result = list(response_dict.values())[-1][1].output + + # Validate result structure (TSTaskDoc) + assert hasattr(result, "structures") + assert hasattr(result, "calcs_reversed") + assert hasattr(result, "time_elapsed") + + # Check structures list output + assert isinstance(result.structures, list) + assert len(result.structures) == n_systems + assert isinstance(result.structures[0], Structure) + + # Check calculation details + assert len(result.calcs_reversed) == 1 + calc = result.calcs_reversed[0] + + # Check model name + assert calc.model == TorchSimModelType.LENNARD_JONES + assert calc.model_path is not None + + # Check optimizer + assert calc.optimizer == ts.Optimizer.fire + + # Check trajectory reporter details + assert calc.trajectory_reporter is not None + assert calc.trajectory_reporter.state_frequency == 5 + assert hasattr(calc.trajectory_reporter, "prop_calculators") + assert all(Path(f).is_file() for f in calc.trajectory_reporter.filenames) + + # Check autobatcher details + assert calc.autobatcher is None + + # Check other parameters + assert calc.max_steps == 500 + assert calc.steps_between_swaps == 10 + assert calc.init_kwargs["cell_filter"] == ts.CellFilter.unit + + # Check time elapsed + assert result.time_elapsed > 0 + + +def test_relax_job_mace( + ar_structure: Structure, mace_model_path: str, tmp_path +) -> None: + """Test TSOptimizeMaker with MACE model. + + Includes trajectory reporter and autobatcher. + """ + # Perturb the structure to make optimization meaningful + perturbed_structure = ar_structure.copy() + perturbed_structure.translate_sites( + list(range(len(perturbed_structure))), [0.01, 0.01, 0.01] + ) + + n_systems = 2 + trajectory_reporter_dict = { + "filenames": [tmp_path / f"relax_{i}.h5md" for i in range(n_systems)], + "state_frequency": 5, + "prop_calculators": {1: ["potential_energy"]}, + } + + autobatcher_dict = {"memory_scales_with": "n_atoms", "max_memory_scaler": 260} + + maker = TorchSimOptimizeMaker( + model_type=TorchSimModelType.MACE, + model_path=mace_model_path, + optimizer=ts.Optimizer.fire, + convergence_fn=ConvergenceFn.FORCE, + trajectory_reporter_dict=trajectory_reporter_dict, + autobatcher_dict=autobatcher_dict, + max_steps=500, + steps_between_swaps=10, + init_kwargs={"cell_filter": ts.CellFilter.unit}, + ) + + job = maker.make([perturbed_structure] * n_systems) + response_dict = run_locally(job, ensure_success=True, root_dir=tmp_path) + result = list(response_dict.values())[-1][1].output + + # Validate result structure + assert hasattr(result, "structures") + assert len(result.structures) == n_systems + assert len(result.calcs_reversed) == 1 + + calc = result.calcs_reversed[0] + assert calc.model == TorchSimModelType.MACE + assert calc.autobatcher is not None + assert calc.autobatcher.memory_scales_with == "n_atoms" + + +def test_md_job_comprehensive(ar_structure: Structure, tmp_path) -> None: + """Test TSIntegrateMaker with all kwargs. + + Includes trajectory reporter and autobatcher. + """ + n_systems = 2 + trajectory_reporter_dict = { + "filenames": [tmp_path / f"md_{i}.h5md" for i in range(n_systems)], + "state_frequency": 2, + "prop_calculators": {1: ["potential_energy", "kinetic_energy", "temperature"]}, + } + + # Create autobatcher + autobatcher_dict = False + + maker = TorchSimIntegrateMaker( + model_type=TorchSimModelType.LENNARD_JONES, + model_path="", + integrator=ts.Integrator.nvt_langevin, + n_steps=20, + temperature=300.0, + timestep=0.001, + trajectory_reporter_dict=trajectory_reporter_dict, + autobatcher_dict=autobatcher_dict, + model_kwargs={"sigma": 3.405, "epsilon": 0.0104, "compute_stress": True}, + ) + + job = maker.make([ar_structure] * n_systems) + response_dict = run_locally(job, ensure_success=True, root_dir=tmp_path) + result = list(response_dict.values())[-1][1].output + + # Validate result structure (TSTaskDoc) + assert hasattr(result, "structures") + assert hasattr(result, "calcs_reversed") + assert hasattr(result, "time_elapsed") + + # Check structures list output + assert isinstance(result.structures, list) + assert len(result.structures) == n_systems + assert isinstance(result.structures[0], Structure) + + # Check calculation details + assert len(result.calcs_reversed) == 1 + calc = result.calcs_reversed[0] + + # Check model name + assert calc.model == TorchSimModelType.LENNARD_JONES + assert calc.model_path is not None + + # Check integrator + assert calc.integrator == ts.Integrator.nvt_langevin + + # Check MD parameters + assert calc.n_steps == 20 + assert calc.temperature == 300.0 + assert calc.timestep == 0.001 + + # Check trajectory reporter details + assert calc.trajectory_reporter is not None + assert calc.trajectory_reporter.state_frequency == 2 + assert hasattr(calc.trajectory_reporter, "prop_calculators") + assert all(Path(f).is_file() for f in calc.trajectory_reporter.filenames) + + # Check autobatcher details + assert calc.autobatcher is None + + # Check time elapsed + assert result.time_elapsed > 0 + + +def test_static_job_comprehensive(ar_structure: Structure, tmp_path) -> None: + """Test TSStaticMaker with all kwargs. + + Includes trajectory reporter and autobatcher. + """ + n_systems = 2 + trajectory_reporter_dict = { + "filenames": [tmp_path / f"static_{i}.h5md" for i in range(n_systems)], + "state_frequency": 1, + "prop_calculators": {1: ["potential_energy"]}, + "state_kwargs": {"save_forces": True}, + } + + # Create autobatcher + autobatcher_dict = False + + maker = TorchSimStaticMaker( + model_type=TorchSimModelType.LENNARD_JONES, + model_path="", + trajectory_reporter_dict=trajectory_reporter_dict, + autobatcher_dict=autobatcher_dict, + model_kwargs={"sigma": 3.405, "epsilon": 0.0104, "compute_stress": True}, + ) + + job = maker.make([ar_structure] * n_systems) + response_dict = run_locally(job, ensure_success=True, root_dir=tmp_path) + result = list(response_dict.values())[-1][1].output + + # Validate result structure (TSTaskDoc) + assert hasattr(result, "structures") + assert hasattr(result, "calcs_reversed") + assert hasattr(result, "time_elapsed") + + # Check structures list output + assert isinstance(result.structures, list) + assert len(result.structures) == n_systems + assert isinstance(result.structures[0], Structure) + + # Check calculation details + assert len(result.calcs_reversed) == 1 + calc = result.calcs_reversed[0] + + # Check model name + assert calc.model == TorchSimModelType.LENNARD_JONES + assert calc.model_path is not None + + # Check trajectory reporter details + assert calc.trajectory_reporter is not None + assert calc.trajectory_reporter.state_frequency == 1 + assert hasattr(calc.trajectory_reporter, "prop_calculators") + assert hasattr(calc.trajectory_reporter, "state_kwargs") + assert calc.trajectory_reporter.state_kwargs["save_forces"] is True + assert all(Path(f).is_file() for f in calc.trajectory_reporter.filenames) + + # Check autobatcher details + assert calc.autobatcher is None + + # Check that all_properties is present + assert hasattr(calc, "all_properties") + assert isinstance(calc.all_properties, list) + assert len(calc.all_properties) == n_systems + + # Check time elapsed + assert result.time_elapsed > 0