Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ jobs:
- name: Setup environment
run: pip install -e .[test]
- name: Run doctests
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py --no-cov
- name: Run extra tests
run: pytest docs/source/.codespell/test_notebook_to_markdown.py
run: pytest docs/source/.codespell/test_notebook_to_markdown.py --no-cov
- name: Run tests
run: pytest --cov-report=xml --no-cov-on-fail
- name: Check codespell for notebooks
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ repos:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
# Exclude docs/ to avoid applying strict linting rules to example notebooks
# Remove this exclusion if you want to enforce strict rules on documentation
exclude: ^docs/
# Run the formatter
- id: ruff-format
types_or: [ python, pyi, jupyter ]
Expand Down
4 changes: 3 additions & 1 deletion causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def impact(x: np.ndarray) -> np.ndarray:

def generate_ancova_data(
N: int = 200,
pre_treatment_means: np.ndarray = np.array([10, 12]),
pre_treatment_means: np.ndarray | None = None,
treatment_effect: int = 2,
sigma: int = 1,
) -> pd.DataFrame:
Expand All @@ -324,6 +324,8 @@ def generate_ancova_data(
... )
>>> df.to_csv(pathlib.Path.cwd() / "ancova_data.csv", index=False) # doctest: +SKIP
"""
if pre_treatment_means is None:
pre_treatment_means = np.array([10, 12])
group = np.random.choice(2, size=N)
pre = np.random.normal(loc=pre_treatment_means[group])
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma
Expand Down
6 changes: 3 additions & 3 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from abc import abstractmethod
from typing import Any, Literal, Union
from typing import Any, Literal

import arviz as az
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -54,7 +54,7 @@ class BaseExperiment:
supports_bayes: bool
supports_ols: bool

def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
def __init__(self, model: PyMCModel | RegressorMixin | None = None) -> None:
# Ensure we've made any provided Scikit Learn model (as identified as being type
# RegressorMixin) compatible with CausalPy by appending our custom methods.
if isinstance(model, RegressorMixin):
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:

def effect_summary(
self,
window: Union[Literal["post"], tuple, slice] = "post",
window: Literal["post"] | tuple | slice = "post",
direction: Literal["increase", "decrease", "two-sided"] = "increase",
alpha: float = 0.05,
cumulative: bool = True,
Expand Down
6 changes: 2 additions & 4 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Difference in differences
"""

from typing import Union

