From cff05b95b288bad02d77a53f45aa86b0ed2b743d Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Tue, 26 Aug 2025 23:44:55 +0530 Subject: [PATCH 1/7] initial design --- pytorch_forecasting/base/_base_pkg.py | 206 ++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 pytorch_forecasting/base/_base_pkg.py diff --git a/pytorch_forecasting/base/_base_pkg.py b/pytorch_forecasting/base/_base_pkg.py new file mode 100644 index 000000000..419b4bd82 --- /dev/null +++ b/pytorch_forecasting/base/_base_pkg.py @@ -0,0 +1,206 @@ +from pathlib import Path +import pickle +from typing import Any, Optional, Union + +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.core.datamodule import LightningDataModule +import torch +from torch.utils.data import DataLoader + +from pytorch_forecasting.data import TimeSeries +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 + + +class Base_pkg(_BasePtForecasterV2): + """ + Base model package class acting as a high-level wrapper for the Lightning workflow. + + This class simplifies the user experience by managing model, datamodule, and trainer + configurations, and providing streamlined `fit` and `predict` methods. + + Parameters + ---------- + model_cfg : dict, optional + Model configs for the initialisation of the model. Required if not loading + from a checkpoint. Defaults to {}. + trainer_cfg : dict, optional + Configs to initialise ``lightning.Trainer``. Defaults to {}. + datamodule_cfg : Union[dict, str, Path], optional + Configs to initialise a ``LightningDataModule``. + - If dict, the keys and values are used as configuration parameters. + - If str or Path, it should be a path to a ``.pkl`` file containing + the serialized configuration dictionary. Required for reproducibility + when loading a model for inference. Defaults to {}. + ckpt_path : Union[str, Path], optional + Path to the checkpoint from which to load the model. If provided, `model_cfg` + is ignored. Defaults to None. + """ + + def __init__( + self, + model_cfg: Optional[dict[str, Any]] = None, + trainer_cfg: Optional[dict[str, Any]] = None, + datamodule_cfg: Optional[Union[dict[str, Any], str, Path]] = None, + ckpt_path: Optional[Union[str, Path]] = None, + ): + self.model_cfg = model_cfg or {} + self.trainer_cfg = trainer_cfg or {} + self.ckpt_path = Path(ckpt_path) if ckpt_path else None + + if isinstance(datamodule_cfg, (str, Path)): + with open(datamodule_cfg, "rb") as f: + self.datamodule_cfg = pickle.load(f) # noqa : S301 + else: + self.datamodule_cfg = datamodule_cfg or {} + + self.model = None + self.trainer = None + self.datamodule = None + + self._build_model() + + @classmethod + def get_cls(cls): + """Get the underlying model class.""" + raise NotImplementedError("Subclasses must implement `get_cls`.") + + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") + + def _build_model(self): + """Instantiates the model, either from a checkpoint or from config.""" + model_cls = self.get_cls() + if self.ckpt_path: + self.model = model_cls.load_from_checkpoint(self.ckpt_path) + elif self.model_cfg: + self.model = model_cls(**self.model_cfg) + else: + self.model = None + + def _build_datamodule(self, data: TimeSeries) -> LightningDataModule: + """Constructs a DataModule from a D1 layer object.""" + if not self.datamodule_cfg: + raise ValueError("`datamodule_cfg` must be provided to build a datamodule.") + datamodule_cls = self.get_datamodule_cls() + return datamodule_cls(data, **self.datamodule_cfg) + + def _load_dataloader( + self, data: Union[TimeSeries, LightningDataModule, DataLoader] + ) -> DataLoader: + """Converts various data input types into a DataLoader for prediction.""" + if isinstance(data, TimeSeries): # D1 Layer + dm = self._build_datamodule(data) + return dm.predict_dataloader() + elif isinstance(data, LightningDataModule): # D2 Layer + return data.predict_dataloader() + elif isinstance(data, DataLoader): + return data + else: + raise TypeError( + f"Unsupported data type for prediction: {type(data).__name__}. " + "Expected TimeSeriesDataSet, LightningDataModule, or DataLoader." + ) + + def fit( + self, + train_data: Union[TimeSeries, LightningDataModule], + # todo: we should create a base data_module for different data_modules + val_data: Optional[Union[TimeSeries, LightningDataModule]] = None, + save_ckpt: bool = True, + ckpt_dir: Union[str, Path] = "checkpoints", + **trainer_fit_kwargs, + ): + """ + Fit the model to the training data. + + Parameters + ---------- + train_data : Union[TimeSeries, LightningDataModule] + Training data (D1 or D2 layer). + val_data : Union[TimeSeries, LightningDataModule], optional + Validation data. + save_ckpt : bool, default=True + If True, save the best model checkpoint and the `datamodule_cfg`. + ckpt_dir : Union[str, Path], default="checkpoints" + Directory to save artifacts. + **trainer_fit_kwargs : + Additional keyword arguments passed to `trainer.fit()`. + + Returns + ------- + Optional[Path] + The path to the best model checkpoint if `save_ckpt=True`, else None. + """ + if self.model is None: + raise RuntimeError( + "Model is not initialized. Provide `model_cfg` or `ckpt_path`." + ) + + if isinstance(train_data, TimeSeries): + self.datamodule = self._build_datamodule(train_data) + else: + self.datamodule = train_data + + callbacks = self.trainer_cfg.get("callbacks", []) + if save_ckpt: + ckpt_dir = Path(ckpt_dir) + ckpt_dir.mkdir(parents=True, exist_ok=True) + checkpoint_cb = ModelCheckpoint( + dirpath=ckpt_dir, + filename="best-{epoch}-{val_loss:.2f}", + save_top_k=1, + monitor="val_loss", + mode="min", + ) + callbacks.append(checkpoint_cb) + + self.trainer = Trainer(**self.trainer_cfg, callbacks=callbacks) + self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) + + if save_ckpt: + best_model_path = Path(checkpoint_cb.best_model_path) + dm_cfg_path = best_model_path.parent / "datamodule_cfg.pkl" + with open(dm_cfg_path, "wb") as f: + pickle.dump(self.datamodule_cfg, f) + print(f"Best model saved to: {best_model_path}") + print(f"DataModule config saved to: {dm_cfg_path}") + return best_model_path + return None + + def predict( + self, + data: Union[TimeSeries, LightningDataModule, DataLoader], + **kwargs, + ) -> Union[dict[str, torch.Tensor], None]: + """ + Generate predictions by wrapping the model's predict method. + + This method prepares the data by resolving it into a DataLoader and then + delegates the prediction task to the underlying model's ``.predict()`` method. + + Parameters + ---------- + data : Union[TimeSeries, LightningDataModule, DataLoader] + The data to predict on (D1, D2, or DataLoader). + **kwargs : + Additional keyword arguments passed directly to the model's ``.predict()`` + method. This includes `mode`, `return_info`, `output_dir`, and any + `trainer_kwargs`. + + Returns + ------- + Union[Dict[str, torch.Tensor], None] + A dictionary of prediction tensors, or `None` if `output_dir` is specified + in `**kwargs`. + """ + if self.model is None: + raise RuntimeError( + "Model is not initialized. Provide `model_cfg` or `ckpt_path`." + ) + + dataloader = self._load_dataloader(data) + + return self.model.predict(dataloader, **kwargs) From 13c12b98a689d8cdedcabc8547af60fad1f08e64 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 3 Oct 2025 00:54:20 +0530 Subject: [PATCH 2/7] preliminary design --- pytorch_forecasting/base/_base_pkg.py | 15 ++- pytorch_forecasting/callbacks/__init__.py | 0 pytorch_forecasting/callbacks/predict.py | 112 ++++++++++++++++++ .../models/base/_base_model_v2.py | 68 ++++++++++- 4 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 pytorch_forecasting/callbacks/__init__.py create mode 100644 pytorch_forecasting/callbacks/predict.py diff --git a/pytorch_forecasting/base/_base_pkg.py b/pytorch_forecasting/base/_base_pkg.py index 419b4bd82..7d5ba39f8 100644 --- a/pytorch_forecasting/base/_base_pkg.py +++ b/pytorch_forecasting/base/_base_pkg.py @@ -93,8 +93,10 @@ def _load_dataloader( """Converts various data input types into a DataLoader for prediction.""" if isinstance(data, TimeSeries): # D1 Layer dm = self._build_datamodule(data) + dm.setup(stage="predict") return dm.predict_dataloader() elif isinstance(data, LightningDataModule): # D2 Layer + data.setup(stage="predict") return data.predict_dataloader() elif isinstance(data, DataLoader): return data @@ -160,7 +162,7 @@ def fit( self.trainer = Trainer(**self.trainer_cfg, callbacks=callbacks) self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) - if save_ckpt: + if save_ckpt and checkpoint_cb: best_model_path = Path(checkpoint_cb.best_model_path) dm_cfg_path = best_model_path.parent / "datamodule_cfg.pkl" with open(dm_cfg_path, "wb") as f: @@ -173,6 +175,7 @@ def fit( def predict( self, data: Union[TimeSeries, LightningDataModule, DataLoader], + output_dir: Optional[Union[str, Path]] = None, **kwargs, ) -> Union[dict[str, torch.Tensor], None]: """ @@ -202,5 +205,15 @@ def predict( ) dataloader = self._load_dataloader(data) + predictions = self.model.predict(dataloader, **kwargs) + + if output_dir: + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + output_file = output_path / "predictions.pkl" + with open(output_file, "wb") as f: + pickle.dump(predictions, f) + print(f"Predictions saved to {output_file}") + return None return self.model.predict(dataloader, **kwargs) diff --git a/pytorch_forecasting/callbacks/__init__.py b/pytorch_forecasting/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/callbacks/predict.py b/pytorch_forecasting/callbacks/predict.py new file mode 100644 index 000000000..48e6d3c84 --- /dev/null +++ b/pytorch_forecasting/callbacks/predict.py @@ -0,0 +1,112 @@ +from typing import Any, Optional +from warnings import warn + +from lightning import Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +import torch + +from pytorch_forecasting.models.base._base_model_v2 import BaseModel + + +class PredictCallback(BasePredictionWriter): + """ + Callback to capture predictions and related information internally. + + This callback is used by `BaseModel.predict()` to process raw model outputs + into the desired format (`prediction`, `quantiles`, or `raw`) and collect + any additional requested info (`x`, `y`, `index`, etc.). The results are + collated and stored in memory, accessible via the `.result` property. It does + not write to disk. + + Parameters + ---------- + mode : str + The prediction mode ("prediction", "quantiles", or "raw"). + return_info : list[str], optional + Additional information to return. + **kwargs : + Additional keyword arguments for `to_prediction` or `to_quantiles`. + """ + + def __init__( + self, + mode: str = "prediction", + return_info: Optional[list[str]] = None, + **kwargs, + ): + super().__init__(write_interval="epoch") + self.mode = mode + self.return_info = return_info or [] + self.kwargs = kwargs + self._reset_data() + + def _reset_data(self): + """Clear collected data for a new prediction run.""" + self.predictions = [] + self.info = {key: [] for key in self.return_info} + self._result = None + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: BaseModel, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ): + """Process and store predictions for a single batch.""" + x, y = batch + + if self.mode == "raw": + processed_output = outputs + elif self.mode == "prediction": + processed_output = pl_module.to_prediction(outputs, **self.kwargs) + elif self.mode == "quantiles": + processed_output = pl_module.to_quantiles(outputs, **self.kwargs) + else: + raise ValueError(f"Invalid prediction mode: {self.mode}") + + self.predictions.append(processed_output) + + for key in self.return_info: + if key == "x": + self.info[key].append(x) + elif key == "y": + self.info[key].append(y[0]) + elif key == "index": + self.info[key].append(y[1]) + elif key == "decoder_lengths": + self.info[key].append(x["decoder_lengths"]) + else: + warn(f"Unknown return_info key: {key}") + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: "BaseModel"): + """Collate all batch results into final tensors.""" + if self.mode == "raw" and isinstance(self.predictions[0], dict): + keys = self.predictions[0].keys() + collated_preds = { + key: torch.cat([p[key] for p in self.predictions]) for key in keys + } + else: + collated_preds = {"prediction": torch.cat(self.predictions)} + + final_result = collated_preds + + for key, data_list in self.info.items(): + if isinstance(data_list[0], dict): + collated_info = { + k: torch.cat([d[k] for d in data_list]) for k in data_list[0].keys() + } + else: + collated_info = torch.cat(data_list) + final_result[key] = collated_info + + self._result = final_result + self._reset_data() + + @property + def result(self) -> dict[str, torch.Tensor]: + if self._result is None: + raise RuntimeError("Prediction results are not yet available.") + return self._result diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index 8896a5397..31cfc356e 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -5,16 +5,19 @@ ######################################################################################## -from typing import Optional, Union +from typing import Any, Optional, Union from warnings import warn +from lightning import Trainer from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT import torch import torch.nn as nn from torch.optim import Optimizer +from torch.utils.data import DataLoader -from pytorch_forecasting.metrics import Metric +from pytorch_forecasting.callbacks.predict import PredictCallback +from pytorch_forecasting.metrics import Metric, MultiLoss from pytorch_forecasting.utils._classproperty import classproperty @@ -91,6 +94,67 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ raise NotImplementedError("Forward method must be implemented by subclass.") + def predict( + self, + dataloader: DataLoader, + mode: str = "prediction", + return_info: Optional[list[str]] = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + """ + Generate predictions for new data using the `lightning.Trainer`. + + Parameters + ---------- + dataloader : DataLoader + The dataloader containing the data to predict on. + mode : str + The prediction mode ("prediction", "quantiles", or "raw"). + return_info : list[str], optional + A list of additional information to return. + **kwargs : + Additional arguments for `to_prediction`/`to_quantiles` or `Trainer`. + + Returns + ------- + dict[str, torch.Tensor] + A dictionary of prediction results. + """ + trainer_kwargs = kwargs.pop("trainer_kwargs", {}) + predict_callback = PredictCallback(mode=mode, return_info=return_info, **kwargs) + + # Ensure callbacks list exists and append the new one + callbacks = trainer_kwargs.get("callbacks", []) + if not isinstance(callbacks, list): + callbacks = [callbacks] + callbacks.append(predict_callback) + trainer_kwargs["callbacks"] = callbacks + + trainer = Trainer(**trainer_kwargs) + trainer.predict(self, dataloaders=dataloader) + + return predict_callback.result + + def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: + """Converts raw model output to point forecasts.""" + if isinstance(self.loss, MultiLoss): + # Assuming first loss is the primary one for prediction + return self.loss.losses[0].to_prediction(out["prediction"][0]) + else: + return self.loss.to_prediction(out["prediction"]) + + def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: + """Converts raw model output to quantile forecasts.""" + quantiles = kwargs.get("quantiles") # Allow overriding default quantiles + if isinstance(self.loss, MultiLoss): + # Assuming first loss is the primary one for quantiles + loss = self.loss.losses[0] + q = quantiles or loss.quantiles + return loss.to_quantiles(out["prediction"][0], quantiles=q) + else: + q = quantiles or self.loss.quantiles + return self.loss.to_quantiles(out["prediction"], quantiles=q) + def training_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: From 35e2447d02d4afb1c08b4b881b0202f0c9b8dbb1 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 18 Oct 2025 23:26:38 +0530 Subject: [PATCH 3/7] add predict to all v2 models --- pytorch_forecasting/base/_base_pkg.py | 70 +++++---- pytorch_forecasting/callbacks/predict.py | 23 ++- .../models/base/_base_model_v2.py | 18 +-- .../models/dlinear/_dlinear_pkg_v2.py | 23 ++- .../models/samformer/_samformer_v2_pkg.py | 92 +++--------- .../_tft_pkg_v2.py | 100 +++---------- .../models/tide/_tide_dsipts/_tide_v2_pkg.py | 95 +++--------- .../models/timexer/_timexer_pkg_v2.py | 95 +++--------- .../tests/test_all_estimators_v2.py | 135 +++++++++--------- 9 files changed, 222 insertions(+), 429 deletions(-) diff --git a/pytorch_forecasting/base/_base_pkg.py b/pytorch_forecasting/base/_base_pkg.py index 7d5ba39f8..340e9e742 100644 --- a/pytorch_forecasting/base/_base_pkg.py +++ b/pytorch_forecasting/base/_base_pkg.py @@ -17,13 +17,13 @@ class Base_pkg(_BasePtForecasterV2): Base model package class acting as a high-level wrapper for the Lightning workflow. This class simplifies the user experience by managing model, datamodule, and trainer - configurations, and providing streamlined `fit` and `predict` methods. + configurations, and providing streamlined ``fit`` and ``predict`` methods. Parameters ---------- model_cfg : dict, optional Model configs for the initialisation of the model. Required if not loading - from a checkpoint. Defaults to {}. + from a checkpoint. Defaults to ``{}``. trainer_cfg : dict, optional Configs to initialise ``lightning.Trainer``. Defaults to {}. datamodule_cfg : Union[dict, str, Path], optional @@ -58,8 +58,6 @@ def __init__( self.trainer = None self.datamodule = None - self._build_model() - @classmethod def get_cls(cls): """Get the underlying model class.""" @@ -70,13 +68,32 @@ def get_datamodule_cls(cls): """Get the underlying DataModule class.""" raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") - def _build_model(self): + @classmethod + def get_test_data(cls, **kwargs): + """ + Creates and returns D1 TimeSeries dataSet objects for testing. + """ + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + raw_data = data_with_covariates_v2() + + datasets_info = make_datasets_v2(raw_data, **kwargs) + + return { + "train": datasets_info["training_dataset"], + "predict": datasets_info["validation_dataset"], + } + + def _build_model(self, metadata: dict): """Instantiates the model, either from a checkpoint or from config.""" model_cls = self.get_cls() if self.ckpt_path: self.model = model_cls.load_from_checkpoint(self.ckpt_path) elif self.model_cfg: - self.model = model_cls(**self.model_cfg) + self.model = model_cls(**self.model_cfg, metadata=metadata) else: self.model = None @@ -108,9 +125,8 @@ def _load_dataloader( def fit( self, - train_data: Union[TimeSeries, LightningDataModule], + data: Union[TimeSeries, LightningDataModule], # todo: we should create a base data_module for different data_modules - val_data: Optional[Union[TimeSeries, LightningDataModule]] = None, save_ckpt: bool = True, ckpt_dir: Union[str, Path] = "checkpoints", **trainer_fit_kwargs, @@ -120,10 +136,9 @@ def fit( Parameters ---------- - train_data : Union[TimeSeries, LightningDataModule] - Training data (D1 or D2 layer). - val_data : Union[TimeSeries, LightningDataModule], optional - Validation data. + data : Union[TimeSeries, LightningDataModule] + The data to fit on (D1 or D2 layer). This object is responsible + for providing both training and validation data. save_ckpt : bool, default=True If True, save the best model checkpoint and the `datamodule_cfg`. ckpt_dir : Union[str, Path], default="checkpoints" @@ -136,17 +151,22 @@ def fit( Optional[Path] The path to the best model checkpoint if `save_ckpt=True`, else None. """ - if self.model is None: - raise RuntimeError( - "Model is not initialized. Provide `model_cfg` or `ckpt_path`." - ) - - if isinstance(train_data, TimeSeries): - self.datamodule = self._build_datamodule(train_data) + if isinstance(data, TimeSeries): + self.datamodule = self._build_datamodule(data) else: - self.datamodule = train_data + self.datamodule = data + self.datamodule.setup(stage="fit") - callbacks = self.trainer_cfg.get("callbacks", []) + if self.model is None: + if not self.model_cfg: + raise RuntimeError( + "`model_cfg` must be provided to train from scratch." + ) + metadata = self.datamodule.metadata + self._build_model(metadata) + + callbacks = self.trainer_cfg.get("callbacks", []).copy() + checkpoint_cb = None if save_ckpt: ckpt_dir = Path(ckpt_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) @@ -158,10 +178,12 @@ def fit( mode="min", ) callbacks.append(checkpoint_cb) + trainer_init_cfg = self.trainer_cfg.copy() + trainer_init_cfg.pop("callbacks", None) - self.trainer = Trainer(**self.trainer_cfg, callbacks=callbacks) - self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) + self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks) + self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) if save_ckpt and checkpoint_cb: best_model_path = Path(checkpoint_cb.best_model_path) dm_cfg_path = best_model_path.parent / "datamodule_cfg.pkl" @@ -216,4 +238,4 @@ def predict( print(f"Predictions saved to {output_file}") return None - return self.model.predict(dataloader, **kwargs) + return predictions diff --git a/pytorch_forecasting/callbacks/predict.py b/pytorch_forecasting/callbacks/predict.py index 48e6d3c84..883a92f8b 100644 --- a/pytorch_forecasting/callbacks/predict.py +++ b/pytorch_forecasting/callbacks/predict.py @@ -2,21 +2,19 @@ from warnings import warn from lightning import Trainer +from lightning.pytorch import LightningModule from lightning.pytorch.callbacks import BasePredictionWriter import torch -from pytorch_forecasting.models.base._base_model_v2 import BaseModel - class PredictCallback(BasePredictionWriter): """ Callback to capture predictions and related information internally. - This callback is used by `BaseModel.predict()` to process raw model outputs - into the desired format (`prediction`, `quantiles`, or `raw`) and collect - any additional requested info (`x`, `y`, `index`, etc.). The results are - collated and stored in memory, accessible via the `.result` property. It does - not write to disk. + This callback is used by ``BaseModel.predict()`` to process raw model outputs + into the desired format (``prediction``, ``quantiles``, or ``raw``) and collect + any additional requested info (``x``, ``y``, ``index``, etc.). The results are + collated and stored in memory, accessible via the ``.result`` property. Parameters ---------- @@ -40,16 +38,17 @@ def __init__( self.kwargs = kwargs self._reset_data() - def _reset_data(self): + def _reset_data(self, result: bool = True): """Clear collected data for a new prediction run.""" self.predictions = [] self.info = {key: [] for key in self.return_info} - self._result = None + if result: + self._result = None def on_predict_batch_end( self, trainer: Trainer, - pl_module: BaseModel, + pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, @@ -81,7 +80,7 @@ def on_predict_batch_end( else: warn(f"Unknown return_info key: {key}") - def on_predict_epoch_end(self, trainer: Trainer, pl_module: "BaseModel"): + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule): """Collate all batch results into final tensors.""" if self.mode == "raw" and isinstance(self.predictions[0], dict): keys = self.predictions[0].keys() @@ -103,7 +102,7 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: "BaseModel"): final_result[key] = collated_info self._result = final_result - self._reset_data() + self._reset_data(result=False) @property def result(self) -> dict[str, torch.Tensor]: diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index 31cfc356e..2de069d58 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -137,23 +137,15 @@ def predict( def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to point forecasts.""" - if isinstance(self.loss, MultiLoss): - # Assuming first loss is the primary one for prediction - return self.loss.losses[0].to_prediction(out["prediction"][0]) - else: - return self.loss.to_prediction(out["prediction"]) + # todo: add MultiLoss support + return self.loss.to_prediction(out["prediction"]) def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to quantile forecasts.""" + # todo: add MultiLoss support quantiles = kwargs.get("quantiles") # Allow overriding default quantiles - if isinstance(self.loss, MultiLoss): - # Assuming first loss is the primary one for quantiles - loss = self.loss.losses[0] - q = quantiles or loss.quantiles - return loss.to_quantiles(out["prediction"][0], quantiles=q) - else: - q = quantiles or self.loss.quantiles - return self.loss.to_quantiles(out["prediction"], quantiles=q) + q = quantiles or self.loss.quantiles + return self.loss.to_quantiles(out["prediction"], quantiles=q) def training_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int diff --git a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py index bf4fffce5..500446d9d 100644 --- a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py +++ b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py @@ -2,10 +2,10 @@ Packages container for DLinear model. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class DLinear_pkg_v2(_BasePtForecasterV2): +class DLinear_pkg_v2(Base_pkg): """DLinear package container.""" _tags = { @@ -26,6 +26,13 @@ def get_cls(cls): return DLinear + @classmethod + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + + return TslibDataModule + @classmethod def _get_test_datamodule_from(cls, trainer_kwargs): """Create test dataloaders from trainer_kwargs - following v1 pattern.""" @@ -112,7 +119,7 @@ def get_test_train_params(cls): from pytorch_forecasting.metrics import MAE, MAPE, SMAPE, QuantileLoss - return [ + params = [ {}, dict(moving_avg=25, individual=False, logging_metrics=[SMAPE()]), dict( @@ -125,3 +132,13 @@ def get_test_train_params(cls): logging_metrics=[SMAPE()], ), ] + + default_dm_cfg = {"context_length": 8, "prediction_length": 2} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py index 36db9340a..2838fcc91 100644 --- a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py +++ b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py @@ -2,10 +2,10 @@ Samformer package container. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class Samformer_pkg_v2(_BasePtForecasterV2): +class Samformer_pkg_v2(Base_pkg): """Samformer package container.""" _tags = { @@ -21,83 +21,13 @@ def get_cls(cls): return Samformer @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -115,7 +45,7 @@ def get_test_train_params(cls): from pytorch_forecasting.metrics import QuantileLoss - return [ + params = [ { # "loss": nn.MSELoss(), "hidden_size": 32, @@ -134,3 +64,13 @@ def get_test_train_params(cls): "use_revin": False, }, ] + + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index 8c95daa6b..d121eba6e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -1,9 +1,9 @@ """TFT package container.""" -from pytorch_forecasting.models.base import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TFT_pkg_v2(_BasePtForecasterV2): +class TFT_pkg_v2(Base_pkg): """TFT package container.""" _tags = { @@ -23,83 +23,13 @@ def get_cls(cls): return TFT @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -113,19 +43,17 @@ def get_test_train_params(cls): `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. `create_test_instance` uses the first (or only) dictionary in `params` """ - return [ + params = [ {}, dict( hidden_size=25, attention_head_size=5, ), - dict( - data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3) - ), + dict(datamodule_cfg=dict(max_encoder_length=5, max_prediction_length=3)), dict( hidden_size=24, attention_head_size=8, - data_loader_kwargs=dict( + datamodule_cfg=dict( max_encoder_length=5, max_prediction_length=3, add_relative_time_idx=False, @@ -133,7 +61,17 @@ def get_test_train_params(cls): ), dict( hidden_size=12, - data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10), + datamodule_cfg=dict(max_encoder_length=7, max_prediction_length=10), ), dict(attention_head_size=2), ] + + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py index d3cf70454..6b2780053 100644 --- a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py +++ b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py @@ -1,9 +1,9 @@ """TIDE package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TIDE_pkg_v2(_BasePtForecasterV2): +class TIDE_pkg_v2(Base_pkg): """TIDE package container.""" _tags = { @@ -19,83 +19,13 @@ def get_cls(cls): return TIDE @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - training_max_time_idx = datasets_info["training_max_time_idx"] - - max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) - max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) - add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return EncoderDecoderTimeSeriesDataModule @classmethod def get_test_train_params(cls): @@ -111,7 +41,7 @@ def get_test_train_params(cls): """ from pytorch_forecasting.metrics import MAE, MAPE - return [ + params = [ dict( hidden_size=16, d_model=8, @@ -125,7 +55,7 @@ def get_test_train_params(cls): n_add_enc=2, n_add_dec=2, dropout_rate=0.2, - data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3), + datamodule_cfg=dict(max_encoder_length=5, max_prediction_length=3), loss=MAE(), ), dict( @@ -134,7 +64,16 @@ def get_test_train_params(cls): n_add_enc=3, n_add_dec=2, dropout_rate=0.1, - data_loader_kwargs=dict(max_encoder_length=4, max_prediction_length=2), + datamodule_cfg=dict(max_encoder_length=4, max_prediction_length=2), loss=MAPE(), ), ] + default_dm_cfg = {"max_encoder_length": 4, "max_prediction_length": 3} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py index a0e4b8aa7..edff00024 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py @@ -2,10 +2,10 @@ Metadata container for TimeXer v2. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.base._base_pkg import Base_pkg -class TimeXer_pkg_v2(_BasePtForecasterV2): +class TimeXer_pkg_v2(Base_pkg): """TimeXer metadata container.""" _tags = { @@ -25,77 +25,11 @@ def get_cls(cls): return TimeXer @classmethod - def _get_test_datamodule_from(cls, trainer_kwargs): - """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + def get_datamodule_cls(cls): + """Get the underlying DataModule class.""" from pytorch_forecasting.data._tslib_data_module import TslibDataModule - from pytorch_forecasting.tests._data_scenarios import ( - data_with_covariates_v2, - make_datasets_v2, - ) - data_with_covariates = data_with_covariates_v2() - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - - data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) - data_loader_default_kwargs.update(data_loader_kwargs) - - datasets_info = make_datasets_v2( - data_with_covariates, **data_loader_default_kwargs - ) - - training_dataset = datasets_info["training_dataset"] - validation_dataset = datasets_info["validation_dataset"] - - context_length = data_loader_kwargs.get("context_length", 12) - prediction_length = data_loader_kwargs.get("prediction_length", 4) - batch_size = data_loader_kwargs.get("batch_size", 2) - - train_datamodule = TslibDataModule( - time_series_dataset=training_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=batch_size, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = TslibDataModule( - time_series_dataset=validation_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=batch_size, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = TslibDataModule( - time_series_dataset=validation_dataset, - context_length=context_length, - prediction_length=prediction_length, - add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), - batch_size=1, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, - } + return TslibDataModule @classmethod def get_test_train_params(cls): @@ -111,17 +45,17 @@ def get_test_train_params(cls): """ from pytorch_forecasting.metrics import QuantileLoss - return [ + params = [ {}, dict( hidden_size=64, n_heads=4, ), - dict(data_loader_kwargs=dict(context_length=12, prediction_length=3)), + dict(datamodule_cfg=dict(context_length=12, prediction_length=3)), dict( hidden_size=32, n_heads=2, - data_loader_kwargs=dict( + datamodule_cfg=dict( context_length=12, prediction_length=3, add_relative_time_idx=False, @@ -130,7 +64,7 @@ def get_test_train_params(cls): dict( hidden_size=128, patch_length=12, - data_loader_kwargs=dict(context_length=16, prediction_length=4), + datamodule_cfg=dict(context_length=16, prediction_length=4), ), dict( n_heads=2, @@ -156,10 +90,19 @@ def get_test_train_params(cls): factor=2, activation="relu", dropout=0.05, - data_loader_kwargs=dict( + datamodule_cfg=dict( context_length=16, prediction_length=4, ), loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), ), ] + default_dm_cfg = {"context_length": 12, "prediction_length": 4} + + for param in params: + current_dm_cfg = param.get("datamodule_cfg", {}) + default_dm_cfg.update(current_dm_cfg) + + param["datamodule_cfg"] = default_dm_cfg + + return params diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 9c28c5d0a..da511e153 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -1,6 +1,7 @@ """Automated tests based on the skbase test suite template.""" import shutil +from typing import Any, Optional import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping @@ -8,6 +9,8 @@ import torch import torch.nn as nn +from pytorch_forecasting.base._base_pkg import Base_pkg +from pytorch_forecasting.data import TimeSeries from pytorch_forecasting.metrics import SMAPE from pytorch_forecasting.tests.test_all_estimators import ( EstimatorFixtureGenerator, @@ -20,77 +23,68 @@ def _integration( - estimator_cls, - dataloaders, - tmp_path, - data_loader_kwargs={}, - clip_target: bool = False, - trainer_kwargs=None, + estimator_cls: type[Base_pkg], + test_data: dict[str, TimeSeries], + model_cfg: dict[str, Any], + datamodule_cfg: dict[str, Any], + tmp_path: str, + trainer_cfg: Optional[dict[str, Any]] = None, **kwargs, ): - train_dataloader = dataloaders["train"] - val_dataloader = dataloaders["val"] - test_dataloader = dataloaders["test"] + train_data = test_data["train"] + predict_data = test_data["predict"] - early_stop_callback = EarlyStopping( - monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" - ) + default_model_cfg = {"loss": SMAPE()} + + default_datamodule_cfg = { + "train_val_test_split": (0.8, 0.2), + "add_relative_time_idx": True, + "batch_size": 2, + } logger = TensorBoardLogger(tmp_path) - if trainer_kwargs is None: - trainer_kwargs = {} - trainer = pl.Trainer( - max_epochs=3, - gradient_clip_val=0.1, - callbacks=[early_stop_callback], - enable_checkpointing=True, - default_root_dir=tmp_path, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - logger=logger, - **trainer_kwargs, - ) - training_data_module = dataloaders.get("data_module") - metadata = training_data_module.metadata - - assert isinstance( - metadata, dict - ), f"Expected metadata to be dict, got {type(metadata)}" - - if "loss" in kwargs: - loss = kwargs["loss"] - kwargs.pop("loss") - else: - loss = SMAPE() - - net = estimator_cls( - metadata=metadata, - loss=loss, - **kwargs, + default_trainer_cfg = { + "max_epochs": 3, + "gradient_clip_val": 0.1, + "enable_checkpointing": True, + "default_root_dir": tmp_path, + "limit_train_batches": 2, + "limit_val_batches": 2, + "logger": logger, + } + default_model_cfg.update(model_cfg) + default_datamodule_cfg.update(datamodule_cfg) + if trainer_cfg is not None: + default_trainer_cfg.update(trainer_cfg) + + pkg = estimator_cls( + model_cfg=default_model_cfg, + trainer_cfg=default_trainer_cfg, + datamodule_cfg=default_datamodule_cfg, ) - trainer.fit( - net, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ) - test_outputs = trainer.test(net, dataloaders=test_dataloader) - assert len(test_outputs) > 0 - - # todo: add the predict pipeline and make this test cleaner - x, y = next(iter(test_dataloader)) - net.eval() - with torch.no_grad(): - output = net(x) - net.train() - prediction = output["prediction"] - n_dims = prediction.ndim - assert n_dims == 3, ( - f"Prediction output must be 3D, but got {n_dims}D tensor " - f"with shape {output.shape}" + pkg.fit(train_data) + + predictions = pkg.predict( + predict_data, + mode="raw", ) + assert predictions is not None + assert isinstance(predictions, dict) + assert "prediction" in predictions + + pred_tensor = predictions["prediction"] + assert isinstance(pred_tensor, torch.Tensor) + assert pred_tensor.ndim == 3, f"Prediction must be 3D, got {pred_tensor.ndim}D" + + expected_pred_len = datamodule_cfg.get("prediction_length") + if expected_pred_len: + assert pred_tensor.shape[1] == expected_pred_len, ( + f"Pred length mismatch: expected {expected_pred_len}, " + f"got {pred_tensor.shape[1]}" + ) + shutil.rmtree(tmp_path, ignore_errors=True) @@ -111,10 +105,19 @@ def test_integration( trainer_kwargs, tmp_path, ): - object_class = object_pkg.get_cls() - dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs) - - _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) + params_copy = trainer_kwargs.copy() + datamodule_cfg = params_copy.pop("datamodule_cfg", {}) + model_cfg = params_copy + + test_data = object_pkg.get_test_data(**datamodule_cfg) + + _integration( + estimator_cls=object_pkg, + test_data=test_data, + model_cfg=model_cfg, + datamodule_cfg=datamodule_cfg, + tmp_path=tmp_path, + ) def test_pkg_linkage(self, object_pkg, object_class): """Test that the package is linked correctly.""" From 882caba53245bcdfe33d672ca40df78d32244e97 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 18 Oct 2025 23:33:28 +0530 Subject: [PATCH 4/7] update base model --- pytorch_forecasting/models/base/_base_model_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index 2de069d58..5b2de1d39 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -143,7 +143,7 @@ def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to quantile forecasts.""" # todo: add MultiLoss support - quantiles = kwargs.get("quantiles") # Allow overriding default quantiles + quantiles = kwargs.get("quantiles") q = quantiles or self.loss.quantiles return self.loss.to_quantiles(out["prediction"], quantiles=q) From c53f88193986a4d528d042a58be4b6a646bdf622 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 18 Oct 2025 23:35:51 +0530 Subject: [PATCH 5/7] update base model --- pytorch_forecasting/base/_base_pkg.py | 2 +- pytorch_forecasting/models/base/_base_model_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/base/_base_pkg.py b/pytorch_forecasting/base/_base_pkg.py index 340e9e742..98b19dba1 100644 --- a/pytorch_forecasting/base/_base_pkg.py +++ b/pytorch_forecasting/base/_base_pkg.py @@ -69,7 +69,7 @@ def get_datamodule_cls(cls): raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") @classmethod - def get_test_data(cls, **kwargs): + def get_test_dataset_from(cls, **kwargs): """ Creates and returns D1 TimeSeries dataSet objects for testing. """ diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index 5b2de1d39..b919b6282 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -123,7 +123,6 @@ def predict( trainer_kwargs = kwargs.pop("trainer_kwargs", {}) predict_callback = PredictCallback(mode=mode, return_info=return_info, **kwargs) - # Ensure callbacks list exists and append the new one callbacks = trainer_kwargs.get("callbacks", []) if not isinstance(callbacks, list): callbacks = [callbacks] From f8d06fe749e8d8ae99f4252e20c903a1997bffd3 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 19 Oct 2025 00:04:24 +0530 Subject: [PATCH 6/7] update test_integration --- pytorch_forecasting/tests/test_all_estimators_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index da511e153..86b819fc8 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -109,7 +109,7 @@ def test_integration( datamodule_cfg = params_copy.pop("datamodule_cfg", {}) model_cfg = params_copy - test_data = object_pkg.get_test_data(**datamodule_cfg) + test_data = object_pkg.get_test_dataset_from(**datamodule_cfg) _integration( estimator_cls=object_pkg, From 5024fcb9ee661b5324d954be5faa11b114eef14b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 2 Nov 2025 01:56:59 +0530 Subject: [PATCH 7/7] hadnle kwargs --- pytorch_forecasting/callbacks/predict.py | 8 ++--- .../models/base/_base_model_v2.py | 29 +++++++++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pytorch_forecasting/callbacks/predict.py b/pytorch_forecasting/callbacks/predict.py index 883a92f8b..0d4dab719 100644 --- a/pytorch_forecasting/callbacks/predict.py +++ b/pytorch_forecasting/callbacks/predict.py @@ -30,12 +30,12 @@ def __init__( self, mode: str = "prediction", return_info: Optional[list[str]] = None, - **kwargs, + mode_kwargs: dict[str, Any] = None, ): super().__init__(write_interval="epoch") self.mode = mode self.return_info = return_info or [] - self.kwargs = kwargs + self.mode_kwargs = mode_kwargs or {} self._reset_data() def _reset_data(self, result: bool = True): @@ -60,9 +60,9 @@ def on_predict_batch_end( if self.mode == "raw": processed_output = outputs elif self.mode == "prediction": - processed_output = pl_module.to_prediction(outputs, **self.kwargs) + processed_output = pl_module.to_prediction(outputs, **self.mode_kwargs) elif self.mode == "quantiles": - processed_output = pl_module.to_quantiles(outputs, **self.kwargs) + processed_output = pl_module.to_quantiles(outputs, **self.mode_kwargs) else: raise ValueError(f"Invalid prediction mode: {self.mode}") diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index b919b6282..e0affe943 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -99,7 +99,8 @@ def predict( dataloader: DataLoader, mode: str = "prediction", return_info: Optional[list[str]] = None, - **kwargs, + mode_kwargs: dict[str, Any] = None, + trainer_kwargs: dict[str, Any] = None, ) -> dict[str, torch.Tensor]: """ Generate predictions for new data using the `lightning.Trainer`. @@ -112,16 +113,20 @@ def predict( The prediction mode ("prediction", "quantiles", or "raw"). return_info : list[str], optional A list of additional information to return. - **kwargs : - Additional arguments for `to_prediction`/`to_quantiles` or `Trainer`. + mode_kwargs : dict[str, Any] + Additional arguments for `to_prediction`/`to_quantiles`. + trainer_kwargs: dict[str, Any] + Additional arguments for `Trainer`. Returns ------- dict[str, torch.Tensor] A dictionary of prediction results. """ - trainer_kwargs = kwargs.pop("trainer_kwargs", {}) - predict_callback = PredictCallback(mode=mode, return_info=return_info, **kwargs) + trainer_kwargs = trainer_kwargs or {} + predict_callback = PredictCallback( + mode=mode, return_info=return_info, mode_kwargs=mode_kwargs + ) callbacks = trainer_kwargs.get("callbacks", []) if not isinstance(callbacks, list): @@ -137,14 +142,20 @@ def predict( def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to point forecasts.""" # todo: add MultiLoss support - return self.loss.to_prediction(out["prediction"]) + try: + out = self.loss.to_prediction(out["prediction"], **kwargs) + except TypeError: # in case passed kwargs do not exist + out = self.loss.to_prediction(out["prediction"]) + return out def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to quantile forecasts.""" # todo: add MultiLoss support - quantiles = kwargs.get("quantiles") - q = quantiles or self.loss.quantiles - return self.loss.to_quantiles(out["prediction"], quantiles=q) + try: + out = self.loss.to_quantiles(out["prediction"], **kwargs) + except TypeError: # in case passed kwargs do not exist + out = self.loss.to_quantiles(out["prediction"]) + return out def training_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int