diff --git a/causalpy/__init__.py b/causalpy/__init__.py index 5587fb3e..09384669 100644 --- a/causalpy/__init__.py +++ b/causalpy/__init__.py @@ -41,4 +41,5 @@ "RegressionKink", "skl_models", "SyntheticControl", + "variable_selection_priors", ] diff --git a/causalpy/experiments/instrumental_variable.py b/causalpy/experiments/instrumental_variable.py index 15427b40..4fdf7429 100644 --- a/causalpy/experiments/instrumental_variable.py +++ b/causalpy/experiments/instrumental_variable.py @@ -51,6 +51,16 @@ class InstrumentalVariable(BaseExperiment): If priors are not specified we will substitute MLE estimates for the beta coefficients. Example: ``priors = {"mus": [0, 0], "sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``. + vs_prior_type : str or None, default=None + Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None. + If None, uses standard normal priors. + vs_hyperparams : dict, optional + Hyperparameters for variable selection priors. Only used if vs_prior_type + is not None. + binary_treatment : bool, default=False + A indicator for whether the treatment to be modelled is binary or not. + Determines which PyMC model we use to model the joint outcome and + treatment. Example -------- @@ -85,6 +95,16 @@ class InstrumentalVariable(BaseExperiment): ... formula=formula, ... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs), ... ) + >>> # With variable selection + >>> iv = cp.InstrumentalVariable( + ... instruments_data=instruments_data, + ... data=data, + ... instruments_formula=instruments_formula, + ... formula=formula, + ... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs), + ... vs_prior_type="spike_and_slab", + ... vs_hyperparams={"slab_sigma": 5.0}, + ... ) """ supports_ols = False @@ -98,6 +118,9 @@ def __init__( formula: str, model: BaseExperiment | None = None, priors: dict | None = None, + vs_prior_type=None, + vs_hyperparams=None, + binary_treatment=False, **kwargs: dict, ) -> None: super().__init__(model=model) @@ -107,6 +130,9 @@ def __init__( self.formula = formula self.instruments_formula = instruments_formula self.model = model + self.vs_prior_type = vs_prior_type + self.vs_hyperparams = vs_hyperparams or {} + self.binary_treatment = binary_treatment self.input_validation() y, X = dmatrices(formula, self.data) @@ -130,15 +156,33 @@ def __init__( COORDS = {"instruments": self.labels_instruments, "covariates": self.labels} self.coords = COORDS if priors is None: - priors = { - "mus": [self.ols_beta_first_params, self.ols_beta_second_params], - "sigmas": [1, 1], - "eta": 2, - "lkj_sd": 1, - } + if binary_treatment: + # Different default priors for binary treatment + priors = { + "mus": [self.ols_beta_first_params, self.ols_beta_second_params], + "sigmas": [1, 1], + "sigma_U": 1.0, + "rho_bounds": [-0.99, 0.99], + } + else: + # Original continuous treatment priors + priors = { + "mus": [self.ols_beta_first_params, self.ols_beta_second_params], + "sigmas": [1, 1], + "eta": 2, + "lkj_sd": 1, + } self.priors = priors self.model.fit( # type: ignore[call-arg,union-attr] - X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors + X=self.X, + Z=self.Z, + y=self.y, + t=self.t, + coords=COORDS, + priors=self.priors, + vs_prior_type=vs_prior_type, + vs_hyperparams=vs_hyperparams, + binary_treatment=self.binary_treatment, ) def input_validation(self) -> None: @@ -159,9 +203,8 @@ def input_validation(self) -> None: if check_binary: warnings.warn( """Warning. The treatment variable is not Binary. - This is not necessarily a problem but it violates - the assumption of a simple IV experiment. - The coefficients should be interpreted appropriately.""" + We will use the multivariate normal likelihood + for continuous treatment.""" ) def get_2SLS_fit(self) -> None: diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 9adc3bcf..71f5210e 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -27,6 +27,7 @@ from pymc_extras.prior import Prior from causalpy.utils import round_num +from causalpy.variable_selection_priors import VariableSelectionPrior class PyMCModel(pm.Model): @@ -679,7 +680,10 @@ def build_model( # type: ignore y: np.ndarray, t: np.ndarray, coords: Dict[str, Any], - priors: Dict[str, Any], + priors, + vs_prior_type=None, + vs_hyperparams=None, + binary_treatment=False, ) -> None: """Specify model with treatment regression and focal regression data and priors. @@ -702,48 +706,126 @@ def build_model( # type: ignore Dictionary of priors for the mus and sigmas of both regressions. Example: ``priors = {"mus": [0, 0], "sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``. + vs_prior_type: An optional string. Can be "spike_and_slab" + or "horseshoe" or "normal + vs_hyperparams: An optional dictionary of priors for the + variable selection hyperparameters + binary_treatment: A flag for determining the relevant + likelihood to be used. + """ # --- Priors --- with self: self.add_coords(coords) - beta_t = pm.Normal( - name="beta_t", - mu=priors["mus"][0], - sigma=priors["sigmas"][0], - dims="instruments", - ) - beta_z = pm.Normal( - name="beta_z", - mu=priors["mus"][1], - sigma=priors["sigmas"][1], - dims="covariates", - ) - sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2) - chol, corr, sigmas = pm.LKJCholeskyCov( - name="chol_cov", - eta=priors["eta"], - n=2, - sd_dist=sd_dist, - ) - # compute and store the covariance matrix - pm.Deterministic(name="cov", var=pt.dot(l=chol, r=chol.T)) - - # --- Parameterization --- - mu_y = pm.Deterministic(name="mu_y", var=pt.dot(X, beta_z)) - # focal regression - mu_t = pm.Deterministic(name="mu_t", var=pt.dot(Z, beta_t)) - # instrumental regression - mu = pm.Deterministic(name="mu", var=pt.stack(tensors=(mu_y, mu_t), axis=1)) - - # --- Likelihood --- - pm.MvNormal( - name="likelihood", - mu=mu, - chol=chol, - observed=np.stack(arrays=(y.flatten(), t.flatten()), axis=1), - shape=(X.shape[0], 2), - ) + + if vs_prior_type and ("mus" in priors or "sigmas" in priors): + warnings.warn( + "Variable selection priors specified. " + "The 'mus' and 'sigmas' in the priors dict will be ignored " + "for beta coefficients. Only 'eta' and 'lkj_sd' will be used." + ) + + # Create coefficient priors + if vs_prior_type: + # Use variable selection priors + self.vs_prior_treatment = VariableSelectionPrior( + vs_prior_type, vs_hyperparams + ) + self.vs_prior_outcome = VariableSelectionPrior( + vs_prior_type, vs_hyperparams + ) + + beta_t = self.vs_prior_treatment.create_prior( + name="beta_t", n_params=Z.shape[1], dims="instruments", X=Z + ) + + beta_z = self.vs_prior_outcome.create_prior( + name="beta_z", n_params=X.shape[1], dims="covariates", X=X + ) + else: + # Use standard normal priors + beta_t = pm.Normal( + name="beta_t", + mu=priors["mus"][0], + sigma=priors["sigmas"][0], + dims="instruments", + ) + beta_z = pm.Normal( + name="beta_z", + mu=priors["mus"][1], + sigma=priors["sigmas"][1], + dims="covariates", + ) + + if binary_treatment: + # Binary treatment formulation with correlated latent errors + sigma_U = pm.Exponential("sigma_U", priors.get("sigma_U", 1.0)) + + # Correlation parameter with bounds + rho_lower = priors.get("rho_bounds", [-0.99, 0.99])[0] + rho_upper = priors.get("rho_bounds", [-0.99, 0.99])[1] + + # Use tanh transform to keep correlation in valid range + rho_unconstr = pm.Normal("rho_unconstr", 0, 0.5) + rho = pm.Deterministic("rho", pm.math.tanh(rho_unconstr)) + + # Clip to ensure numerical stability + rho_clipped = pt.clip(rho, rho_lower + 0.01, rho_upper - 0.01) + + # Cholesky decomposition for correlated errors + inverse_rho = pm.math.sqrt(pm.math.maximum(1 - rho_clipped**2, 1e-12)) + chol = pt.stack([[sigma_U, 0.0], [sigma_U * rho_clipped, inverse_rho]]) + + # Draw latent errors + eps_raw = pm.Normal("eps_raw", 0, 1, shape=(X.shape[0], 2)) + eps = pm.Deterministic("eps", pt.dot(eps_raw, chol.T)) + + U = eps[:, 0] # Outcome error + V = eps[:, 1] # Treatment error + + # Treatment equation (logit link for binary treatment) + mu_treatment = pm.Deterministic("mu_t", pt.dot(Z, beta_t) + V) + p_t = pm.math.invlogit(mu_treatment) + pm.Bernoulli("likelihood_treatment", p=p_t, observed=t.flatten()) + + # Outcome equation + mu_outcome = pm.Deterministic("mu_y", pt.dot(X, beta_z) + U) + pm.Normal( + "likelihood_outcome", + mu=mu_outcome, + sigma=sigma_U, + observed=y.flatten(), + ) + + else: + sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2) + chol, _, _ = pm.LKJCholeskyCov( + name="chol_cov", + eta=priors["eta"], + n=2, + sd_dist=sd_dist, + ) + # compute and store the covariance matrix + pm.Deterministic(name="cov", var=pt.dot(l=chol, r=chol.T)) + + # --- Parameterization --- + mu_y = pm.Deterministic(name="mu_y", var=pt.dot(X, beta_z)) + # focal regression + mu_t = pm.Deterministic(name="mu_t", var=pt.dot(Z, beta_t)) + # instrumental regression + mu = pm.Deterministic( + name="mu", var=pt.stack(tensors=(mu_y, mu_t), axis=1) + ) + + # --- Likelihood --- + pm.MvNormal( + name="likelihood", + mu=mu, + chol=chol, + observed=np.stack(arrays=(y.flatten(), t.flatten()), axis=1), + shape=(X.shape[0], 2), + ) def sample_predictive_distribution(self, ppc_sampler: str | None = "jax") -> None: """Function to sample the Multivariate Normal posterior predictive @@ -777,50 +859,35 @@ def sample_predictive_distribution(self, ppc_sampler: str | None = "jax") -> Non ) ) - def fit( # type: ignore + def fit( # type: ignore[override] self, - X: np.ndarray, - Z: np.ndarray, - y: np.ndarray, - t: np.ndarray, - coords: Dict[str, Any], - priors: Dict[str, Any], - ppc_sampler: str | None = None, - ) -> az.InferenceData: - """Draw samples from posterior distribution and potentially from - the prior and posterior predictive distributions. - - Parameters - ---------- - X : np.ndarray - Array used to predict our outcome y. - Z : np.ndarray - Array used to predict our treatment variable t. - y : np.ndarray - Array of values representing our focal outcome y. - t : np.ndarray - Array representing the treatment variable. - coords : dict - Dictionary with coordinate names for named dimensions. - priors : dict - Dictionary of priors for the model. - ppc_sampler : str, optional - Sampler for posterior predictive distribution. Can be 'jax', - 'pymc', or None. Defaults to None, so the user can determine - if they wish to spend time sampling the posterior predictive - distribution independently. - - Returns - ------- - az.InferenceData - InferenceData object containing the samples. + X, + Z, + y, + t, + coords, + priors, + ppc_sampler=None, + vs_prior_type=None, + vs_hyperparams=None, + binary_treatment: bool = False, + ): # type: ignore[override] + """Draw samples from posterior distribution and potentially + from the prior and posterior predictive distributions. The + fit call can take values for the + ppc_sampler = ['jax', 'pymc', None] + We default to None, so the user can determine if they wish + to spend time sampling the posterior predictive distribution + independently. """ # Ensure random_seed is used in sample_prior_predictive() and # sample_posterior_predictive() if provided in sample_kwargs. # Use JAX for ppc sampling of multivariate likelihood - self.build_model(X, Z, y, t, coords, priors) + self.build_model( + X, Z, y, t, coords, priors, vs_prior_type, vs_hyperparams, binary_treatment + ) with self: self.idata = pm.sample(**self.sample_kwargs) self.sample_predictive_distribution(ppc_sampler=ppc_sampler) @@ -926,6 +993,7 @@ def fit_outcome_model( normal_outcome: bool = True, spline_component: bool = False, winsorize_boundary: float = 0.0, + spline_knots: int = 30, ) -> tuple[az.InferenceData, pm.Model]: """ Fit a Bayesian outcome model using covariates and previously estimated propensity scores. @@ -966,6 +1034,9 @@ def fit_outcome_model( If we wish to winsorize the propensity score this can be set to clip the high and low values of the propensity at 0 + winsorize_boundary and 1-winsorize_boundary + spline_knots: int, default 30 + The number of knots we use in the 0 - 1 interval to create our spline function + Returns ------- idata_outcome : arviz.InferenceData @@ -1029,11 +1100,11 @@ class initialisation. "beta_ps_spline", priors["beta_ps"][0], priors["beta_ps"][1], - size=34, + size=spline_knots + 4, ) B = dmatrix( "bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1", - {"ps": p, "knots": np.linspace(0, 1, 30)}, + {"ps": p, "knots": np.linspace(0, 1, spline_knots)}, ) B_f = np.asarray(B, order="F") splines_summed = pm.Deterministic( diff --git a/causalpy/tests/test_integration_pymc_examples.py b/causalpy/tests/test_integration_pymc_examples.py index 00068507..6675a298 100644 --- a/causalpy/tests/test_integration_pymc_examples.py +++ b/causalpy/tests/test_integration_pymc_examples.py @@ -682,6 +682,116 @@ def test_iv_reg(mock_pymc_sample): result.get_plot_data() +@pytest.mark.integration +def test_iv_binary_treatment(mock_pymc_sample): + df = cp.load_data("risk") + df["binary_trt"] = np.random.binomial(1, 0.5, len(df)) + instruments_formula = "binary_trt ~ 1 + risk + logmort0" + formula = "loggdp ~ 1 + binary_trt + risk" + instruments_data = df[["risk", "logmort0", "binary_trt"]] + data = df[["loggdp", "risk", "binary_trt"]] + + result = cp.InstrumentalVariable( + instruments_data=instruments_data, + data=data, + instruments_formula=instruments_formula, + formula=formula, + model=cp.pymc_models.InstrumentalVariableRegression( + sample_kwargs=sample_kwargs + ), + binary_treatment=True, + ) + result.model.sample_predictive_distribution(ppc_sampler="pymc") + assert isinstance(df, pd.DataFrame) + assert isinstance(data, pd.DataFrame) + assert isinstance(instruments_data, pd.DataFrame) + assert isinstance(result, cp.InstrumentalVariable) + assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"] + assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"] + with pytest.raises(NotImplementedError): + result.get_plot_data() + assert "rho" in result.model.named_vars + + +@pytest.mark.integration +def test_iv_reg_vs_prior(mock_pymc_sample): + df = cp.load_data("risk") + instruments_formula = "risk ~ 1 + logmort0" + formula = "loggdp ~ 1 + risk" + instruments_data = df[["risk", "logmort0"]] + data = df[["loggdp", "risk"]] + + result = cp.InstrumentalVariable( + instruments_data=instruments_data, + data=data, + instruments_formula=instruments_formula, + formula=formula, + model=cp.pymc_models.InstrumentalVariableRegression( + sample_kwargs=sample_kwargs + ), + vs_prior_type="spike_and_slab", + vs_hyperparams={"pi_alpha": 5}, + ) + result.model.sample_predictive_distribution(ppc_sampler="pymc") + assert isinstance(df, pd.DataFrame) + assert isinstance(data, pd.DataFrame) + assert isinstance(instruments_data, pd.DataFrame) + assert isinstance(result, cp.InstrumentalVariable) + assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"] + assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"] + with pytest.raises(NotImplementedError): + result.get_plot_data() + assert "gamma_beta_t" in result.model.named_vars + assert "pi_beta_t" in result.model.named_vars + summary = result.model.vs_prior_outcome.get_inclusion_probabilities( + result.idata, "beta_z" + ) + assert isinstance(summary, pd.DataFrame) + with pytest.raises(ValueError): + summary = result.model.vs_prior_outcome.get_shrinkage_factors( + result.idata, "beta_z" + ) + + +@pytest.mark.integration +def test_iv_reg_vs_prior_hs(mock_pymc_sample): + df = cp.load_data("risk") + instruments_formula = "risk ~ 1 + logmort0" + formula = "loggdp ~ 1 + risk" + instruments_data = df[["risk", "logmort0"]] + data = df[["loggdp", "risk"]] + + result = cp.InstrumentalVariable( + instruments_data=instruments_data, + data=data, + instruments_formula=instruments_formula, + formula=formula, + model=cp.pymc_models.InstrumentalVariableRegression( + sample_kwargs=sample_kwargs + ), + vs_prior_type="horseshoe", + ) + result.model.sample_predictive_distribution(ppc_sampler="pymc") + assert isinstance(df, pd.DataFrame) + assert isinstance(data, pd.DataFrame) + assert isinstance(instruments_data, pd.DataFrame) + assert isinstance(result, cp.InstrumentalVariable) + assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"] + assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"] + with pytest.raises(NotImplementedError): + result.get_plot_data() + assert "tau_beta_t" in result.model.named_vars + assert "tau_beta_z" in result.model.named_vars + summary = result.model.vs_prior_outcome.get_shrinkage_factors( + result.idata, "beta_z" + ) + assert isinstance(summary, pd.DataFrame) + with pytest.raises(ValueError): + summary = result.model.vs_prior_outcome.get_inclusion_probabilities( + result.idata, "beta_z" + ) + + @pytest.mark.integration def test_inverse_prop(mock_pymc_sample): """Test the InversePropensityWeighting class.""" diff --git a/causalpy/tests/test_variable_selection_priors.py b/causalpy/tests/test_variable_selection_priors.py new file mode 100644 index 00000000..1b464be6 --- /dev/null +++ b/causalpy/tests/test_variable_selection_priors.py @@ -0,0 +1,125 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pymc as pm +import pytest + +from causalpy.variable_selection_priors import ( + HorseshoePrior, + SpikeAndSlabPrior, + VariableSelectionPrior, + create_variable_selection_prior, +) + + +@pytest.fixture +def sample_data(): + """Generate sample design matrix for testing.""" + rng = np.random.default_rng(42) + n_obs = 100 + n_features = 5 + X = rng.normal(size=(n_obs, n_features)) + return X + + +@pytest.fixture +def coords(): + """Generate sample coordinates for PyMC models.""" + return {"features": [f"x_{i}" for i in range(5)]} + + +def test_create_variable_in_model_context(coords): + """Test that create_variable works in PyMC model context.""" + prior = SpikeAndSlabPrior(dims="features") + + with pm.Model(coords=coords) as model: + beta = prior.create_variable("beta") + + # Check that beta was created + assert "beta" in model.named_vars + assert beta.name == "beta" + + # Check that intermediate variables were created + assert "pi_beta" in model.named_vars + assert "beta_raw" in model.named_vars + assert "gamma_beta" in model.named_vars + + +def test_create_variable_in_model_context_horseshoe(coords): + """Test that create_variable works in PyMC model context.""" + prior = HorseshoePrior(dims="features") + + with pm.Model(coords=coords) as model: + beta = prior.create_variable("beta") + + # Check that beta was created + assert "beta" in model.named_vars + assert beta.name == "beta" + + # Check that intermediate variables were created + assert "tau_beta" in model.named_vars + assert "lambda_beta" in model.named_vars + assert "c2_beta" in model.named_vars + assert "lambda_tilde_beta" in model.named_vars + assert "beta_raw" in model.named_vars + + +def test_create_prior_spike_and_slab(coords): + """Test create_prior for spike-and-slab.""" + vs_prior = VariableSelectionPrior("spike_and_slab", hyperparams={"pi_alpha": 5}) + + with pm.Model(coords=coords) as model: + beta = vs_prior.create_prior(name="beta", n_params=5, dims="features") + + assert "beta" in model.named_vars + assert beta.name == "beta" + + +def test_create_prior_horseshoe(coords, sample_data): + """Test create_prior for horseshoe.""" + vs_prior = VariableSelectionPrior("horseshoe") + + with pm.Model(coords=coords) as model: + beta = vs_prior.create_prior( + name="beta", n_params=5, dims="features", X=sample_data + ) + + assert "beta" in model.named_vars + assert beta.name == "beta" + + +def test_create_prior_normal(coords, sample_data): + """Test create_prior for horseshoe.""" + vs_prior = VariableSelectionPrior("normal") + + with pm.Model(coords=coords) as model: + beta = vs_prior.create_prior(name="beta", n_params=5, dims="features") + + assert "beta" in model.named_vars + assert beta.name == "beta" + + +def test_convenience_function_with_custom_hyperparams(coords): + """Test convenience function with custom hyperparameters.""" + with pm.Model(coords=coords) as model: + _ = create_variable_selection_prior( + prior_type="spike_and_slab", + name="beta", + n_params=5, + dims="features", + hyperparams={"slab_sigma": 5}, + ) + + assert "beta" in model.named_vars diff --git a/causalpy/variable_selection_priors.py b/causalpy/variable_selection_priors.py new file mode 100644 index 00000000..8c8b2001 --- /dev/null +++ b/causalpy/variable_selection_priors.py @@ -0,0 +1,594 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic variable selection priors for PyMC models using pymc-extras Prior class. + +This module provides reusable prior specifications that can be applied to any +PyMC model with coefficient vectors (beta parameters). Supports spike-and-slab +and horseshoe priors for automatic variable selection and shrinkage, built on +top of the pymc-extras Prior infrastructure. +""" + +from typing import Any, Dict, Optional, Union + +import numpy as np +import pandas as pd +import pymc as pm +import pytensor.tensor as pt +from pymc_extras.prior import Prior + + +def _relaxed_bernoulli_transform( + p: Union[float, pt.TensorVariable], temperature: float = 0.1 +): + """ + Transform function for relaxed (continuous) Bernoulli distribution. + + This provides a continuous approximation to a Bernoulli distribution, + useful for gradient-based inference. As temperature → 0, this approaches + a true binary distribution. + + Parameters + ---------- + p : float or PyMC variable + Probability parameter. + temperature : float, default=0.1 + Temperature parameter (lower = more binary). + + Returns + ------- + function + Transform function that takes uniform random variable. + """ + + def transform(u): + logit_p = pt.log(p) - pt.log(1 - p) + return pm.math.sigmoid((logit_p + pt.log(u) - pt.log(1 - u)) / temperature) + + return transform + + +class SpikeAndSlabPrior: + """ + Spike-and-slab prior using pymc-extras Prior class. + + Creates a mixture prior with a point mass at zero (spike) and a diffuse + normal distribution (slab), implemented as: + + .. math:: + \beta_{j} = \gamma_{j} \cdot \beta_{j}^{\text{raw}} \\ + \beta_{j}^{\text{raw}} \sim \mathcal{N}(0, \sigma_{\text{slab}}^{2}), \qquad + \gamma_{j} \in [0,1]. + + Parameters + ---------- + pi_alpha : float, default=2 + Beta prior alpha for selection probability + pi_beta : float, default=2 + Beta prior beta for selection probability + slab_sigma : float, default=2 + Standard deviation of slab (non-zero) component + temperature : float, default=0.1 + Relaxation parameter for binary approximation (lower = more binary) + dims : str or tuple, optional + Dimension names for the coefficient vector + + Example + ------- + >>> import pymc as pm + >>> from causalpy.variable_selection_priors import SpikeAndSlabPrior + >>> spike_slab = SpikeAndSlabPrior(dims="features") + >>> coords = {"features": ["a", "b", "c", "d", "e"]} + >>> with pm.Model(coords=coords) as model: + ... beta = spike_slab.create_variable("beta") + """ + + def __init__( + self, + pi_alpha: float = 2, + pi_beta: float = 2, + slab_sigma: float = 2, + temperature: float = 0.1, + dims: Optional[Union[str, tuple]] = None, + ): + self.pi_alpha = pi_alpha + self.pi_beta = pi_beta + self.slab_sigma = slab_sigma + self.temperature = temperature + self.dims = dims if isinstance(dims, tuple) or dims is None else (dims,) + + def create_variable(self, name: str) -> pm.Deterministic: + """ + Create spike-and-slab variable. + + Parameters + ---------- + name : str + Name for the coefficient vector + + Returns + ------- + pm.Deterministic + Coefficient vector with spike-and-slab prior + """ + # Selection probability using Prior class + pi_prior = Prior("Beta", alpha=self.pi_alpha, beta=self.pi_beta) + pi = pi_prior.create_variable(f"pi_{name}") + + # Raw coefficients (slab component) using Prior class + slab_prior = Prior("Normal", mu=0, sigma=self.slab_sigma, dims=self.dims) + beta_raw = slab_prior.create_variable(f"{name}_raw") + + # Selection indicators using relaxed Bernoulli + # We use Uniform and transform it + u = pm.Uniform(f"gamma_{name}_u", 0, 1, dims=self.dims) + transform_fn = _relaxed_bernoulli_transform(pi, self.temperature) + gamma = pm.Deterministic(f"gamma_{name}", transform_fn(u), dims=self.dims) + + # Actual coefficients + return pm.Deterministic(name, gamma * beta_raw, dims=self.dims) + + +class HorseshoePrior: + """ + Regularized horseshoe prior using pymc-extras Prior class. + + Provides continuous shrinkage with heavy tails, allowing strong signals + to escape shrinkage while weak signals are dampened: + + .. math:: + \beta_{j} & = \tau \cdot \lambda_{j} \cdot \beta_{j}^{raw} \\ + \lambda_{j} & = \sqrt{ \dfrac{c^{2}\lambda_{j}^{2}}{c^{2} + \tau^{2}\lambda_{j}^{2}} } + + Parameters + ---------- + tau0 : float, optional + Global shrinkage parameter. If None, computed from data. + nu : float, default=3 + Degrees of freedom for half-t prior on tau + c2_alpha : float, default=2 + InverseGamma alpha for regularization parameter + c2_beta : float, default=2 + InverseGamma beta for regularization parameter + dims : str or tuple, optional + Dimension names for the coefficient vector + + Example + ------- + >>> import pymc as pm + >>> from causalpy.variable_selection_priors import HorseshoePrior + >>> horseshoe = HorseshoePrior(dims="features") + >>> coords = {"features": ["a", "b", "c", "d", "e"]} + >>> with pm.Model(coords=coords) as model: + ... beta = horseshoe.create_variable("beta") + """ + + def __init__( + self, + tau0: Optional[float] = None, + nu: float = 3, + c2_alpha: float = 2, + c2_beta: float = 2, + dims: Optional[Union[str, tuple]] = None, + ): + self.tau0 = tau0 + self.nu = nu + self.c2_alpha = c2_alpha + self.c2_beta = c2_beta + self.dims = dims if isinstance(dims, tuple) or dims is None else (dims,) + + def create_variable(self, name: str) -> pm.Deterministic: + """ + Create horseshoe variable. + + Parameters + ---------- + name : str + Name for the coefficient vector + + Returns + ------- + pm.Deterministic + Coefficient vector with horseshoe prior + """ + # Global shrinkage using Prior class + tau_prior = Prior("HalfStudentT", nu=self.nu, sigma=self.tau0 or 1.0) + tau = tau_prior.create_variable(f"tau_{name}") + + # Local shrinkage parameters using Prior class + lambda_prior = Prior("HalfCauchy", beta=1.0, dims=self.dims) + lambda_ = lambda_prior.create_variable(f"lambda_{name}") + + # Regularization parameter using Prior class + c2_prior = Prior("InverseGamma", alpha=self.c2_alpha, beta=self.c2_beta) + c2 = c2_prior.create_variable(f"c2_{name}") + + # Regularized local shrinkage + lambda_tilde = pm.Deterministic( + f"lambda_tilde_{name}", + pm.math.sqrt(c2 * lambda_**2 / (c2 + tau**2 * lambda_**2)), + dims=self.dims, + ) + + # Raw coefficients using Prior class + raw_prior = Prior("Normal", mu=0, sigma=1, dims=self.dims) + beta_raw = raw_prior.create_variable(f"{name}_raw") + + # Actual coefficients + return pm.Deterministic(name, beta_raw * lambda_tilde * tau, dims=self.dims) + + +class VariableSelectionPrior: + """ + Factory for creating variable selection priors on coefficient vectors. + + This class provides a unified interface for different types of variable + selection priors that can be applied to any beta coefficient in a PyMC model. + Built on top of pymc-extras Prior class for consistency and interoperability. + + Supported prior types: + - 'spike_and_slab': Mixture prior with near-zero spike and diffuse slab + - 'horseshoe': Continuous shrinkage with adaptive regularization + - 'normal': Standard normal prior (no selection, for comparison) + + Parameters + ---------- + prior_type : str + Type of prior: 'spike_and_slab', 'horseshoe', or 'normal' + hyperparams : dict, optional + Hyperparameters specific to the chosen prior type. If None, defaults are used. + + For 'spike_and_slab': + - pi_alpha: float (default=2) - Beta prior alpha for selection probability + - pi_beta: float (default=2) - Beta prior beta for selection probability + - slab_sigma: float (default=2) - SD of slab (non-zero) component + - temperature: float (default=0.1) - Relaxation parameter for binary approximation + + For 'horseshoe': + - tau0: float (default=None) - Global shrinkage, auto-computed if None + - nu: float (default=3) - Degrees of freedom for half-t prior on tau + - c2_alpha: float (default=2) - InverseGamma alpha for regularization + - c2_beta: float (default=2) - InverseGamma beta for regularization + + For 'normal': + - mu: float or array (default=0) - Prior mean + - sigma: float or array (default=1) - Prior SD + + Example + ------- + >>> import pymc as pm + >>> from causalpy.variable_selection_priors import VariableSelectionPrior + >>> # Create spike-and-slab prior + >>> vs_prior = VariableSelectionPrior("spike_and_slab") + >>> coords = {"features": ["a", "b", "c", "d", "e"]} + >>> with pm.Model(coords=coords) as model: + ... # Create coefficients with variable selection + ... beta = vs_prior.create_prior(name="beta", n_params=5, dims="features") + """ + + def __init__(self, prior_type: str, hyperparams: Optional[Dict[str, Any]] = None): + """Initialize the variable selection prior factory.""" + self.prior_type = prior_type.lower() + self.hyperparams = hyperparams or {} + + if self.prior_type not in ["spike_and_slab", "horseshoe", "normal"]: + raise ValueError( + f"Unknown prior_type: {prior_type}. " + "Must be 'spike_and_slab', 'horseshoe', or 'normal'" + ) + + # Will be set when create_prior is called + self._prior_instance = None + + def _get_default_hyperparams( + self, n_params: int, X: Optional[np.ndarray] = None + ) -> Dict[str, Any]: + """ + Get default hyperparameters for the chosen prior type. + + Parameters + ---------- + n_params : int + Number of parameters (dimension of beta vector) + X : array-like, optional + Design matrix for computing data-adaptive defaults (horseshoe only) + + Returns + ------- + dict + Default hyperparameters + """ + if self.prior_type == "spike_and_slab": + return { + "pi_alpha": 2, + "pi_beta": 2, + "slab_sigma": 2, + "temperature": 0.1, + } + + elif self.prior_type == "horseshoe": + # Compute tau0 using rule of thumb from Piironen & Vehtari (2017) + if X is not None: + p = n_params + p0 = min(5.0, p / 2) # Expected number of nonzero coefficients + sigma_est = 1.0 + n = X.shape[0] + tau0 = (p0 / (p - p0)) * (sigma_est / np.sqrt(n)) + else: + # Fallback if no data provided + tau0 = 1.0 / np.sqrt(n_params) + + return { + "tau0": tau0, + "nu": 3, + "c2_alpha": 2, + "c2_beta": 2, + } + + else: # normal + return { + "mu": 0, + "sigma": 1, + } + + def create_prior( + self, + name: str, + n_params: int, + dims: Optional[Union[str, tuple]] = None, + X: Optional[np.ndarray] = None, + hyperparams: Optional[Dict[str, Any]] = None, + ) -> Union[pm.Deterministic, pm.Distribution]: + """ + Create the specified prior on a coefficient vector. + + This is the main method to use. It creates the appropriate prior type + based on the configuration and returns the PyMC variable. + + Parameters + ---------- + name : str + Name for the coefficient vector (e.g., 'beta', 'b', 'coef') + n_params : int + Number of parameters (length of coefficient vector) + dims : str or tuple, optional + Dimension name(s) for the coefficient vector + X : array-like, optional + Design matrix for computing data-adaptive hyperparameters + (used only for horseshoe priors) + hyperparams : dict, optional + Override default hyperparameters for this specific prior instance + + Returns + ------- + PyMC variable + The coefficient vector with the specified prior + + Example + ------- + >>> import pymc as pm + >>> import pandas as pd + >>> from causalpy.variable_selection_priors import VariableSelectionPrior + >>> vs_prior = VariableSelectionPrior("spike_and_slab") + >>> coords = {"features": ["a", "b", "c", "d", "e"]} + >>> with pm.Model(coords=coords) as model: + ... beta = vs_prior.create_prior("beta", n_params=4, dims="features") + """ + # Merge instance and call-specific hyperparameters + default_hp = self._get_default_hyperparams(n_params, X) + merged_hp = {**default_hp, **self.hyperparams} + if hyperparams: + merged_hp.update(hyperparams) + + # Normalize dims + if isinstance(dims, str): + dims = (dims,) + + # Create the appropriate prior + if self.prior_type == "spike_and_slab": + self._prior_instance = SpikeAndSlabPrior( + pi_alpha=merged_hp["pi_alpha"], + pi_beta=merged_hp["pi_beta"], + slab_sigma=merged_hp["slab_sigma"], + temperature=merged_hp["temperature"], + dims=dims, + ) # type: ignore[assignment] + return self._prior_instance.create_variable(name) # type: ignore[attr-defined] + + elif self.prior_type == "horseshoe": + self._prior_instance = HorseshoePrior( + tau0=merged_hp["tau0"], + nu=merged_hp["nu"], + c2_alpha=merged_hp["c2_alpha"], + c2_beta=merged_hp["c2_beta"], + dims=dims, + ) # type: ignore[assignment] + return self._prior_instance.create_variable(name) # type: ignore[attr-defined] + + else: # normal + # Use Prior class directly for normal + normal_prior = Prior( + "Normal", mu=merged_hp["mu"], sigma=merged_hp["sigma"], dims=dims + ) + return normal_prior.create_variable(name) + + def get_inclusion_probabilities( + self, idata, param_name: str, threshold: float = 0.5 + ) -> pd.DataFrame: + """ + Extract variable inclusion probabilities from fitted model. + + Only applicable for spike-and-slab priors. Returns the posterior + probability that each coefficient is "selected" (non-zero). + + Parameters + ---------- + idata : arviz.InferenceData + Fitted model inference data + param_name : str + Name of the coefficient parameter (must match name in create_prior) + threshold : float, default=0.5 + Threshold for considering a variable "selected" + + Returns + ------- + dict + Dictionary with keys: + - 'probabilities': Array of inclusion probabilities per coefficient + - 'selected': Boolean array indicating which are selected + - 'gamma_mean': Mean of gamma (indicator) variables + + Raises + ------ + ValueError + If prior_type is not 'spike_and_slab' or gamma variables not found + + """ + if self.prior_type != "spike_and_slab": + raise ValueError( + "Inclusion probabilities only available for 'spike_and_slab' priors" + ) + + gamma_name = f"gamma_{param_name}" + + if gamma_name not in idata.posterior: + raise ValueError( + f"Could not find '{gamma_name}' in posterior. " + f"Make sure you used the correct parameter name." + ) + + import arviz as az + + # Extract gamma values + gamma = az.extract(idata.posterior[gamma_name]) + + # Compute inclusion probabilities + probabilities = (gamma > threshold).mean(dim="sample").to_array() + gamma_mean = gamma.mean(dim="sample").to_array() + selected = probabilities > threshold + + summary = { + "probabilities": probabilities, + "selected": selected, + "gamma_mean": gamma_mean, + } + probs = summary["probabilities"].T + df = pd.DataFrame(index=list(range(len(probs)))) + + df["prob"] = probs + df["selected"] = summary["selected"].T + df["gamma_mean"] = summary["gamma_mean"].T + return df + + def get_shrinkage_factors(self, idata, param_name: str) -> pd.DataFrame: + """ + Extract shrinkage factors from horseshoe prior. + + Only applicable for horseshoe priors. Returns the effective shrinkage + applied to each coefficient: κ_j = τ · λ̃_j + + Parameters + ---------- + idata : arviz.InferenceData + Fitted model inference data + param_name : str + Name of the coefficient parameter + + Returns + ------- + dict + Dictionary with keys: + - 'shrinkage_factors': Array of shrinkage factors per coefficient + - 'tau': Global shrinkage parameter + - 'lambda_tilde': Regularized local shrinkage parameters + + Raises + ------ + ValueError + If prior_type is not 'horseshoe' or required variables not found + + """ + if self.prior_type != "horseshoe": + raise ValueError("Shrinkage factors only available for 'horseshoe' priors") + + import arviz as az + + tau_name = f"tau_{param_name}" + lambda_tilde_name = f"lambda_tilde_{param_name}" + + if tau_name not in idata.posterior: + raise ValueError(f"Could not find '{tau_name}' in posterior") + if lambda_tilde_name not in idata.posterior: + raise ValueError(f"Could not find '{lambda_tilde_name}' in posterior") + + # Extract components + tau = az.extract(idata.posterior[tau_name]).to_array() + lambda_tilde = az.extract(idata.posterior[lambda_tilde_name]).to_array() + + shrinkage_factor = np.array( + [tau[0, i] * lambda_tilde[0, :, :] for i in range(len(tau))] + ) + shrinkage_factor = shrinkage_factor.mean(axis=2) + + summary = { + "shrinkage_factors": shrinkage_factor, + "tau": tau.mean(), + "lambda_tilde": lambda_tilde.mean(dim=("sample")), + } + probs = summary["shrinkage_factors"].T + df = pd.DataFrame(index=list(range(len(probs)))) + df["shrinkage_factor"] = probs + + df["lambda_tilde"] = summary["lambda_tilde"].T + df["tau"] = np.mean(tau).item() + return df + + +def create_variable_selection_prior( + prior_type: str, + name: str, + n_params: int, + dims: Optional[Union[str, tuple]] = None, + X: Optional[np.ndarray] = None, + hyperparams: Optional[Dict[str, Any]] = None, +) -> Union[pm.Deterministic, pm.Distribution]: + """ + Convenience function to create a variable selection prior in one call. + + This is a shorthand for creating a VariableSelectionPrior instance and + calling create_prior() in one step. + + Parameters + ---------- + prior_type : str + Type of prior: 'spike_and_slab', 'horseshoe', or 'normal' + name : str + Name for the coefficient vector + n_params : int + Number of parameters + dims : str or tuple, optional + Dimension name(s) + X : array-like, optional + Design matrix for data-adaptive hyperparameters + hyperparams : dict, optional + Custom hyperparameters + + Returns + ------- + PyMC variable + The coefficient vector with specified prior + + """ + vs_prior = VariableSelectionPrior(prior_type, hyperparams) + return vs_prior.create_prior(name, n_params, dims, X) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4704ef6c..4ed4f3af 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 95.5% + interrogate: 94.6% - + @@ -12,8 +12,8 @@ interrogate interrogate - 95.5% - 95.5% + 94.6% + 94.6% diff --git a/docs/source/notebooks/index.md b/docs/source/notebooks/index.md index c9ae0ad7..be879e33 100644 --- a/docs/source/notebooks/index.md +++ b/docs/source/notebooks/index.md @@ -65,6 +65,7 @@ rkink_pymc.ipynb iv_pymc.ipynb iv_weak_instruments.ipynb +iv_vs_priors.ipynb ::: :::{toctree} diff --git a/docs/source/notebooks/iv_vs_priors.ipynb b/docs/source/notebooks/iv_vs_priors.ipynb new file mode 100644 index 00000000..a3faa0c1 --- /dev/null +++ b/docs/source/notebooks/iv_vs_priors.ipynb @@ -0,0 +1,2219 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "532c6736", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + } + ], + "source": [ + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pymc as pm\n", + "\n", + "import causalpy as cp\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "b1b3aa75", + "metadata": {}, + "source": [ + "## Variable Selection Priors and Instrumental Variable Designs\n", + "\n", + "When building causal inference models, we often face a dilemma: we want to control for confounders to get unbiased causal estimates, but we're not always certain which variables are the true confounders. Include too few, and we risk omitted variable bias. Include too many, and we introduce noise that inflates our uncertainty or, worse, creates multicollinearity that destabilizes our estimates.\n", + "\n", + "Traditional approaches force us to make hard choices upfront—which variables to include, which to exclude. This in ideal cases should be driven by theory. But what if we could let the data help us make these decisions while still maintaining the principled probabilistic framework of Bayesian inference? This is where variable selection priors come in. Let's first simulate some data with some natural confounding structure. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "046aa8e0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Y_contY_binT_contT_binalphafeature_0feature_1feature_2feature_3feature_4...feature_8feature_9feature_10feature_11feature_12feature_13Y_cont_scaledY_bin_scaledT_cont_scaledT_bin_scaled
03.169837-0.1703461.11339409.2361021.2944410.4182410.536286-0.615573-1.173784...0.5593931.111766-0.2160690.451496-0.8631890.3191800.268475-0.4383620.347809-1.022452
110.4790496.6629902.27202013.787487-0.8850050.3150130.8101381.1372140.203685...-0.0171790.4108180.6700741.9549440.888255-0.5065180.9166371.3210110.7260140.977650
27.3078214.1189822.06294613.311157-1.0755371.0585361.9089441.1229140.611691...0.0506411.245489-1.070642-0.060250-1.8576380.8069130.6354210.6660080.6577670.977650
39.7813600.6496474.04390418.4234500.9693800.6981990.3144191.446987-2.729092...-1.0524291.143079-1.757701-1.167276-0.1646201.6870040.854768-0.2272391.3044020.977650
45.7392836.8120300.64241716.6139220.245569-0.3384910.814305-0.8597980.334968...-0.5476220.7787240.6789561.715229-0.439130-0.2388470.4963271.3593850.1940700.977650
..................................................................
2495-5.099912-0.870746-1.40972206.5656300.2262521.5896530.056005-0.386026-0.462251...0.9876330.246870-0.2029170.1785790.7631860.527462-0.464865-0.618693-0.475800-1.022452
2496-32.742858-7.337551-8.46843504.760520-0.4957920.5460020.2090720.6666140.400847...-0.798775-0.6164830.4315521.2389570.957759-0.583051-2.916171-2.283696-2.779944-1.022452
24976.7598041.9120401.615921010.4459341.7783730.097808-0.8076580.380358-0.455391...0.6680401.4719630.573966-0.288768-0.8610250.3726570.5868240.0977880.511847-1.022452
2498-11.249395-1.938808-3.10352905.715321-0.1138720.7474801.635159-1.136585-0.007239...-2.404202-2.0745700.0228781.3450180.7053610.414329-1.010186-0.893686-1.028702-1.022452
249921.6582587.6754015.66095215.741192-0.103523-1.0838280.8968270.1462431.363973...-1.310958-0.0042240.2066200.012787-1.7777830.3401761.9079811.5816761.8322470.977650
\n", + "

2500 rows × 23 columns

\n", + "
" + ], + "text/plain": [ + " Y_cont Y_bin T_cont T_bin alpha feature_0 feature_1 \\\n", + "0 3.169837 -0.170346 1.113394 0 9.236102 1.294441 0.418241 \n", + "1 10.479049 6.662990 2.272020 1 3.787487 -0.885005 0.315013 \n", + "2 7.307821 4.118982 2.062946 1 3.311157 -1.075537 1.058536 \n", + "3 9.781360 0.649647 4.043904 1 8.423450 0.969380 0.698199 \n", + "4 5.739283 6.812030 0.642417 1 6.613922 0.245569 -0.338491 \n", + "... ... ... ... ... ... ... ... \n", + "2495 -5.099912 -0.870746 -1.409722 0 6.565630 0.226252 1.589653 \n", + "2496 -32.742858 -7.337551 -8.468435 0 4.760520 -0.495792 0.546002 \n", + "2497 6.759804 1.912040 1.615921 0 10.445934 1.778373 0.097808 \n", + "2498 -11.249395 -1.938808 -3.103529 0 5.715321 -0.113872 0.747480 \n", + "2499 21.658258 7.675401 5.660952 1 5.741192 -0.103523 -1.083828 \n", + "\n", + " feature_2 feature_3 feature_4 ... feature_8 feature_9 feature_10 \\\n", + "0 0.536286 -0.615573 -1.173784 ... 0.559393 1.111766 -0.216069 \n", + "1 0.810138 1.137214 0.203685 ... -0.017179 0.410818 0.670074 \n", + "2 1.908944 1.122914 0.611691 ... 0.050641 1.245489 -1.070642 \n", + "3 0.314419 1.446987 -2.729092 ... -1.052429 1.143079 -1.757701 \n", + "4 0.814305 -0.859798 0.334968 ... -0.547622 0.778724 0.678956 \n", + "... ... ... ... ... ... ... ... \n", + "2495 0.056005 -0.386026 -0.462251 ... 0.987633 0.246870 -0.202917 \n", + "2496 0.209072 0.666614 0.400847 ... -0.798775 -0.616483 0.431552 \n", + "2497 -0.807658 0.380358 -0.455391 ... 0.668040 1.471963 0.573966 \n", + "2498 1.635159 -1.136585 -0.007239 ... -2.404202 -2.074570 0.022878 \n", + "2499 0.896827 0.146243 1.363973 ... -1.310958 -0.004224 0.206620 \n", + "\n", + " feature_11 feature_12 feature_13 Y_cont_scaled Y_bin_scaled \\\n", + "0 0.451496 -0.863189 0.319180 0.268475 -0.438362 \n", + "1 1.954944 0.888255 -0.506518 0.916637 1.321011 \n", + "2 -0.060250 -1.857638 0.806913 0.635421 0.666008 \n", + "3 -1.167276 -0.164620 1.687004 0.854768 -0.227239 \n", + "4 1.715229 -0.439130 -0.238847 0.496327 1.359385 \n", + "... ... ... ... ... ... \n", + "2495 0.178579 0.763186 0.527462 -0.464865 -0.618693 \n", + "2496 1.238957 0.957759 -0.583051 -2.916171 -2.283696 \n", + "2497 -0.288768 -0.861025 0.372657 0.586824 0.097788 \n", + "2498 1.345018 0.705361 0.414329 -1.010186 -0.893686 \n", + "2499 0.012787 -1.777783 0.340176 1.907981 1.581676 \n", + "\n", + " T_cont_scaled T_bin_scaled \n", + "0 0.347809 -1.022452 \n", + "1 0.726014 0.977650 \n", + "2 0.657767 0.977650 \n", + "3 1.304402 0.977650 \n", + "4 0.194070 0.977650 \n", + "... ... ... \n", + "2495 -0.475800 -1.022452 \n", + "2496 -2.779944 -1.022452 \n", + "2497 0.511847 -1.022452 \n", + "2498 -1.028702 -1.022452 \n", + "2499 1.832247 0.977650 \n", + "\n", + "[2500 rows x 23 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def inv_logit(z):\n", + " \"\"\"Compute the inverse logit (sigmoid) of z.\"\"\"\n", + " return 1 / (1 + np.exp(-z))\n", + "\n", + "\n", + "def simulate_data(n=2500, alpha_true=3.0, rho=0.6, cate_estimation=False):\n", + " # Exclusion restrictions:\n", + " # X[0], X[1] affect both Y and T (confounders)\n", + " # X[2], X[3] affect ONLY T (instruments for T)\n", + " # X[4] affects ONLY Y (predictor of Y only)\n", + "\n", + " betaY = np.array(\n", + " [0.5, -0.3, 0.0, 0.0, 0.4, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", + " ) # X[2], X[3] excluded\n", + " betaD = np.array(\n", + " [0.7, 0.1, -0.4, 0.3, 0.0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", + " ) # X[4] excluded\n", + " p = len(betaY)\n", + "\n", + " # noise variances and correlation\n", + " sigma_U = 3.0\n", + " sigma_V = 3.0\n", + "\n", + " # design matrix (n × p) with mean-zero columns\n", + " X = np.random.normal(size=(n, p))\n", + " X = (X - X.mean(axis=0)) / X.std(axis=0)\n", + "\n", + " mean = [0, 0]\n", + " cov = [[sigma_U**2, rho * sigma_U * sigma_V], [rho * sigma_U * sigma_V, sigma_V**2]]\n", + " errors = np.random.multivariate_normal(mean, cov, size=n)\n", + " U = errors[:, 0] # error in outcome equation\n", + " V = errors[:, 1] #\n", + "\n", + " # continuous treatment\n", + " T_cont = X @ betaD + V\n", + "\n", + " # latent variable for binary treatment\n", + " T_latent = X @ betaD + V\n", + " T_bin = np.random.binomial(n=1, p=inv_logit(T_latent), size=n)\n", + "\n", + " alpha_individual = 3.0 + 2.5 * X[:, 0]\n", + "\n", + " # outcomes\n", + " Y_cont = alpha_true * T_cont + X @ betaY + U\n", + " if cate_estimation:\n", + " Y_bin = alpha_individual * T_bin + X @ betaY + U\n", + " else:\n", + " Y_bin = alpha_true * T_bin + X @ betaY + U\n", + "\n", + " # combine into DataFrame\n", + " data = pd.DataFrame(\n", + " {\n", + " \"Y_cont\": Y_cont,\n", + " \"Y_bin\": Y_bin,\n", + " \"T_cont\": T_cont,\n", + " \"T_bin\": T_bin,\n", + " }\n", + " )\n", + " data[\"alpha\"] = alpha_true + alpha_individual\n", + " for j in range(p):\n", + " data[f\"feature_{j}\"] = X[:, j]\n", + " data[\"Y_cont_scaled\"] = (data[\"Y_cont\"] - data[\"Y_cont\"].mean()) / data[\n", + " \"Y_cont\"\n", + " ].std(ddof=1)\n", + " data[\"Y_bin_scaled\"] = (data[\"Y_bin\"] - data[\"Y_bin\"].mean()) / data[\"Y_bin\"].std(\n", + " ddof=1\n", + " )\n", + " data[\"T_cont_scaled\"] = (data[\"T_cont\"] - data[\"T_cont\"].mean()) / data[\n", + " \"T_cont\"\n", + " ].std(ddof=1)\n", + " data[\"T_bin_scaled\"] = (data[\"T_bin\"] - data[\"T_bin\"].mean()) / data[\"T_bin\"].std(\n", + " ddof=1\n", + " )\n", + " return data\n", + "\n", + "\n", + "data = simulate_data()\n", + "instruments_data = data.copy()\n", + "features = [col for col in data.columns if \"feature\" in col]\n", + "formula = \"Y_cont ~ T_cont + \" + \" + \".join(features)\n", + "instruments_formula = \"T_cont ~ 1 + \" + \" + \".join(features)\n", + "data" + ] + }, + { + "cell_type": "markdown", + "id": "e2472e18", + "metadata": {}, + "source": [ + "CausalPy's `Variable Selection` module provides a way to encode our uncertainty about variable relevance directly into the prior distribution. Rather than choosing which predictors to include, we specify priors that allow coefficients to be shrunk toward zero (or exactly zero) when the data doesn't support their inclusion. The key insight is that variable selection becomes part of the inference problem rather than a preprocessing step. The module offers two fundamentally different approaches to variable selection, each reflecting a different belief about how sparsity manifests in the world.\n", + "\n", + "#### The Spike-and-Slab: Discrete Choices\n", + "\n", + "The spike-and-slab prior embodies a binary worldview: each variable either matters or it doesn't. Mathematically, we express this as:\n", + "\n", + "$$ \\beta_{j} = \\gamma_{j} \\cdot \\beta_{j_\\text{raw}}$$\n", + "\n", + "such that \n", + "\n", + "$$ \\gamma_{j} \\in \\{0, 1\\}$$\n", + "\n", + "So we have the \"spike\"—the coefficient is exactly zero. When $\\gamma_{j} = 1$, we have the \"slab\" i.e. the coefficient takes on a value from the raw distribution.\n", + "This approach appeals to our intuition about many real-world scenarios. Consider a propensity score model predicting whether someone receives a treatment. Some demographic variables might genuinely have no relationship with treatment assignment, while others are strongly predictive. The spike-and-slab says: let's let each variable clearly declare itself as relevant or irrelevant.\n", + "\n", + "#### The Regularised Horseshoe: Gentle Moderation\n", + "\n", + "The horseshoe prior takes a different philosophical stance. Instead of discrete selection, it says: effects exist on a continuum from negligible to substantial, and we should shrink them proportionally to their signal strength. Small effects get heavily shrunk (possibly to near-zero), while large effects escape shrinkage almost entirely.\n", + "\n", + "$$ \\beta_{j} = \\tau \\cdot \\tilde{\\lambda}_j \\cdot \\beta_{j\\text{raw}}$$\n", + "\n", + "where $\\tau$ is a global shrinkage parameter shared across all coefficients, and $\\tilde{\\lambda}_j$ is local or specific to each coefficient and regularised so as to ensure finite variance. \n" + ] + }, + { + "cell_type": "markdown", + "id": "806df6ea", + "metadata": {}, + "source": [ + "### Hyperparameters for Variable Selection Priors\n", + "\n", + "You can control the behaviour of the variable selection priors through some of the hyperparameters available. For the spike and slab prior, the most important hyperparamers are `temperature`, `pi_alpha`, and `pi_beta`. \n", + "\n", + "Because our sampler doesn't like discrete variables, we're approximating a bernoulli outcome in our sampling to define the spike and slab. The approximation is governed by the `temperature` parameter. The default value of 0.1 works well in most cases, creating indicators that cluster near 0 or 1 without causing sampling difficulties.\n", + "\n", + "The selection probability parameters `pi_alpha` and `pi_beta` encode your prior belief about sparsity. With both set to 2 (the default), you're placing a Beta(2,2) prior on π, the overall proportion of selected variables. This is symmetric around 0.5 but slightly concentrated there—you're saying \"I don't know how many variables are relevant, but probably not all of them and probably not none of them.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "ae848fe9", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(20, 6))\n", + "axs = axs.flatten()\n", + "axs[0].hist(pm.draw(pm.Beta.dist(2, 2), 1000), ec=\"black\", color=\"slateblue\")\n", + "axs[1].hist(pm.draw(pm.Beta.dist(2, 5), 1000), ec=\"black\", color=\"slateblue\")\n", + "axs[2].hist(pm.draw(pm.Beta.dist(5, 2), 1000), ec=\"black\", color=\"slateblue\")\n", + "axs[1].set_title(r\"Various Distributions for the $\\pi$ hyperparameter\", size=20);" + ] + }, + { + "cell_type": "markdown", + "id": "3237bb49", + "metadata": {}, + "source": [ + "We'll now fit two models and estimate the implied treatment effect." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "763ca253", + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--------------------------------------------------------------------------------\n", + "Model 1: Normal Priors (No Variable Selection)\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/causalpy/experiments/instrumental_variable.py:187: UserWarning: Warning. The treatment variable is not Binary.\n", + " This is not necessarily a problem but it violates\n", + " the assumption of a simple IV experiment.\n", + " The coefficients should be interpreted appropriately.\n", + " warnings.warn(\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [beta_t, beta_z, chol_cov]\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b332c2837f1849329b3b561a9765f3e5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 133 seconds.\n",
+      "The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n",
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/causalpy/experiments/instrumental_variable.py:187: UserWarning: Warning. The treatment variable is not Binary.\n",
+      "                This is not necessarily a problem but it violates\n",
+      "                the assumption of a simple IV experiment.\n",
+      "                The coefficients should be interpreted appropriately.\n",
+      "  warnings.warn(\n",
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/causalpy/pymc_models.py:699: UserWarning: Variable selection priors specified. The 'mus' and 'sigmas' in the priors dict will be ignored for beta coefficients. Only 'eta' and 'lkj_sd' will be used.\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "--------------------------------------------------------------------------------\n",
+      "Model 2: Spike-and-Slab Priors\n",
+      "--------------------------------------------------------------------------------\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Initializing NUTS using jitter+adapt_diag...\n",
+      "Multiprocess sampling (4 chains in 4 jobs)\n",
+      "NUTS: [pi_beta_t, beta_t_raw, gamma_beta_t_u, pi_beta_z, beta_z_raw, gamma_beta_z_u, chol_cov]\n",
+      "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n",
+      "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n",
+      "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n",
+      "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "db312f02efe743c386a4f2b4449f8904",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n",
+      "  outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n",
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n",
+      "  outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n",
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n",
+      "  outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n",
+      "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n",
+      "  outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 551 seconds.\n",
+      "There were 167 divergences after tuning. Increase `target_accept` or reparameterize.\n",
+      "The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
+      "The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
+     ]
+    }
+   ],
+   "source": [
+    "sample_kwargs = {\n",
+    "    \"draws\": 1000,\n",
+    "    \"tune\": 2000,\n",
+    "    \"chains\": 4,\n",
+    "    \"cores\": 4,\n",
+    "    \"target_accept\": 0.95,\n",
+    "    \"progressbar\": True,\n",
+    "    \"random_seed\": 42,\n",
+    "    \"mp_ctx\": \"spawn\",\n",
+    "}\n",
+    "\n",
+    "# =========================================================================\n",
+    "# Model 1: Normal priors (no selection)\n",
+    "# =========================================================================\n",
+    "print(\"\\n\" + \"-\" * 80)\n",
+    "print(\"Model 1: Normal Priors (No Variable Selection)\")\n",
+    "print(\"-\" * 80)\n",
+    "\n",
+    "result_normal = cp.InstrumentalVariable(\n",
+    "    instruments_data=instruments_data,\n",
+    "    data=data,\n",
+    "    instruments_formula=instruments_formula,\n",
+    "    formula=formula,\n",
+    "    model=cp.pymc_models.InstrumentalVariableRegression(sample_kwargs=sample_kwargs),\n",
+    "    vs_prior_type=None,  # No variable selection\n",
+    ")\n",
+    "\n",
+    "# =========================================================================\n",
+    "# Model 2: Spike-and-Slab priors\n",
+    "# =========================================================================\n",
+    "print(\"\\n\" + \"-\" * 80)\n",
+    "print(\"Model 2: Spike-and-Slab Priors\")\n",
+    "print(\"-\" * 80)\n",
+    "\n",
+    "result_spike_slab = cp.InstrumentalVariable(\n",
+    "    instruments_data=instruments_data,\n",
+    "    data=data,\n",
+    "    instruments_formula=instruments_formula,\n",
+    "    formula=formula,\n",
+    "    model=cp.pymc_models.InstrumentalVariableRegression(sample_kwargs=sample_kwargs),\n",
+    "    vs_prior_type=\"spike_and_slab\",\n",
+    "    vs_hyperparams={\"pi_alpha\": 2, \"pi_beta\": 2, \"slab_sigma\": 2, \"temperature\": 0.1},\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2ccb6e0b",
+   "metadata": {},
+   "source": [
+    "The models have quite a distinct structure. Compare the normal IV model with non variable selection priors. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "e97a9ca2",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/svg+xml": [
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "clusterinstruments (15)\n",
+       "\n",
+       "instruments (15)\n",
+       "\n",
+       "\n",
+       "clustercovariates (16)\n",
+       "\n",
+       "covariates (16)\n",
+       "\n",
+       "\n",
+       "cluster3\n",
+       "\n",
+       "3\n",
+       "\n",
+       "\n",
+       "cluster2 x 2\n",
+       "\n",
+       "2 x 2\n",
+       "\n",
+       "\n",
+       "cluster2\n",
+       "\n",
+       "2\n",
+       "\n",
+       "\n",
+       "cluster2500\n",
+       "\n",
+       "2500\n",
+       "\n",
+       "\n",
+       "cluster2500 x 2\n",
+       "\n",
+       "2500 x 2\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t\n",
+       "\n",
+       "beta_t\n",
+       "~\n",
+       "Normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_t\n",
+       "\n",
+       "mu_t\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t->mu_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z\n",
+       "\n",
+       "beta_z\n",
+       "~\n",
+       "Normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_y\n",
+       "\n",
+       "mu_y\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z->mu_y\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov\n",
+       "\n",
+       "chol_cov\n",
+       "~\n",
+       "_LKJCholeskyCov\n",
+       "\n",
+       "\n",
+       "\n",
+       "cov\n",
+       "\n",
+       "cov\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->cov\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov_corr\n",
+       "\n",
+       "chol_cov_corr\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->chol_cov_corr\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov_stds\n",
+       "\n",
+       "chol_cov_stds\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->chol_cov_stds\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "likelihood\n",
+       "\n",
+       "likelihood\n",
+       "~\n",
+       "Multivariate_normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->likelihood\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu\n",
+       "\n",
+       "mu\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_y->mu\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_t->mu\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu->likelihood\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n"
+      ],
+      "text/plain": [
+       ""
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pm.model_to_graphviz(result_normal.model)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "34f3a1b7",
+   "metadata": {},
+   "source": [
+    "Now compare the structure of the spike and slab model. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "4f8c2685",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/svg+xml": [
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "clusterinstruments (15)\n",
+       "\n",
+       "instruments (15)\n",
+       "\n",
+       "\n",
+       "clustercovariates (16)\n",
+       "\n",
+       "covariates (16)\n",
+       "\n",
+       "\n",
+       "cluster3\n",
+       "\n",
+       "3\n",
+       "\n",
+       "\n",
+       "cluster2 x 2\n",
+       "\n",
+       "2 x 2\n",
+       "\n",
+       "\n",
+       "cluster2\n",
+       "\n",
+       "2\n",
+       "\n",
+       "\n",
+       "cluster2500\n",
+       "\n",
+       "2500\n",
+       "\n",
+       "\n",
+       "cluster2500 x 2\n",
+       "\n",
+       "2500 x 2\n",
+       "\n",
+       "\n",
+       "\n",
+       "pi_beta_z\n",
+       "\n",
+       "pi_beta_z\n",
+       "~\n",
+       "Beta\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_z\n",
+       "\n",
+       "gamma_beta_z\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "pi_beta_z->gamma_beta_z\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "pi_beta_t\n",
+       "\n",
+       "pi_beta_t\n",
+       "~\n",
+       "Beta\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_t\n",
+       "\n",
+       "gamma_beta_t\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "pi_beta_t->gamma_beta_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t_raw\n",
+       "\n",
+       "beta_t_raw\n",
+       "~\n",
+       "Normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t\n",
+       "\n",
+       "beta_t\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t_raw->beta_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_t\n",
+       "\n",
+       "mu_t\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_t->mu_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_t->beta_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_t_u\n",
+       "\n",
+       "gamma_beta_t_u\n",
+       "~\n",
+       "Uniform\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_t_u->gamma_beta_t\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_z_u\n",
+       "\n",
+       "gamma_beta_z_u\n",
+       "~\n",
+       "Uniform\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_z_u->gamma_beta_z\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z\n",
+       "\n",
+       "beta_z\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_y\n",
+       "\n",
+       "mu_y\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z->mu_y\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "gamma_beta_z->beta_z\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z_raw\n",
+       "\n",
+       "beta_z_raw\n",
+       "~\n",
+       "Normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "beta_z_raw->beta_z\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov\n",
+       "\n",
+       "chol_cov\n",
+       "~\n",
+       "_LKJCholeskyCov\n",
+       "\n",
+       "\n",
+       "\n",
+       "cov\n",
+       "\n",
+       "cov\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->cov\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov_corr\n",
+       "\n",
+       "chol_cov_corr\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->chol_cov_corr\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov_stds\n",
+       "\n",
+       "chol_cov_stds\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->chol_cov_stds\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "likelihood\n",
+       "\n",
+       "likelihood\n",
+       "~\n",
+       "Multivariate_normal\n",
+       "\n",
+       "\n",
+       "\n",
+       "chol_cov->likelihood\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu\n",
+       "\n",
+       "mu\n",
+       "~\n",
+       "Deterministic\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_y->mu\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu_t->mu\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "mu->likelihood\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n",
+       "\n"
+      ],
+      "text/plain": [
+       ""
+      ]
+     },
+     "execution_count": 18,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pm.model_to_graphviz(result_spike_slab.model)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "368660c8",
+   "metadata": {},
+   "source": [
+    "Despite seeing some divergences in our spike and slab model, most other sampler health metrics seem healthy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "0755095c",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "az.plot_energy(result_spike_slab.idata, figsize=(20, 6));" + ] + }, + { + "cell_type": "markdown", + "id": "5bffd8b6", + "metadata": {}, + "source": [ + "And since we know the true data generating conditions we can also assess the derived posterior treatment estimates. " + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "838e0726", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 6))\n", + "az.plot_posterior(\n", + " result_normal.idata,\n", + " var_names=[\"beta_z\"],\n", + " coords={\"covariates\": [\"T_cont\"]},\n", + " ax=ax,\n", + " label=\"Normal\",\n", + ")\n", + "az.plot_posterior(\n", + " result_spike_slab.idata,\n", + " var_names=[\"beta_z\"],\n", + " coords={\"covariates\": [\"T_cont\"]},\n", + " ax=ax,\n", + " color=\"green\",\n", + " label=\"spike and slab\",\n", + ")\n", + "ax.axvline(3, color=\"black\", linestyle=\"--\", label=\"True value\");" + ] + }, + { + "cell_type": "markdown", + "id": "057b4f5d", + "metadata": {}, + "source": [ + "This plot suggests that the spike and slab prior was better able to ignore noise in the process and zero in on the true effect. This will not always work but it is a sensible practice to at least sensitivity check difference between the estimates under different prior settings. We can observe how aggressively the spike and slab prior worked to cull unwanted variables from each model by comparing the values on the coefficients across each model" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "127888b7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "axs = az.plot_forest(\n", + " [result_spike_slab.idata, result_normal.idata],\n", + " var_names=[\"beta_z\"],\n", + " combined=True,\n", + " model_names=[\"Spike and Slab\", \"Normal\"],\n", + " r_hat=True,\n", + ")\n", + "axs[0].set_title(\"Parameter Comparison Outcome Model \\n Baseline v Spike and Slab\");" + ] + }, + { + "cell_type": "markdown", + "id": "f09b24bf", + "metadata": {}, + "source": [ + "#### The Treatment Model\n", + "\n", + "Variable selection is applied to both the outcome and the treatment model. In this way we calibrate our parameters to the joint patterns of realisations between these two endogenous variables. " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "acafc928", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "axs = az.plot_forest(\n", + " [result_spike_slab.idata, result_normal.idata],\n", + " var_names=[\"beta_t\"],\n", + " combined=True,\n", + " model_names=[\"Spike and Slab\", \"Normal\"],\n", + " r_hat=True,\n", + ")\n", + "\n", + "axs[0].set_title(\"Parameter Comparison Treatment Model \\n Baseline v Spike and Slab\");" + ] + }, + { + "cell_type": "markdown", + "id": "07f7d95b", + "metadata": {}, + "source": [ + "The spike and slab prior can also output direct inclusion probabilities that can be used for communication regarding which variables were \"selected\" in the process." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f2a0b213", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
probselectedgamma_mean
00.01750False0.020546
11.00000True0.991285
20.64100True0.659357
30.58200True0.616580
40.17625False0.192949
50.06075False0.068297
60.66600True0.702671
70.01000False0.012679
80.01300False0.016342
90.00975False0.012392
100.01700False0.019508
110.01625False0.021217
120.02000False0.024469
130.01825False0.023474
140.02700False0.031636
150.01650False0.020863
\n", + "
" + ], + "text/plain": [ + " prob selected gamma_mean\n", + "0 0.01750 False 0.020546\n", + "1 1.00000 True 0.991285\n", + "2 0.64100 True 0.659357\n", + "3 0.58200 True 0.616580\n", + "4 0.17625 False 0.192949\n", + "5 0.06075 False 0.068297\n", + "6 0.66600 True 0.702671\n", + "7 0.01000 False 0.012679\n", + "8 0.01300 False 0.016342\n", + "9 0.00975 False 0.012392\n", + "10 0.01700 False 0.019508\n", + "11 0.01625 False 0.021217\n", + "12 0.02000 False 0.024469\n", + "13 0.01825 False 0.023474\n", + "14 0.02700 False 0.031636\n", + "15 0.01650 False 0.020863" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary = result_spike_slab.model.vs_prior_outcome.get_inclusion_probabilities(\n", + " result_spike_slab.idata, \"beta_z\"\n", + ")\n", + "summary" + ] + }, + { + "cell_type": "markdown", + "id": "38568d27", + "metadata": {}, + "source": [ + "### Horseshoe\n", + "\n", + "The horseshoe prior is a sophisticated continuous shrinkage method designed for regularization and variable selection in high-dimensional regression settings. Unlike discrete selection approaches, it operates through a elegant hierarchical structure that adaptively shrinks coefficients based on the strength of their signal in the data. The key to the implementation is the hierarchical $\\lambda$ component: \n", + "\n", + "$$ \\tilde{\\lambda}_j = \\sqrt{\\frac{c^2 \\lambda_j^2}{c^2 + \\tau^2 \\lambda_j^2}} $$\n", + "\n", + "is composed of individual local shrinkage parameters and $c^2$ is a regularization parameter that prevents over-shrinkage of genuinely large signals. \n", + "\n", + "#### The $\\tau_0$ hyperparameter\n", + "\n", + "Like the `temperature` parameter in the spike and slab model, the $\\tau_0$ parameter determines the overall level of sparsity expected in the model. However, the $tau_0$ will by default be derived from the data and the number of covariates in your data. While both the horseshoe and spike-and-slab priors address variable selection and sparsity, they embody fundamentally different philosophies about how to achieve these goals. The horseshoe embraces continuity, creating a smooth gradient of shrinkage where all coefficients remain in the model but are pulled toward zero with varying intensity. " + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "16bb5f90", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(20, 6))\n", + "axs = axs.flatten()\n", + "axs[0].hist(\n", + " pm.draw(pm.InverseGamma.dist(2, 2), 1000) ** 2,\n", + " ec=\"black\",\n", + " color=\"slateblue\",\n", + " bins=30,\n", + ")\n", + "axs[1].hist(\n", + " pm.draw(pm.InverseGamma.dist(3, 3), 1000) ** 2,\n", + " ec=\"black\",\n", + " color=\"slateblue\",\n", + " bins=30,\n", + ")\n", + "axs[2].hist(\n", + " pm.draw(pm.InverseGamma.dist(4, 4), 1000) ** 2,\n", + " ec=\"black\",\n", + " color=\"slateblue\",\n", + " bins=30,\n", + ")\n", + "axs[1].set_title(r\"Various Distributions for the $c^{2}$ hyperparameter\", size=20);" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "63edfa4e", + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--------------------------------------------------------------------------------\n", + "Model 3: Horseshoe Priors\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/causalpy/experiments/instrumental_variable.py:187: UserWarning: Warning. The treatment variable is not Binary.\n", + " This is not necessarily a problem but it violates\n", + " the assumption of a simple IV experiment.\n", + " The coefficients should be interpreted appropriately.\n", + " warnings.warn(\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/causalpy/pymc_models.py:699: UserWarning: Variable selection priors specified. The 'mus' and 'sigmas' in the priors dict will be ignored for beta coefficients. Only 'eta' and 'lkj_sd' will be used.\n", + " warnings.warn(\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [tau_beta_t, lambda_beta_t, c2_beta_t, beta_t_raw, tau_beta_z, lambda_beta_z, c2_beta_z, beta_z_raw, chol_cov]\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1410b52c995348e397aed1a2843f2bb7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n", + "/Users/nathanielforde/mambaforge/envs/CausalPy/lib/python3.13/site-packages/pytensor/compile/function/types.py:1039: RuntimeWarning: invalid value encountered in accumulate\n", + " outputs = vm() if output_subset is None else vm(output_subset=output_subset)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 534 seconds.\n",
+      "There were 16 divergences after tuning. Increase `target_accept` or reparameterize.\n",
+      "The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
+     ]
+    }
+   ],
+   "source": [
+    "# =========================================================================\n",
+    "# Model 2: Horseshoe priors\n",
+    "# =========================================================================\n",
+    "print(\"\\n\" + \"-\" * 80)\n",
+    "print(\"Model 3: Horseshoe Priors\")\n",
+    "print(\"-\" * 80)\n",
+    "\n",
+    "result_horseshoe = cp.InstrumentalVariable(\n",
+    "    instruments_data=instruments_data,\n",
+    "    data=data,\n",
+    "    instruments_formula=instruments_formula,\n",
+    "    formula=formula,\n",
+    "    model=cp.pymc_models.InstrumentalVariableRegression(sample_kwargs=sample_kwargs),\n",
+    "    vs_prior_type=\"horseshoe\",\n",
+    "    vs_hyperparams={\"c2_alpha\": 3, \"c2_beta\": 3},\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "db7d86aa",
+   "metadata": {},
+   "source": [
+    "Similar to the inclusion probabilities in the spike and slab model, a horseshoe model can output the relative shrinkage factor that gets applied to each variables inclusion. This method of variable is less decisive than spike and slab, but also mitigates case of completely zero-ing the small but real contributions of certain variables."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "9c283ee1",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
shrinkage_factorlambda_tildetau
00.0147801.2225240.036811
10.69773957.7139610.036811
20.16028213.2577980.036811
30.12964910.7240230.036811
40.0766766.3422980.036811
50.0256502.1216800.036811
60.19018615.7313980.036811
70.0141981.1743820.036811
80.0146361.2106340.036811
90.0140241.1600020.036811
100.0144411.1944850.036811
110.0189071.5639340.036811
120.0185181.5317340.036811
130.0174151.4405140.036811
140.0192381.5912910.036811
150.0174811.4459460.036811
\n", + "
" + ], + "text/plain": [ + " shrinkage_factor lambda_tilde tau\n", + "0 0.014780 1.222524 0.036811\n", + "1 0.697739 57.713961 0.036811\n", + "2 0.160282 13.257798 0.036811\n", + "3 0.129649 10.724023 0.036811\n", + "4 0.076676 6.342298 0.036811\n", + "5 0.025650 2.121680 0.036811\n", + "6 0.190186 15.731398 0.036811\n", + "7 0.014198 1.174382 0.036811\n", + "8 0.014636 1.210634 0.036811\n", + "9 0.014024 1.160002 0.036811\n", + "10 0.014441 1.194485 0.036811\n", + "11 0.018907 1.563934 0.036811\n", + "12 0.018518 1.531734 0.036811\n", + "13 0.017415 1.440514 0.036811\n", + "14 0.019238 1.591291 0.036811\n", + "15 0.017481 1.445946 0.036811" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary = result_horseshoe.model.vs_prior_outcome.get_shrinkage_factors(\n", + " result_horseshoe.idata, \"beta_z\"\n", + ")\n", + "summary" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "82b0121c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 6))\n", + "az.plot_posterior(\n", + " result_normal.idata, var_names=[\"beta_z\"], coords={\"covariates\": [\"T_cont\"]}, ax=ax\n", + ")\n", + "az.plot_posterior(\n", + " result_horseshoe.idata,\n", + " var_names=[\"beta_z\"],\n", + " coords={\"covariates\": [\"T_cont\"]},\n", + " ax=ax,\n", + " color=\"green\",\n", + ")\n", + "ax.axvline(3, color=\"black\", linestyle=\"--\");" + ] + }, + { + "cell_type": "markdown", + "id": "e15d4f1e", + "metadata": {}, + "source": [ + "In this case it seems the horseshoe prior leads a bi-modal posterior estimate of the treatment effect suggesting a kind of indecision about the level of sparsity to apply. " + ] + }, + { + "cell_type": "markdown", + "id": "fc265f5d", + "metadata": {}, + "source": [ + "### Binary Treatment Case\n", + "\n", + "Our data generating function output two different simulation scenarios, where the treatment was either continuous or binary. This allows us to demonstrate the joint modelling of the binary treatment outcome which uses a Bernoulli likelihood for the treatment variable and latent confounding to model the joint realisation of treatment and outcome. " + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "89e61d28", + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--------------------------------------------------------------------------------\n", + "Model 1: Normal Priors Binary Treatment (No Variable Selection)\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [beta_t, beta_z, sigma_U, rho_unconstr, eps_raw]\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n", + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c7276bf767824813a312c4c3ccff7c01", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 68 seconds.\n",
+      "The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
+      "The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n"
+     ]
+    }
+   ],
+   "source": [
+    "formula = \"Y_bin ~ T_bin + \" + \" + \".join(features)\n",
+    "instruments_formula = \"T_bin ~ 1 + \" + \" + \".join(features)\n",
+    "\n",
+    "\n",
+    "# =========================================================================\n",
+    "# Model 1: Normal priors (no selection)\n",
+    "# =========================================================================\n",
+    "print(\"\\n\" + \"-\" * 80)\n",
+    "print(\"Model 1: Normal Priors Binary Treatment (No Variable Selection)\")\n",
+    "print(\"-\" * 80)\n",
+    "\n",
+    "result_normal_binary = cp.InstrumentalVariable(\n",
+    "    instruments_data=instruments_data,\n",
+    "    data=data,\n",
+    "    instruments_formula=instruments_formula,\n",
+    "    formula=formula,\n",
+    "    model=cp.pymc_models.InstrumentalVariableRegression(sample_kwargs=sample_kwargs),\n",
+    "    vs_prior_type=None,  # No variable selection\n",
+    "    binary_treatment=True,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "id": "98c1b50a",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 6))\n", + "az.plot_posterior(\n", + " result_normal.idata, var_names=[\"beta_z\"], coords={\"covariates\": [\"T_cont\"]}, ax=ax\n", + ")\n", + "az.plot_posterior(\n", + " result_normal_binary.idata,\n", + " var_names=[\"beta_z\"],\n", + " coords={\"covariates\": [\"T_bin\"]},\n", + " ax=ax,\n", + " color=\"green\",\n", + ")\n", + "ax.axvline(3, color=\"black\", linestyle=\"--\");" + ] + }, + { + "cell_type": "markdown", + "id": "36ad018b", + "metadata": {}, + "source": [ + "\n", + "### Conclusion: Choosing Your Path Through Uncertainty\n", + "\n", + "Variable selection priors offer a principled way to navigate the tension between model complexity and causal identification. Rather than forcing binary decisions about which variables to include, these priors encode our uncertainty about variable relevance directly into the inferential framework. But as we've seen, the choice between spike-and-slab and horseshoe reflects deeper commitments about how sparsity manifests in the world.\n", + "\n", + "**The spike-and-slab prior** embodies decisiveness. It asks: which variables truly matter? By pushing coefficients toward exactly zero or allowing them to take on substantial values, it produces interpretable inclusion probabilities that clearly communicate which predictors the model has \"selected.\" This approach shines when you believe that many potential confounders are genuine noise—included out of caution but ultimately irrelevant to the causal mechanism. The discrete nature of selection also makes results easier to communicate to stakeholders who think in terms of \"what factors matter?\" \n", + "\n", + "**The horseshoe prior** embraces nuance. It acknowledges that effects exist on a continuum, and that small but real contributions shouldn't be entirely zeroed out. The continuous shrinkage allows weak signals to persist (heavily damped) while strong signals emerge largely unscathed. This is valuable when you suspect that multiple confounders have genuine but varying degrees of influence, and when premature exclusion of any single variable might introduce bias. The regularization parameter $c^2$ acts as a safeguard, preventing even the horseshoe's aggressive shrinkage from overwhelming genuinely large effects.\n", + " \n", + "In our simulations, both approaches identified the true treatment effect of 3, though they arrived there differently. The spike-and-slab showed more confidence, producing tighter posterior intervals by decisively excluding noise variables. The horseshoe's bi-modal posterior in some specifications revealed its uncertainty about the appropriate level of sparsity a kind of probabilistic humility that spike-and-slab's discrete choices don't allow.\n", + "\n", + "**Practical Guidance:**\n", + " \n", + "- **Use spike-and-slab when** you have strong priors about sparsity (many potential confounders, few true ones), when interpretability matters (stakeholders want to know \"what's included?\"), or when you're willing to trade some flexibility for more decisive inference.\n", + " \n", + "- **Use horseshoe when** you're uncertain about sparsity levels, when small effects might still matter for causal identification, or when you want the model to smoothly adapt its shrinkage to the data without hard inclusion/exclusion decisions.\n", + " \n", + "- **Use neither when** theory clearly identifies your confounders, when sample size is large relative to the number of predictors, or when the cost of Type I errors (including irrelevant variables) is low relative to Type II errors (excluding true confounders).\n", + " \n", + "**Final Thoughts:**\n", + " \n", + "Variable selection priors don't eliminate the need for causal reasoning. They don't tell you which variables are *causally* relevant, only which are *statistically* predictive. But when used thoughtfully—guided by theory about potential confounders, informed by domain knowledge about likely sparsity patterns, and validated through sensitivity analysis. They offer a middle path between the Scylla of over-specification (including everything) and the Charybdis of under-specification (excluding too much). Used within a joint model of treatment and outcome variable, the argument of a variable selection routine represents an attempt to calibrate the parameters to select the instrument structure. What variable selection is really doing in joint treatment-outcome models is calibrating the parameters to discover patterns consistent with instrument structure *if such structure exists in the data*. The horseshoe shrinks away coefficients that appear redundant given the covariance structure between treatment, outcome, and covariates. The spike-and-slab actively excludes variables that don't contribute to explaining either margin after accounting for shared variation.\n", + " \n", + "The ideal use of variable selection in instrumental variable designs is not as a replacement for domain knowledge but as a consistency check. The real power of these methods lies not in automation but in transparency. By making variable selection part of the posterior distribution rather than a pre-processing step, we can quantify and communicate our uncertainty about model structure itself. This moves us closer to the goal of all principled causal inference: not just estimating effects, but understanding the limits of what we can learn from the data we have.\n", + " \n", + "As always in causal inference, the model is a question posed to the data. Variable selection priors help us ask that question more precisely, but we still need theory to tell us if we're asking the right question at all.\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "CausalPy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}