Skip to content

Commit 4fa1650

Browse files
committed
add test that the r2 scores differ across treated units
1 parent 26691ec commit 4fa1650

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

causalpy/tests/test_multi_unit_wsf.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,90 @@ def test_scoring_single_unit(self, single_treated_data):
305305
# R2 should be reasonable
306306
assert score["r2"] >= -1 # R2 can be negative for very bad fits
307307
assert score["r2_std"] >= 0 # Standard deviation should be non-negative
308+
309+
def test_r2_scores_differ_across_units(self, rng):
310+
"""Test that R² scores are different for different treated units.
311+
312+
This is a defensive test to ensure that each treated unit is being scored
313+
independently and not getting identical scores due to implementation bugs.
314+
"""
315+
n_obs = 100 # Use more observations for better differentiation
316+
n_control = 4
317+
318+
# Control unit data
319+
control_data = {}
320+
for i in range(n_control):
321+
control_data[f"control_{i}"] = rng.normal(0, 1, n_obs)
322+
323+
# Create treated units with deliberately different quality of fit
324+
treated_data = {}
325+
326+
# Treated unit 0: Good fit (close to control combination)
327+
weights_0 = rng.dirichlet(np.ones(n_control))
328+
treated_data["treated_0"] = sum(
329+
weights_0[i] * control_data[f"control_{i}"] for i in range(n_control)
330+
) + rng.normal(0, 0.05, n_obs) # Low noise
331+
332+
# Treated unit 1: Medium fit
333+
weights_1 = rng.dirichlet(np.ones(n_control))
334+
treated_data["treated_1"] = sum(
335+
weights_1[i] * control_data[f"control_{i}"] for i in range(n_control)
336+
) + rng.normal(0, 0.3, n_obs) # Medium noise
337+
338+
# Treated unit 2: Poor fit (mostly random)
339+
treated_data["treated_2"] = rng.normal(0, 2, n_obs) # Largely independent
340+
341+
# Create DataFrame
342+
df = pd.DataFrame({**control_data, **treated_data})
343+
344+
# Prepare data for model
345+
control_units = [f"control_{i}" for i in range(n_control)]
346+
treated_units = ["treated_0", "treated_1", "treated_2"]
347+
348+
X = xr.DataArray(
349+
df[control_units].values,
350+
dims=["obs_ind", "coeffs"],
351+
coords={
352+
"obs_ind": df.index,
353+
"coeffs": control_units,
354+
},
355+
)
356+
357+
y = xr.DataArray(
358+
df[treated_units].values,
359+
dims=["obs_ind", "treated_units"],
360+
coords={
361+
"obs_ind": df.index,
362+
"treated_units": treated_units,
363+
},
364+
)
365+
366+
coords = {
367+
"coeffs": control_units,
368+
"treated_units": treated_units,
369+
"obs_ind": np.arange(n_obs),
370+
}
371+
372+
# Fit model and score
373+
wsf = WeightedSumFitter(sample_kwargs=sample_kwargs)
374+
wsf.fit(X, y, coords=coords)
375+
scores = wsf.score(X, y)
376+
377+
# Extract R² values for each treated unit
378+
r2_values = [scores[f"{unit}_r2"] for unit in treated_units]
379+
380+
# Test that not all R² values are the same
381+
# Use a tolerance to avoid issues with floating point precision
382+
assert not np.allclose(r2_values, r2_values[0], atol=1e-6), (
383+
f"All R² scores are too similar: {r2_values}. "
384+
"This suggests the scoring might not be working correctly for individual units."
385+
)
386+
387+
# Test that the expected ordering holds (good > medium > poor fit)
388+
# Note: This might occasionally fail due to randomness, but should generally hold
389+
# We'll just check that they're not all identical and that we have reasonable variation
390+
r2_std = np.std(r2_values)
391+
assert r2_std > 0.01, (
392+
f"R² standard deviation is too low ({r2_std}), suggesting insufficient variation "
393+
"between treated units. This might indicate a scoring implementation issue."
394+
)

0 commit comments

Comments
 (0)