diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css
index 70b848a19..18b25e16e 100644
--- a/docs/source/_static/custom.css
+++ b/docs/source/_static/custom.css
@@ -56,3 +56,81 @@ span.highlighted {
.highlight > pre > .s2 {
color: #647db6 !important;
}
+
+/* Model overview table styling */
+.model-overview-table {
+ width: 100%;
+ font-size: 0.9em;
+ border-collapse: collapse;
+ margin: 20px 0;
+}
+
+.model-overview-table th {
+ background-color: #f0f0f0;
+ font-weight: bold;
+ padding: 12px;
+ text-align: left;
+ border-bottom: 2px solid #ddd;
+}
+
+.model-overview-table td {
+ padding: 10px 12px;
+ border-bottom: 1px solid #ddd;
+}
+
+.model-overview-table tr:hover {
+ background-color: #f9f9f9;
+}
+
+.model-overview-table a {
+ color: #647db6;
+ text-decoration: none;
+}
+
+.model-overview-table a:hover {
+ text-decoration: underline;
+ color: #ee4c2c;
+}
+
+#model-overview-container {
+ margin: 20px 0;
+}
+
+#model-filters {
+ margin-bottom: 15px;
+ padding: 10px;
+ background-color: #f9f9f9;
+ border-radius: 4px;
+}
+
+#model-filters label {
+ margin-right: 15px;
+ font-weight: 500;
+}
+
+#model-filters select {
+ margin-left: 5px;
+ padding: 5px 10px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ background-color: white;
+}
+
+/* DataTables styling overrides */
+#model-table_wrapper {
+ margin-top: 20px;
+}
+
+#model-table_wrapper .dataTables_filter input {
+ margin-left: 10px;
+ padding: 5px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+}
+
+#model-table_wrapper .dataTables_length select {
+ padding: 5px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ margin: 0 5px;
+}
\ No newline at end of file
diff --git a/docs/source/_static/model_overview.js b/docs/source/_static/model_overview.js
new file mode 100644
index 000000000..925e9d23f
--- /dev/null
+++ b/docs/source/_static/model_overview.js
@@ -0,0 +1,104 @@
+/**
+ * JavaScript for interactive model overview table.
+ *
+ * This script loads the model overview data from JSON and creates
+ * an interactive DataTable with search and filtering capabilities.
+ */
+
+$(document).ready(function() {
+ // Determine the correct path to the JSON file
+ // In built HTML, the path should be relative to the current page
+ var jsonPath = '_static/model_overview_db.json';
+
+ // Load model data from JSON
+ $.getJSON(jsonPath, function(data) {
+ // Initialize DataTable
+ var table = $('#model-table').DataTable({
+ data: data,
+ columns: [
+ {
+ data: 'Model Name',
+ title: 'Model Name',
+ render: function(data, type, row) {
+ // If data is already HTML (from pandas), return as-is
+ if (type === 'display' && data && data.includes('
' +
+ 'Error loading model overview data. Please ensure the documentation was built correctly.' +
+ '
'
+ );
+ });
+});
+
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 0e58692e1..7e43bbcaf 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -132,11 +132,152 @@ def get_items(self, names):
return new_items
+def _make_estimator_overview(app):
+ """Make estimator/model overview table.
+
+ This function generates a dynamic table of all models in pytorch-forecasting
+ by querying the registry system. The table is written as HTML and JSON files
+ for inclusion in the documentation.
+ """
+ try:
+ import pandas as pd
+ from pytorch_forecasting._registry import all_objects
+
+ # Base classes to exclude from the overview
+ BASE_CLASSES = {
+ "BaseModel",
+ "BaseModelWithCovariates",
+ "AutoRegressiveBaseModel",
+ "AutoRegressiveBaseModelWithCovariates",
+ "_BaseObject",
+ "_BasePtForecaster",
+ "_BasePtForecasterV2",
+ "_BasePtForecaster_Common",
+ }
+
+ # Get all objects from registry
+ all_objs = all_objects(return_names=True, suppress_import_stdout=True)
+
+ records = []
+
+ for obj_name, obj_class in all_objs:
+ # Skip base classes
+ if obj_name in BASE_CLASSES:
+ continue
+
+ # Skip if it's not a model class (check if it has get_class_tag method)
+ if not hasattr(obj_class, "get_class_tag"):
+ continue
+
+ try:
+ # Get model name from tags or use class name
+ model_name = obj_class.get_class_tag("info:name", obj_name)
+
+ # Get authors
+ authors = obj_class.get_class_tag("authors", None)
+ if authors is None:
+ authors = "pytorch-forecasting developers"
+ elif isinstance(authors, list):
+ authors = ", ".join(authors)
+
+ # Get object type
+ object_type = obj_class.get_class_tag("object_type", "model")
+ if isinstance(object_type, list):
+ object_type = ", ".join(object_type)
+
+ # Get capabilities
+ has_exogenous = obj_class.get_class_tag("capability:exogenous", False)
+ has_multivariate = obj_class.get_class_tag("capability:multivariate", False)
+ has_pred_int = obj_class.get_class_tag("capability:pred_int", False)
+ has_flexible_history = obj_class.get_class_tag("capability:flexible_history_length", False)
+ has_cold_start = obj_class.get_class_tag("capability:cold_start", False)
+
+ # Get compute requirement
+ compute = obj_class.get_class_tag("info:compute", None)
+
+ # Get module path for documentation link
+ module_path = obj_class.__module__
+ class_name = obj_class.__name__
+
+ # Construct documentation link
+ # Convert module path to API documentation path
+ api_path = module_path.replace(".", "/")
+ doc_link = f"api/{api_path}.html#{module_path}.{class_name}"
+
+ # Create model name with link
+ model_name_link = f'{model_name}'
+
+ # Build capabilities string
+ capabilities = []
+ if has_exogenous:
+ capabilities.append("Covariates")
+ if has_multivariate:
+ capabilities.append("Multiple targets")
+ if has_pred_int:
+ capabilities.append("Uncertainty")
+ if has_flexible_history:
+ capabilities.append("Flexible history")
+ if has_cold_start:
+ capabilities.append("Cold-start")
+
+ capabilities_str = ", ".join(capabilities) if capabilities else ""
+
+ records.append({
+ "Model Name": model_name_link,
+ "Type": object_type,
+ "Authors": authors,
+ "Covariates": "✓" if has_exogenous else "",
+ "Multiple targets": "✓" if has_multivariate else "",
+ "Uncertainty": "✓" if has_pred_int else "",
+ "Flexible history": "✓" if has_flexible_history else "",
+ "Cold-start": "✓" if has_cold_start else "",
+ "Compute": str(compute) if compute is not None else "",
+ "Capabilities": capabilities_str,
+ "Module": module_path,
+ })
+ except Exception as e:
+ # Skip objects that can't be processed
+ print(f"Warning: Could not process {obj_name}: {e}")
+ continue
+
+ if not records:
+ print("Warning: No models found in registry")
+ return
+
+ # Create DataFrame
+ df = pd.DataFrame(records)
+
+ # Ensure _static directory exists
+ static_dir = SOURCE_PATH.joinpath("_static")
+ static_dir.mkdir(exist_ok=True)
+
+ # Write HTML table
+ html_file = static_dir.joinpath("model_overview_table.html")
+ html_content = df[["Model Name", "Type", "Covariates", "Multiple targets",
+ "Uncertainty", "Flexible history", "Cold-start", "Compute"]].to_html(
+ classes="model-overview-table", index=False, border=0, escape=False
+ )
+ html_file.write_text(html_content, encoding="utf-8")
+ print(f"Generated model overview table: {html_file}")
+
+ # Write JSON database for interactive filtering (optional)
+ json_file = static_dir.joinpath("model_overview_db.json")
+ df.to_json(json_file, orient="records", indent=2)
+ print(f"Generated model overview JSON: {json_file}")
+
+ except ImportError as e:
+ print(f"Warning: Could not generate model overview (missing dependency): {e}")
+ except Exception as e:
+ print(f"Warning: Error generating model overview: {e}")
+
+
def setup(app: Sphinx):
app.add_css_file("custom.css")
app.connect("autodoc-skip-member", skip)
app.add_directive("moduleautosummary", ModuleAutoSummary)
app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"})
+ # Connect model overview generator to builder-inited event
+ app.connect("builder-inited", _make_estimator_overview)
# extension configuration
diff --git a/docs/source/model_overview.rst b/docs/source/model_overview.rst
new file mode 100644
index 000000000..541703f53
--- /dev/null
+++ b/docs/source/model_overview.rst
@@ -0,0 +1,69 @@
+.. _model_overview:
+
+Model Overview
+==============
+
+This page provides a comprehensive, searchable overview of all forecasting models available in pytorch-forecasting.
+The table is automatically generated from the model registry, ensuring it stays up-to-date as new models are added.
+
+Use the search box below to filter models by name, type, or capabilities. Click on column headers to sort the table.
+
+.. raw:: html
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+How to Use This Page
+--------------------
+
+- **Search**: Type in the search box to filter models by any column
+- **Sort**: Click column headers to sort ascending/descending
+- **Filter by Type**: Use the "Model Type" dropdown to filter by forecaster version
+- **Filter by Capability**: Use the "Capability" dropdown to find models with specific features
+- **Click Model Names**: Click on any model name to view its detailed API documentation
+
+.. raw:: html
+ :file: _static/model_overview_table.html
+
+
+
+The table includes the following information:
+
+- **Model Name**: Name of the model with link to API documentation
+- **Type**: Object type (forecaster version)
+- **Covariates**: Whether the model supports exogenous variables/covariates
+- **Multiple targets**: Whether the model can handle multiple target variables
+- **Uncertainty**: Whether the model provides uncertainty estimates
+- **Flexible history**: Whether the model supports variable history lengths
+- **Cold-start**: Whether the model can make predictions without historical data
+- **Compute**: Computational resource requirement (1-5 scale, where 5 is most intensive)
+
+For more information about selecting the right model for your use case, see the :doc:`models` page.
+
diff --git a/docs/source/models.rst b/docs/source/models.rst
index 1df0a395a..ed7efeeb5 100644
--- a/docs/source/models.rst
+++ b/docs/source/models.rst
@@ -21,20 +21,13 @@ Selecting an architecture
Criteria for selecting an architecture depend heavily on the use-case. There are multiple selection criteria
and you should take into account. Here is an overview over the pros and cons of the implemented models:
-.. csv-table:: Model comparison
- :header: "Name", "Covariates", "Multiple targets", "Regression", "Classification", "Probabilistic", "Uncertainty", "Interactions between series", "Flexible history length", "Cold-start", "Required computational resources (1-5, 5=most)"
+.. note::
+ The table below is automatically generated from the model registry. It includes all available models
+ with their capabilities and properties. For a more detailed, searchable overview, see the
+ :ref:`Model Overview Page `.
- :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2
- :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1
- :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1
- :py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1
- :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
- :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
- :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
- :py:class:`~pytorch_forecasting.models.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
- :py:class:`~pytorch_forecasting.models.xlstm.xLSTMTime`, "x", "x", "x", "", "", "", "", "x", "", 3
-
-.. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model.
+.. raw:: html
+ :file: _static/model_overview_table.html
Size and type of available data
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -72,7 +65,7 @@ If your time series are related to each other (e.g. all sales of products of the
a model that can learn relations between the timeseries can improve accuracy.
Not that only :ref:`models that can process covariates ` can
learn relationships between different timeseries.
-If the timeseries denote different entities or exhibit very similar patterns across the board,
+If the timeseries denote different entities or exhibit very similar patterns accross the board,
a model such as :py:class:`~pytorch_forecasting.models.nbeats.NBeats` will not work as well.
If you have only one or very few timeseries,
@@ -155,7 +148,13 @@ Every model should inherit from a base model in :py:mod:`~pytorch_forecasting.mo
Details and available models
-------------------------------
-See the API documentation for further details on available models:
+The table above provides an automatically generated overview of all models in pytorch-forecasting,
+including their capabilities, dependencies, and properties. The table is generated from the model
+registry, ensuring it stays up-to-date as new models are added.
+
+For a more detailed, searchable overview with filtering capabilities, see the :ref:`Model Overview Page `.
+
+See the API documentation below for further details on available models:
.. currentmodule:: pytorch_forecasting