From 89636f8e6b18e578bf8b396e1435c8a163d141bf Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 27 Oct 2025 15:19:00 +0000 Subject: [PATCH 1/2] initial bug fix + regression test --- Makefile | 2 +- causalpy/plot_utils.py | 35 ++++++++--- causalpy/tests/test_plot_utils.py | 97 +++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 causalpy/tests/test_plot_utils.py diff --git a/Makefile b/Makefile index d109ae39..d6670904 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ doctest: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py test: - pytest + python -m pytest uml: pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests diff --git a/causalpy/plot_utils.py b/causalpy/plot_utils.py index 5ad596ce..03b7021f 100644 --- a/causalpy/plot_utils.py +++ b/causalpy/plot_utils.py @@ -93,10 +93,31 @@ def get_hdi_to_df( :param hdi_prob: The size of the HDI, default is 0.94 """ - hdi = ( - az.hdi(x, hdi_prob=hdi_prob) - .to_dataframe() - .unstack(level="hdi") - .droplevel(0, axis=1) - ) - return hdi + hdi_result = az.hdi(x, hdi_prob=hdi_prob) + hdi_df = hdi_result.to_dataframe().unstack(level="hdi") + + # Handle MultiIndex columns from unstack operation + # After unstack, we may have MultiIndex like: [('mu', 'lower'), ('mu', 'higher'), ('coord', 'lower'), ('coord', 'higher')] + # We need to extract only the data variable columns (first level), not coordinate columns + if isinstance(hdi_df.columns, pd.MultiIndex): + # Get the name of the data variable (should be at level 0) + # For xarray DataArrays, the variable name is typically at index 0 + data_var_names = hdi_df.columns.get_level_values(0).unique() + + # Filter to include only actual data variables (excluding coordinate names that became columns) + # The data variable is typically the one that was originally in the DataArray/Dataset + # For simple cases, it's often just the first unique value + if len(data_var_names) > 1: + # Find the numeric data variable (not string coordinates) + for var_name in data_var_names: + if ( + hdi_df[(var_name, hdi_df.columns.get_level_values(1)[0])].dtype + != "object" + ): + hdi_df = hdi_df[var_name] + break + else: + # Only one variable, select it + hdi_df = hdi_df[data_var_names[0]] + + return hdi_df diff --git a/causalpy/tests/test_plot_utils.py b/causalpy/tests/test_plot_utils.py new file mode 100644 index 00000000..26965784 --- /dev/null +++ b/causalpy/tests/test_plot_utils.py @@ -0,0 +1,97 @@ +# Copyright 2025 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for plot utility functions +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from causalpy.plot_utils import get_hdi_to_df + + +@pytest.mark.integration +def test_get_hdi_to_df_with_coordinate_dimensions(): + """ + Regression test for bug where get_hdi_to_df returned string coordinate values + instead of numeric HDI values when xarray had named coordinate dimensions. + + This bug manifested in multi-cell synthetic control experiments where columns + like 'pred_hdi_upper_94' contained the string "treated_agg" instead of + numeric upper bound values. + + See: https://github.com/pymc-labs/CausalPy/issues/532 + """ + # Create a mock xarray DataArray similar to what's produced in synthetic control + # with a coordinate dimension like 'treated_units' + np.random.seed(42) + n_chains = 2 + n_draws = 100 + n_obs = 10 + + # Simulate posterior samples with a named coordinate + data = np.random.normal(loc=5.0, scale=0.5, size=(n_chains, n_draws, n_obs)) + + xr_data = xr.DataArray( + data, + dims=["chain", "draw", "obs_ind"], + coords={ + "chain": np.arange(n_chains), + "draw": np.arange(n_draws), + "obs_ind": np.arange(n_obs), + "treated_units": "treated_agg", # This coordinate caused the bug + }, + ) + + # Call get_hdi_to_df + result = get_hdi_to_df(xr_data, hdi_prob=0.94) + + # Assertions to verify the bug is fixed + assert isinstance(result, pd.DataFrame), "Result should be a DataFrame" + + # Check that we have exactly 2 columns (lower and higher) + assert result.shape[1] == 2, f"Expected 2 columns, got {result.shape[1]}" + + # Check column names + assert "lower" in result.columns, "Should have 'lower' column" + assert "higher" in result.columns, "Should have 'higher' column" + + # CRITICAL: Check that columns contain numeric data, not strings + assert result["lower"].dtype in [ + np.float64, + np.float32, + ], f"'lower' column should be numeric, got {result['lower'].dtype}" + assert result["higher"].dtype in [ + np.float64, + np.float32, + ], f"'higher' column should be numeric, got {result['higher'].dtype}" + + # Check that no string values like 'treated_agg' appear in the data + assert not (result["lower"].astype(str).str.contains("treated_agg").any()), ( + "'lower' column should not contain coordinate string values" + ) + assert not (result["higher"].astype(str).str.contains("treated_agg").any()), ( + "'higher' column should not contain coordinate string values" + ) + + # Verify HDI ordering + assert (result["lower"] <= result["higher"]).all(), ( + "'lower' should be <= 'higher' for all rows" + ) + + # Verify reasonable HDI values (should be around the mean of 5.0) + assert result["lower"].min() > 3.0, "HDI lower bounds should be reasonable" + assert result["higher"].max() < 7.0, "HDI upper bounds should be reasonable" From fa5f5afa9e01462076d9b57514f0574236c4e378 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 27 Oct 2025 15:23:16 +0000 Subject: [PATCH 2/2] massive simplification of the fix --- causalpy/plot_utils.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/causalpy/plot_utils.py b/causalpy/plot_utils.py index 03b7021f..4eefa173 100644 --- a/causalpy/plot_utils.py +++ b/causalpy/plot_utils.py @@ -94,30 +94,16 @@ def get_hdi_to_df( The size of the HDI, default is 0.94 """ hdi_result = az.hdi(x, hdi_prob=hdi_prob) - hdi_df = hdi_result.to_dataframe().unstack(level="hdi") - # Handle MultiIndex columns from unstack operation - # After unstack, we may have MultiIndex like: [('mu', 'lower'), ('mu', 'higher'), ('coord', 'lower'), ('coord', 'higher')] - # We need to extract only the data variable columns (first level), not coordinate columns - if isinstance(hdi_df.columns, pd.MultiIndex): - # Get the name of the data variable (should be at level 0) - # For xarray DataArrays, the variable name is typically at index 0 - data_var_names = hdi_df.columns.get_level_values(0).unique() + # Get the data variable name (typically 'mu' or 'x') + # We select only the data variable column to exclude coordinates like 'treated_units' + data_var = list(hdi_result.data_vars)[0] - # Filter to include only actual data variables (excluding coordinate names that became columns) - # The data variable is typically the one that was originally in the DataArray/Dataset - # For simple cases, it's often just the first unique value - if len(data_var_names) > 1: - # Find the numeric data variable (not string coordinates) - for var_name in data_var_names: - if ( - hdi_df[(var_name, hdi_df.columns.get_level_values(1)[0])].dtype - != "object" - ): - hdi_df = hdi_df[var_name] - break - else: - # Only one variable, select it - hdi_df = hdi_df[data_var_names[0]] + # Convert to DataFrame, select only the data variable column, then unstack + # This prevents coordinate values (like 'treated_agg') from appearing as columns + hdi_df = hdi_result[data_var].to_dataframe()[[data_var]].unstack(level="hdi") + + # Remove the top level of column MultiIndex to get just 'lower' and 'higher' + hdi_df.columns = hdi_df.columns.droplevel(0) return hdi_df