Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/source/_static/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,81 @@ span.highlighted {
.highlight > pre > .s2 {
color: #647db6 !important;
}

/* Model overview table styling */
.model-overview-table {
width: 100%;
font-size: 0.9em;
border-collapse: collapse;
margin: 20px 0;
}

.model-overview-table th {
background-color: #f0f0f0;
font-weight: bold;
padding: 12px;
text-align: left;
border-bottom: 2px solid #ddd;
}

.model-overview-table td {
padding: 10px 12px;
border-bottom: 1px solid #ddd;
}

.model-overview-table tr:hover {
background-color: #f9f9f9;
}

.model-overview-table a {
color: #647db6;
text-decoration: none;
}

.model-overview-table a:hover {
text-decoration: underline;
color: #ee4c2c;
}

#model-overview-container {
margin: 20px 0;
}

#model-filters {
margin-bottom: 15px;
padding: 10px;
background-color: #f9f9f9;
border-radius: 4px;
}

#model-filters label {
margin-right: 15px;
font-weight: 500;
}

#model-filters select {
margin-left: 5px;
padding: 5px 10px;
border: 1px solid #ddd;
border-radius: 4px;
background-color: white;
}

/* DataTables styling overrides */
#model-table_wrapper {
margin-top: 20px;
}

#model-table_wrapper .dataTables_filter input {
margin-left: 10px;
padding: 5px;
border: 1px solid #ddd;
border-radius: 4px;
}

#model-table_wrapper .dataTables_length select {
padding: 5px;
border: 1px solid #ddd;
border-radius: 4px;
margin: 0 5px;
}
104 changes: 104 additions & 0 deletions docs/source/_static/model_overview.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/**
* JavaScript for interactive model overview table.
*
* This script loads the model overview data from JSON and creates
* an interactive DataTable with search and filtering capabilities.
*/

$(document).ready(function() {
// Determine the correct path to the JSON file
// In built HTML, the path should be relative to the current page
var jsonPath = '_static/model_overview_db.json';

// Load model data from JSON
$.getJSON(jsonPath, function(data) {
// Initialize DataTable
var table = $('#model-table').DataTable({
data: data,
columns: [
{
data: 'Model Name',
title: 'Model Name',
render: function(data, type, row) {
// If data is already HTML (from pandas), return as-is
if (type === 'display' && data && data.includes('<a')) {
return data;
}
return data || row['Model Name'] || '';
}
},
{ data: 'Type', title: 'Type' },
{ data: 'Covariates', title: 'Covariates' },
{ data: 'Multiple targets', title: 'Multiple targets' },
{ data: 'Uncertainty', title: 'Uncertainty' },
{ data: 'Flexible history', title: 'Flexible history' },
{ data: 'Cold-start', title: 'Cold-start' },
{ data: 'Compute', title: 'Compute' }
],
pageLength: 25,
order: [[0, 'asc']],
responsive: true,
dom: 'lfrtip',
language: {
search: "Search models:",
lengthMenu: "Show _MENU_ models per page",
info: "Showing _START_ to _END_ of _TOTAL_ models",
infoEmpty: "No models found",
infoFiltered: "(filtered from _MAX_ total models)"
}
});

// Filter by type
$('#type-filter').on('change', function() {
var val = $(this).val();
table.column(1).search(val).draw();
});

// Filter by capability
$('#capability-filter').on('change', function() {
var val = $(this).val();
if (val === '') {
// Clear all capability filters
table.columns([2, 3, 4, 5, 6]).search('').draw();
} else {
// Map capability name to column index
var capabilityMap = {
'Covariates': 2,
'Multiple targets': 3,
'Uncertainty': 4,
'Flexible history': 5,
'Cold-start': 6
};

var colIdx = capabilityMap[val];
if (colIdx !== undefined) {
// Clear all capability columns first
table.columns([2, 3, 4, 5, 6]).search('');
// Then search in the specific column
table.column(colIdx).search('✓').draw();
} else {
// If capability not found, search in all columns
table.search(val).draw();
}
}
});

// Clear filters when "All" is selected
$('#type-filter, #capability-filter').on('change', function() {
if ($(this).val() === '') {
if ($(this).attr('id') === 'type-filter') {
table.column(1).search('').draw();
}
}
});
}).fail(function(jqXHR, textStatus, errorThrown) {
// Handle error loading JSON
console.error('Error loading model overview data:', textStatus, errorThrown);
$('#model-table').html(
'<tr><td colspan="8" style="text-align: center; padding: 20px;">' +
'Error loading model overview data. Please ensure the documentation was built correctly.' +
'</td></tr>'
);
});
});

141 changes: 141 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,152 @@ def get_items(self, names):
return new_items


