Skip to content

Commit 7962b50

Browse files
committed
init
1 parent a9ddbfa commit 7962b50

15 files changed

+295
-176
lines changed

causalpy/data/simulate_data.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _smoothed_gaussian_random_walk(
3131
gaussian_random_walk_mu: float,
3232
gaussian_random_walk_sigma: float,
3333
N: int,
34-
lowess_kwargs: dict[str, Any],
34+
lowess_kwargs: dict,
3535
) -> tuple[np.ndarray, np.ndarray]:
3636
"""
3737
Generates Gaussian random walk data and applies LOWESS.
@@ -57,7 +57,7 @@ def generate_synthetic_control_data(
5757
treatment_time: int = 70,
5858
grw_mu: float = 0.25,
5959
grw_sigma: float = 1,
60-
lowess_kwargs: dict[str, Any] | None = None,
60+
lowess_kwargs: dict = default_lowess_kwargs,
6161
) -> tuple[pd.DataFrame, np.ndarray]:
6262
"""
6363
Generates data for synthetic control example.
@@ -78,9 +78,6 @@ def generate_synthetic_control_data(
7878
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7979
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
8080
"""
81-
if lowess_kwargs is None:
82-
lowess_kwargs = default_lowess_kwargs
83-
8481
# 1. Generate non-treated variables
8582
df = pd.DataFrame(
8683
{
@@ -166,7 +163,9 @@ def generate_time_series_data(
166163
return df
167164

168165

169-
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
166+
def generate_time_series_data_seasonal(
167+
treatment_time: pd.Timestamp,
168+
) -> pd.DataFrame:
170169
"""
171170
Generates 10 years of monthly data with seasonality
172171
"""
@@ -184,7 +183,9 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
184183

185184
N = df.shape[0]
186185
idx = np.arange(N)[df.index > treatment_time]
187-
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
186+
df["causal effect"] = 100 * gamma(10).pdf(
187+
np.array(np.arange(0, N, 1)) - int(np.min(idx))
188+
)
188189

189190
df["y"] += df["causal effect"]
190191
df["y"] += norm(0, 2).rvs(N)
@@ -310,8 +311,8 @@ def impact(x: np.ndarray) -> np.ndarray:
310311
def generate_ancova_data(
311312
N: int = 200,
312313
pre_treatment_means: np.ndarray = np.array([10, 12]),
313-
treatment_effect: float = 2,
314-
sigma: float = 1,
314+
treatment_effect: int = 2,
315+
sigma: int = 1,
315316
) -> pd.DataFrame:
316317
"""
317318
Generate ANCOVA example data
@@ -445,7 +446,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:
445446

446447

447448
def generate_seasonality(
448-
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
449+
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
449450
) -> np.ndarray:
450451
"""Generate monthly seasonality by sampling from a Gaussian process with a
451452
Gaussian kernel, using numpy code"""
@@ -463,9 +464,9 @@ def generate_seasonality(
463464
def periodic_kernel(
464465
x1: np.ndarray,
465466
x2: np.ndarray,
466-
period: float = 1,
467-
length_scale: float = 1,
468-
amplitude: float = 1,
467+
period: int = 1,
468+
length_scale: float = 1.0,
469+
amplitude: int = 1,
469470
) -> np.ndarray:
470471
"""Generate a periodic kernel for gaussian process"""
471472
return amplitude**2 * np.exp(
@@ -475,10 +476,10 @@ def periodic_kernel(
475476

476477
def create_series(
477478
n: int = 52,
478-
amplitude: float = 1,
479-
length_scale: float = 2,
479+
amplitude: int = 1,
480+
length_scale: int = 2,
480481
n_years: int = 4,
481-
intercept: float = 3,
482+
intercept: int = 3,
482483
) -> np.ndarray:
483484
"""
484485
Returns numpy tile with generated seasonality data repeated over

causalpy/experiments/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19+
from typing import Any, Union
1920

2021
import arviz as az
2122
import matplotlib.pyplot as plt
@@ -29,10 +30,12 @@
2930
class BaseExperiment:
3031
"""Base class for quasi experimental designs."""
3132

33+
labels: list[str]
34+
3235
supports_bayes: bool
3336
supports_ols: bool
3437

35-
def __init__(self, model=None):
38+
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
3639
# Ensure we've made any provided Scikit Learn model (as identified as being type
3740
# RegressorMixin) compatible with CausalPy by appending our custom methods.
3841
if isinstance(model, RegressorMixin):
@@ -50,16 +53,19 @@ def __init__(self, model=None):
5053
if self.model is None:
5154
raise ValueError("model not set or passed.")
5255

56+
def fit(self, *args: Any, **kwargs: Any) -> None:
57+
raise NotImplementedError("fit method not implemented")
58+
5359
@property
54-
def idata(self):
60+
def idata(self) -> az.InferenceData:
5561
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
5662
return self.model.idata
5763

58-
def print_coefficients(self, round_to=None):
64+
def print_coefficients(self, round_to: int | None = None) -> None:
5965
"""Ask the model to print its coefficients."""
6066
self.model.print_coefficients(self.labels, round_to)
6167

62-
def plot(self, *args, **kwargs) -> tuple:
68+
def plot(self, *args: Any, **kwargs: Any) -> tuple:
6369
"""Plot the model.
6470
6571
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
@@ -75,16 +81,16 @@ def plot(self, *args, **kwargs) -> tuple:
7581
raise ValueError("Unsupported model type")
7682

7783
@abstractmethod
78-
def _bayesian_plot(self, *args, **kwargs):
84+
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
7985
"""Abstract method for plotting the model."""
8086
raise NotImplementedError("_bayesian_plot method not yet implemented")
8187

8288
@abstractmethod
83-
def _ols_plot(self, *args, **kwargs):
89+
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
8490
"""Abstract method for plotting the model."""
8591
raise NotImplementedError("_ols_plot method not yet implemented")
8692

87-
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
93+
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
8894
"""Recover the data of an experiment along with the prediction and causal impact information.
8995
9096
Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
@@ -98,11 +104,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
98104
raise ValueError("Unsupported model type")
99105

100106
@abstractmethod
101-
def get_plot_data_bayesian(self, *args, **kwargs):
107+
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
102108
"""Abstract method for recovering plot data."""
103109
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
104110

105111
@abstractmethod
106-
def get_plot_data_ols(self, *args, **kwargs):
112+
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
107113
"""Abstract method for recovering plot data."""
108114
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/diff_in_diff.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Difference in differences
1616
"""
1717

18+
from typing import Union
19+
1820
import arviz as az
1921
import numpy as np
2022
import pandas as pd
@@ -92,8 +94,8 @@ def __init__(
9294
time_variable_name: str,
9395
group_variable_name: str,
9496
post_treatment_variable_name: str = "post_treatment",
95-
model=None,
96-
**kwargs,
97+
model: Union[PyMCModel, RegressorMixin] | None = None,
98+
**kwargs: dict,
9799
) -> None:
98100
super().__init__(model=model)
99101
self.causal_impact: xr.DataArray | float | None
@@ -234,14 +236,14 @@ def __init__(
234236
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
235237
)
236238
matched_key = next((k for k in coef_map if interaction_term in k), None)
237-
att = coef_map.get(matched_key)
239+
att = coef_map.get(matched_key) if matched_key is not None else None
238240
self.causal_impact = att
239241
else:
240242
raise ValueError("Model type not recognized")
241243

242244
return
243245

244-
def input_validation(self):
246+
def input_validation(self) -> None:
245247
# Validate formula structure and interaction interaction terms
246248
self._validate_formula_interaction_terms()
247249

@@ -269,7 +271,7 @@ def input_validation(self):
269271
coded. Consisting of 0's and 1's only."""
270272
)
271273

272-
def _validate_formula_interaction_terms(self):
274+
def _validate_formula_interaction_terms(self) -> None:
273275
"""
274276
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
275277
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
@@ -299,7 +301,7 @@ def _validate_formula_interaction_terms(self):
299301
"Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
300302
)
301303

302-
def summary(self, round_to=None) -> None:
304+
def summary(self, round_to: int | None = 2) -> None:
303305
"""Print summary of main results and model coefficients.
304306
305307
:param round_to:
@@ -311,11 +313,13 @@ def summary(self, round_to=None) -> None:
311313
print(self._causal_impact_summary_stat(round_to))
312314
self.print_coefficients(round_to)
313315

314-
def _causal_impact_summary_stat(self, round_to=None) -> str:
316+
def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
315317
"""Computes the mean and 94% credible interval bounds for the causal impact."""
316318
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
317319

318-
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
320+
def _bayesian_plot(
321+
self, round_to: int | None = None, **kwargs: dict
322+
) -> tuple[plt.Figure, plt.Axes]:
319323
"""
320324
Plot the results
321325
@@ -463,9 +467,10 @@ def _plot_causal_impact_arrow(results, ax):
463467
)
464468
return fig, ax
465469

466-
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
470+
def _ols_plot(
471+
self, round_to: int | None = 2, **kwargs: dict
472+
) -> tuple[plt.Figure, plt.Axes]:
467473
"""Generate plot for difference-in-differences"""
468-
round_to = kwargs.get("round_to")
469474
fig, ax = plt.subplots()
470475

471476
# Plot raw data
@@ -528,11 +533,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
528533
va="center",
529534
)
530535
# formatting
536+
# In OLS context, causal_impact should be a float, but mypy doesn't know this
537+
causal_impact_value = (
538+
float(self.causal_impact) if self.causal_impact is not None else 0.0
539+
)
531540
ax.set(
532541
xlim=[-0.05, 1.1],
533542
xticks=[0, 1],
534543
xticklabels=["pre", "post"],
535-
title=f"Causal impact = {round_num(self.causal_impact, round_to)}",
544+
title=f"Causal impact = {round_num(causal_impact_value, round_to)}",
536545
)
537546
ax.legend(fontsize=LEGEND_FONT_SIZE)
538547
return fig, ax

causalpy/experiments/instrumental_variable.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def __init__(
9797
data: pd.DataFrame,
9898
instruments_formula: str,
9999
formula: str,
100-
model=None,
101-
priors=None,
102-
**kwargs,
103-
):
100+
model: BaseExperiment | None = None,
101+
priors: dict | None = None,
102+
**kwargs: dict,
103+
) -> None:
104104
super().__init__(model=model)
105105
self.expt_type = "Instrumental Variable Regression"
106106
self.data = data
@@ -138,11 +138,11 @@ def __init__(
138138
"lkj_sd": 1,
139139
}
140140
self.priors = priors
141-
self.model.fit(
141+
self.model.fit( # type: ignore[call-arg,union-attr]
142142
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
143143
)
144144

145-
def input_validation(self):
145+
def input_validation(self) -> None:
146146
"""Validate the input data and model formula for correctness"""
147147
treatment = self.instruments_formula.split("~")[0]
148148
test = treatment.strip() in self.instruments_data.columns
@@ -165,7 +165,7 @@ def input_validation(self):
165165
The coefficients should be interpreted appropriately."""
166166
)
167167

168-
def get_2SLS_fit(self):
168+
def get_2SLS_fit(self) -> None:
169169
"""
170170
Two Stage Least Squares Fit
171171
@@ -187,7 +187,7 @@ def get_2SLS_fit(self):
187187
self.first_stage_reg = first_stage_reg
188188
self.second_stage_reg = second_stage_reg
189189

190-
def get_naive_OLS_fit(self):
190+
def get_naive_OLS_fit(self) -> None:
191191
"""
192192
Naive Ordinary Least Squares
193193
@@ -199,7 +199,7 @@ def get_naive_OLS_fit(self):
199199
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
200200
self.ols_reg = ols_reg
201201

202-
def plot(self, round_to=None):
202+
def plot(self, *args, **kwargs) -> None: # type: ignore[override]
203203
"""
204204
Plot the results
205205
@@ -208,7 +208,7 @@ def plot(self, round_to=None):
208208
"""
209209
raise NotImplementedError("Plot method not implemented.")
210210

211-
def summary(self, round_to=None) -> None:
211+
def summary(self, round_to: int | None = None) -> None:
212212
"""Print summary of main results and model coefficients.
213213
214214
:param round_to:

0 commit comments

Comments
 (0)