import arviz as az
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -98,7 +96,7 @@ def __init__(
time_variable_name: str,
group_variable_name: str,
post_treatment_variable_name: str = "post_treatment",
model: Union[PyMCModel, RegressorMixin] | None = None,
model: PyMCModel | RegressorMixin | None = None,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand Down Expand Up @@ -234,7 +232,7 @@ def __init__(
elif isinstance(self.model, RegressorMixin):
# This is the coefficient on the interaction term
# Store the coefficient into dictionary {intercept:value}
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
coef_map = dict(zip(self.labels, self.model.get_coeffs(), strict=False))
# Create and find the interaction term based on the values user provided
interaction_term = (
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
Expand Down
8 changes: 6 additions & 2 deletions causalpy/experiments/instrumental_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def input_validation(self) -> None:
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
The coefficients should be interpreted appropriately.""",
UserWarning,
stacklevel=2,
)

def get_2SLS_fit(self) -> None:
Expand Down Expand Up @@ -195,7 +197,9 @@ def get_naive_OLS_fit(self) -> None:
ols_reg = sk_lin_reg().fit(self.X, self.y)
beta_params = list(ols_reg.coef_[0][1:])
beta_params.insert(0, ols_reg.intercept_[0])
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
self.ols_beta_params = dict(
zip(self._x_design_info.column_names, beta_params, strict=False)
)
self.ols_reg = ols_reg

def plot(self, *args, **kwargs) -> None: # type: ignore[override]
Expand Down
22 changes: 11 additions & 11 deletions causalpy/experiments/interrupted_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Interrupted Time Series Analysis
"""

from typing import Any, List, Union
from typing import Any

import arviz as az
import numpy as np
Expand Down Expand Up @@ -91,9 +91,9 @@ class InterruptedTimeSeries(BaseExperiment):
def __init__(
self,
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
treatment_time: int | float | pd.Timestamp,
formula: str,
model: Union[PyMCModel, RegressorMixin] | None = None,
model: PyMCModel | RegressorMixin | None = None,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
)

if is_bsts_like:
Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(
# score the goodness of fit to the pre-intervention data
if isinstance(self.model, PyMCModel):
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
)
if is_bsts_like:
X_score = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
Expand All @@ -202,7 +202,7 @@ def __init__(
# get the model predictions of the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
)
if is_bsts_like:
X_pre_predict = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
Expand All @@ -220,7 +220,7 @@ def __init__(
# calculate the counterfactual (post period)
if isinstance(self.model, PyMCModel):
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
)
if is_bsts_like:
X_post_predict = (
Expand All @@ -244,7 +244,7 @@ def __init__(
# calculate impact - use appropriate y data format for each model type
if isinstance(self.model, PyMCModel):
is_bsts_like = isinstance(
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
)
if is_bsts_like:
pre_y_for_impact = self.pre_y.isel(treated_units=0)
Expand Down Expand Up @@ -275,7 +275,7 @@ def __init__(
)

def input_validation(
self, data: pd.DataFrame, treatment_time: Union[int, float, pd.Timestamp]
self, data: pd.DataFrame, treatment_time: int | float | pd.Timestamp
) -> None:
"""Validate the input data and model formula for correctness"""
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
Expand Down Expand Up @@ -303,7 +303,7 @@ def summary(self, round_to: int | None = None) -> None:

def _bayesian_plot(
self, round_to: int | None = 2, **kwargs: dict
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results

Expand Down Expand Up @@ -481,7 +481,7 @@ def _bayesian_plot(

def _ols_plot(
self, round_to: int | None = 2, **kwargs: dict
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results

Expand Down
8 changes: 3 additions & 5 deletions causalpy/experiments/inverse_propensity_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Inverse propensity weighting
"""

from typing import List

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -263,7 +261,7 @@ def plot_ate(
method: str | None = None,
prop_draws: int = 100,
ate_draws: int = 300,
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
if idata is None:
idata = self.model.idata
if method is None:
Expand Down Expand Up @@ -325,7 +323,7 @@ def make_hists(idata, i, axs, method=method):
BBBBCC"""

fig, axs = plt.subplot_mosaic(mosaic, figsize=(20, 13))
axs = [axs[k] for k in axs.keys()]
axs = [axs[k] for k in axs]
axs[0].axvline(
0.1, linestyle="--", label="Low Extreme Propensity Scores", color="black"
)
Expand Down Expand Up @@ -412,7 +410,7 @@ def plot_balance_ecdf(
covariate: str,
idata: az.InferenceData | None = None,
weighting_scheme: str | None = None,
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plotting function takes a single covariate and shows the
differences in the ECDF between the treatment and control
Expand Down
4 changes: 1 addition & 3 deletions causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Pretest/posttest nonequivalent group design
"""

from typing import List

import arviz as az
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -227,7 +225,7 @@ def summary(self, round_to: int | None = None) -> None:

def _bayesian_plot(
self, round_to: int | None = None, **kwargs: dict
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""
fig, ax = plt.subplots(
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
Expand Down
6 changes: 3 additions & 3 deletions causalpy/experiments/regression_discontinuity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""

import warnings # noqa: I001
from typing import Union


import numpy as np
Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(
data: pd.DataFrame,
formula: str,
treatment_threshold: float,
model: Union[PyMCModel, RegressorMixin] | None = None,
model: PyMCModel | RegressorMixin | None = None,
running_variable_name: str = "x",
epsilon: float = 0.001,
bandwidth: float = np.inf,
Expand All @@ -112,6 +111,7 @@ def __init__(
warnings.warn(
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
UserWarning,
stacklevel=2,
)
y, X = dmatrices(formula, filtered_data)
else:
Expand Down Expand Up @@ -218,7 +218,7 @@ def input_validation(self) -> None:
self.data = self.data.copy()
self.data["treated"] = self.data["treated"].astype(bool)

def _is_treated(self, x: Union[np.ndarray, pd.Series]) -> np.ndarray:
def _is_treated(self, x: np.ndarray | pd.Series) -> np.ndarray:
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.

.. warning::
Expand Down
4 changes: 2 additions & 2 deletions causalpy/experiments/regression_kink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import warnings # noqa: I001
from typing import Union


from matplotlib import pyplot as plt
Expand Down Expand Up @@ -75,6 +74,7 @@ def __init__(
warnings.warn(
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
UserWarning,
stacklevel=2,
)
y, X = dmatrices(formula, filtered_data)
else:
Expand Down Expand Up @@ -192,7 +192,7 @@ def _probe_kink_point(self) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
return mu_kink_left, mu_kink, mu_kink_right

def _is_treated(self, x: Union[np.ndarray, pd.Series]) -> np.ndarray:
def _is_treated(self, x: np.ndarray | pd.Series) -> np.ndarray:
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
return np.greater_equal(x, self.kink_point)

Expand Down
12 changes: 5 additions & 7 deletions causalpy/experiments/synthetic_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Synthetic Control Experiment
"""

from typing import List, Union

import arviz as az
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -83,10 +81,10 @@ class SyntheticControl(BaseExperiment):
def __init__(
self,
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
treatment_time: int | float | pd.Timestamp,
control_units: list[str],
treated_units: list[str],
model: Union[PyMCModel, RegressorMixin] | None = None,
model: PyMCModel | RegressorMixin | None = None,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand Down Expand Up @@ -186,7 +184,7 @@ def __init__(
)

def input_validation(
self, data: pd.DataFrame, treatment_time: Union[int, float, pd.Timestamp]
self, data: pd.DataFrame, treatment_time: int | float | pd.Timestamp
) -> None:
"""Validate the input data and model formula for correctness"""
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
Expand Down Expand Up @@ -221,7 +219,7 @@ def _bayesian_plot(
round_to: int | None = None,
treated_unit: str | None = None,
**kwargs: dict,
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results for a specific treated unit

Expand Down Expand Up @@ -375,7 +373,7 @@ def _ols_plot(
round_to: int | None = None,
treated_unit: str | None = None,
**kwargs: dict,
) -> tuple[plt.Figure, List[plt.Axes]]:
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results for OLS model for a specific treated unit

Expand Down
8 changes: 4 additions & 4 deletions causalpy/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Plotting utility functions.
"""

from typing import Any, Dict, Tuple, Union
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
Expand All @@ -28,13 +28,13 @@


def plot_xY(
x: Union[pd.DatetimeIndex, np.ndarray, pd.Index, pd.Series, ExtensionArray],
x: pd.DatetimeIndex | np.ndarray | pd.Index | pd.Series | ExtensionArray,
Y: xr.DataArray,
ax: plt.Axes,
plot_hdi_kwargs: Dict[str, Any] | None = None,
plot_hdi_kwargs: dict[str, Any] | None = None,
hdi_prob: float = 0.94,
label: str | None = None,
) -> Tuple[Line2D, PolyCollection]:
) -> tuple[Line2D, PolyCollection]:
"""Plot HDI intervals.

Parameters
Expand Down
Loading