def _make_estimator_overview(app):
"""Make estimator/model overview table.

This function generates a dynamic table of all models in pytorch-forecasting
by querying the registry system. The table is written as HTML and JSON files
for inclusion in the documentation.
"""
try:
import pandas as pd
from pytorch_forecasting._registry import all_objects

# Base classes to exclude from the overview
BASE_CLASSES = {
"BaseModel",
"BaseModelWithCovariates",
"AutoRegressiveBaseModel",
"AutoRegressiveBaseModelWithCovariates",
"_BaseObject",
"_BasePtForecaster",
"_BasePtForecasterV2",
"_BasePtForecaster_Common",
}

# Get all objects from registry
all_objs = all_objects(return_names=True, suppress_import_stdout=True)

records = []

for obj_name, obj_class in all_objs:
# Skip base classes
if obj_name in BASE_CLASSES:
continue

# Skip if it's not a model class (check if it has get_class_tag method)
if not hasattr(obj_class, "get_class_tag"):
continue

try:
# Get model name from tags or use class name
model_name = obj_class.get_class_tag("info:name", obj_name)

# Get authors
authors = obj_class.get_class_tag("authors", None)
if authors is None:
authors = "pytorch-forecasting developers"
elif isinstance(authors, list):
authors = ", ".join(authors)

# Get object type
object_type = obj_class.get_class_tag("object_type", "model")
if isinstance(object_type, list):
object_type = ", ".join(object_type)

# Get capabilities
has_exogenous = obj_class.get_class_tag("capability:exogenous", False)
has_multivariate = obj_class.get_class_tag("capability:multivariate", False)
has_pred_int = obj_class.get_class_tag("capability:pred_int", False)
has_flexible_history = obj_class.get_class_tag("capability:flexible_history_length", False)
has_cold_start = obj_class.get_class_tag("capability:cold_start", False)

# Get compute requirement
compute = obj_class.get_class_tag("info:compute", None)

# Get module path for documentation link
module_path = obj_class.__module__
class_name = obj_class.__name__

# Construct documentation link
# Convert module path to API documentation path
api_path = module_path.replace(".", "/")
doc_link = f"api/{api_path}.html#{module_path}.{class_name}"

# Create model name with link
model_name_link = f'<a href="{doc_link}">{model_name}</a>'

# Build capabilities string
capabilities = []
if has_exogenous:
capabilities.append("Covariates")
if has_multivariate:
capabilities.append("Multiple targets")
if has_pred_int:
capabilities.append("Uncertainty")
if has_flexible_history:
capabilities.append("Flexible history")
if has_cold_start:
capabilities.append("Cold-start")

capabilities_str = ", ".join(capabilities) if capabilities else ""

records.append({
"Model Name": model_name_link,
"Type": object_type,
"Authors": authors,
"Covariates": "✓" if has_exogenous else "",
"Multiple targets": "✓" if has_multivariate else "",
"Uncertainty": "✓" if has_pred_int else "",
"Flexible history": "✓" if has_flexible_history else "",
"Cold-start": "✓" if has_cold_start else "",
"Compute": str(compute) if compute is not None else "",
"Capabilities": capabilities_str,
"Module": module_path,
})
except Exception as e:
# Skip objects that can't be processed
print(f"Warning: Could not process {obj_name}: {e}")
continue

if not records:
print("Warning: No models found in registry")
return

# Create DataFrame
df = pd.DataFrame(records)

# Ensure _static directory exists
static_dir = SOURCE_PATH.joinpath("_static")
static_dir.mkdir(exist_ok=True)

# Write HTML table
html_file = static_dir.joinpath("model_overview_table.html")
html_content = df[["Model Name", "Type", "Covariates", "Multiple targets",
"Uncertainty", "Flexible history", "Cold-start", "Compute"]].to_html(
classes="model-overview-table", index=False, border=0, escape=False
)
html_file.write_text(html_content, encoding="utf-8")
print(f"Generated model overview table: {html_file}")

# Write JSON database for interactive filtering (optional)
json_file = static_dir.joinpath("model_overview_db.json")
df.to_json(json_file, orient="records", indent=2)
print(f"Generated model overview JSON: {json_file}")

except ImportError as e:
print(f"Warning: Could not generate model overview (missing dependency): {e}")
except Exception as e:
print(f"Warning: Error generating model overview: {e}")


def setup(app: Sphinx):
app.add_css_file("custom.css")
app.connect("autodoc-skip-member", skip)
app.add_directive("moduleautosummary", ModuleAutoSummary)
app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"})
# Connect model overview generator to builder-inited event
app.connect("builder-inited", _make_estimator_overview)


# extension configuration
Expand Down
Loading
Loading