Skip to content

Commit 70e5afe

Browse files
authored
Merge pull request #369 from DoubleML/p-add-rv-plot
Add plotting of rv and sensitivity bounds for DiD model
2 parents 99e4116 + 3fab6f9 commit 70e5afe

File tree

2 files changed

+203
-23
lines changed

2 files changed

+203
-23
lines changed

doubleml/did/did_multi.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -979,12 +979,13 @@ def aggregate(self, aggregation="group"):
979979
def plot_effects(
980980
self,
981981
level=0.95,
982+
result_type="effect",
982983
joint=True,
983984
figsize=(12, 8),
984985
color_palette="colorblind",
985986
date_format=None,
986-
y_label="Effect",
987-
title="Estimated ATTs by Group",
987+
y_label=None,
988+
title=None,
988989
jitter_value=None,
989990
default_jitter=0.1,
990991
):
@@ -996,6 +997,10 @@ def plot_effects(
996997
level : float
997998
The confidence level for the intervals.
998999
Default is ``0.95``.
1000+
result_type : str
1001+
Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
1002+
``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
1003+
Default is ``'effect'``.
9991004
joint : bool
10001005
Indicates whether joint confidence intervals are computed.
10011006
Default is ``True``.
@@ -1010,10 +1015,10 @@ def plot_effects(
10101015
Default is ``None``.
10111016
y_label : str
10121017
Label for y-axis.
1013-
Default is ``"Effect"``.
1018+
Default is ``None``.
10141019
title : str
10151020
Title for the entire plot.
1016-
Default is ``"Estimated ATTs by Group"``.
1021+
Default is ``None``.
10171022
jitter_value : float
10181023
Amount of jitter to apply to points.
10191024
Default is ``None``.
@@ -1035,8 +1040,29 @@ def plot_effects(
10351040
"""
10361041
if self.framework is None:
10371042
raise ValueError("Apply fit() before plot_effects().")
1043+
1044+
if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]:
1045+
raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.")
1046+
1047+
if result_type != "effect" and self._framework.sensitivity_params is None:
1048+
raise ValueError(
1049+
f"result_type='{result_type}' requires sensitivity analysis. " "Please call sensitivity_analysis() first."
1050+
)
1051+
10381052
df = self._create_ci_dataframe(level=level, joint=joint)
10391053

1054+
# Set default y_label and title based on result_type
1055+
label_configs = {
1056+
"effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"},
1057+
"rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"},
1058+
"est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"},
1059+
"ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"},
1060+
}
1061+
1062+
config = label_configs[result_type]
1063+
y_label = y_label if y_label is not None else config["y_label"]
1064+
title = title if title is not None else config["title"]
1065+
10401066
# Sort time periods and treatment groups
10411067
first_treated_periods = sorted(df["First Treated"].unique())
10421068
n_periods = len(first_treated_periods)
@@ -1068,7 +1094,7 @@ def plot_effects(
10681094
period_df = df[df["First Treated"] == period]
10691095
ax = axes[idx]
10701096

1071-
self._plot_single_group(ax, period_df, period, colors, is_datetime, jitter_value)
1097+
self._plot_single_group(ax, period_df, period, result_type, colors, is_datetime, jitter_value)
10721098

10731099
# Set axis labels
10741100
if idx == n_periods - 1: # Only bottom plot gets x label
@@ -1085,7 +1111,7 @@ def plot_effects(
10851111
legend_ax.axis("off")
10861112
legend_elements = [
10871113
Line2D([0], [0], color="red", linestyle=":", alpha=0.7, label="Treatment start"),
1088-
Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label="Zero effect"),
1114+
Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label=f"Zero {result_type}"),
10891115
Line2D([0], [0], marker="o", color=colors["pre"], linestyle="None", label="Pre-treatment", markersize=5),
10901116
]
10911117

@@ -1108,7 +1134,7 @@ def plot_effects(
11081134

11091135
return fig, axes
11101136

1111-
def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_value):
1137+
def _plot_single_group(self, ax, period_df, period, result_type, colors, is_datetime, jitter_value):
11121138
"""
11131139
Plot estimates for a single treatment group on the given axis.
11141140
@@ -1120,6 +1146,10 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11201146
DataFrame containing estimates for a specific time period.
11211147
period : int or datetime
11221148
Treatment period for this group.
1149+
result_type : str
1150+
Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
1151+
``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
1152+
Default is ``'effect'``.
11231153
colors : dict
11241154
Dictionary with 'pre', 'anticipation' (if applicable), and 'post' color values.
11251155
is_datetime : bool
@@ -1165,6 +1195,31 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11651195
# Define category mappings
11661196
categories = [("pre", pre_treatment_mask), ("anticipation", anticipation_mask), ("post", post_treatment_mask)]
11671197

1198+
# Define plot configurations for each result type
1199+
plot_configs = {
1200+
"effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30},
1201+
"rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50},
1202+
"est_bounds": {
1203+
"plot_col": "Estimate",
1204+
"err_col_upper": "Estimate Upper Bound",
1205+
"err_col_lower": "Estimate Lower Bound",
1206+
"s_val": 30,
1207+
},
1208+
"ci_bounds": {
1209+
"plot_col": "Estimate",
1210+
"err_col_upper": "CI Upper Bound",
1211+
"err_col_lower": "CI Lower Bound",
1212+
"s_val": 30,
1213+
},
1214+
}
1215+
1216+
config = plot_configs[result_type]
1217+
plot_col = config["plot_col"]
1218+
plot_col_2 = config.get("plot_col_2")
1219+
err_col_upper = config.get("err_col_upper")
1220+
err_col_lower = config.get("err_col_lower")
1221+
s_val = config["s_val"]
1222+
11681223
# Plot each category
11691224
for category_name, mask in categories:
11701225
if not mask.any():
@@ -1179,22 +1234,33 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11791234

11801235
if not category_data.empty:
11811236
ax.scatter(
1182-
category_data["jittered_x"], category_data["Estimate"], color=colors[category_name], alpha=0.8, s=30
1183-
)
1184-
ax.errorbar(
1185-
category_data["jittered_x"],
1186-
category_data["Estimate"],
1187-
yerr=[
1188-
category_data["Estimate"] - category_data["CI Lower"],
1189-
category_data["CI Upper"] - category_data["Estimate"],
1190-
],
1191-
fmt="o",
1192-
capsize=3,
1193-
color=colors[category_name],
1194-
markersize=4,
1195-
markeredgewidth=1,
1196-
linewidth=1,
1237+
category_data["jittered_x"], category_data[plot_col], color=colors[category_name], alpha=0.8, s=s_val
11971238
)
1239+
if result_type in ["effect", "est_bounds", "ci_bounds"]:
1240+
ax.errorbar(
1241+
category_data["jittered_x"],
1242+
category_data[plot_col],
1243+
yerr=[
1244+
category_data[plot_col] - category_data[err_col_lower],
1245+
category_data[err_col_upper] - category_data[plot_col],
1246+
],
1247+
fmt="o",
1248+
capsize=3,
1249+
color=colors[category_name],
1250+
markersize=4,
1251+
markeredgewidth=1,
1252+
linewidth=1,
1253+
)
1254+
1255+
elif result_type == "rv":
1256+
ax.scatter(
1257+
category_data["jittered_x"],
1258+
category_data[plot_col_2],
1259+
color=colors[category_name],
1260+
alpha=0.8,
1261+
s=s_val,
1262+
marker="s",
1263+
)
11981264

11991265
# Format axes
12001266
if is_datetime:
@@ -1431,6 +1497,8 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
14311497
- 'CI Lower': Lower bound of confidence intervals
14321498
- 'CI Upper': Upper bound of confidence intervals
14331499
- 'Pre-Treatment': Boolean indicating if evaluation period is before treatment
1500+
- 'RV': Robustness values (if sensitivity_analysis() has been called before)
1501+
- 'RVa': Robustness values for (1-a) confidence bounds (if sensitivity_analysis() has been called before)
14341502
14351503
Notes
14361504
-----
@@ -1459,5 +1527,11 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
14591527
"Pre-Treatment": [gt_combination[2] < gt_combination[0] for gt_combination in self.gt_combinations],
14601528
}
14611529
)
1462-
1530+
if self._framework.sensitivity_params is not None:
1531+
df["RV"] = self.framework.sensitivity_params["rv"]
1532+
df["RVa"] = self.framework.sensitivity_params["rva"]
1533+
df["CI Lower Bound"] = self.framework.sensitivity_params["ci"]["lower"]
1534+
df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"]
1535+
df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"]
1536+
df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"]
14631537
return df

doubleml/did/tests/test_did_multi_plot.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,109 @@ def test_plot_effects_jitter(doubleml_did_fixture):
184184
assert fig_default != fig
185185

186186
plt.close("all")
187+
188+
189+
@pytest.mark.ci
190+
def test_plot_effects_result_types(doubleml_did_fixture):
191+
"""Test plot_effects with different result types."""
192+
dml_obj = doubleml_did_fixture["model"]
193+
194+
# Test default result_type='effect'
195+
fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect")
196+
assert isinstance(fig_effect, plt.Figure)
197+
assert isinstance(axes_effect, list)
198+
199+
# Check that the default y-label is set correctly
200+
assert axes_effect[0].get_ylabel() == "Effect"
201+
202+
plt.close("all")
203+
204+
205+
@pytest.mark.ci
206+
def test_plot_effects_result_type_rv(doubleml_did_fixture):
207+
"""Test plot_effects with result_type='rv' (requires sensitivity analysis)."""
208+
dml_obj = doubleml_did_fixture["model"]
209+
210+
# Perform sensitivity analysis first
211+
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
212+
213+
# Test result_type='rv'
214+
fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv")
215+
assert isinstance(fig_rv, plt.Figure)
216+
assert isinstance(axes_rv, list)
217+
218+
# Check that the y-label is set correctly
219+
assert axes_rv[0].get_ylabel() == "Robustness Value"
220+
221+
plt.close("all")
222+
223+
224+
@pytest.mark.ci
225+
def test_plot_effects_result_type_est_bounds(doubleml_did_fixture):
226+
"""Test plot_effects with result_type='est_bounds' (requires sensitivity analysis)."""
227+
dml_obj = doubleml_did_fixture["model"]
228+
229+
# Perform sensitivity analysis first
230+
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
231+
232+
# Test result_type='est_bounds'
233+
fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds")
234+
assert isinstance(fig_est, plt.Figure)
235+
assert isinstance(axes_est, list)
236+
237+
# Check that the y-label is set correctly
238+
assert axes_est[0].get_ylabel() == "Estimate Bounds"
239+
240+
plt.close("all")
241+
242+
243+
@pytest.mark.ci
244+
def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture):
245+
"""Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis)."""
246+
dml_obj = doubleml_did_fixture["model"]
247+
248+
# Perform sensitivity analysis first
249+
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
250+
251+
# Test result_type='ci_bounds'
252+
fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds")
253+
assert isinstance(fig_ci, plt.Figure)
254+
assert isinstance(axes_ci, list)
255+
256+
# Check that the y-label is set correctly
257+
assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds"
258+
259+
plt.close("all")
260+
261+
262+
@pytest.mark.ci
263+
def test_plot_effects_result_type_invalid(doubleml_did_fixture):
264+
"""Test plot_effects with invalid result_type."""
265+
dml_obj = doubleml_did_fixture["model"]
266+
267+
# Test with invalid result_type
268+
with pytest.raises(ValueError, match="result_type must be either"):
269+
dml_obj.plot_effects(result_type="invalid_type")
270+
271+
plt.close("all")
272+
273+
274+
@pytest.mark.ci
275+
def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture):
276+
"""Test plot_effects with result_type and custom labels."""
277+
dml_obj = doubleml_did_fixture["model"]
278+
279+
# Perform sensitivity analysis first
280+
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
281+
282+
# Test result_type with custom labels
283+
custom_title = "Custom Sensitivity Plot"
284+
custom_ylabel = "Custom Bounds Label"
285+
286+
fig, axes = dml_obj.plot_effects(result_type="est_bounds", title=custom_title, y_label=custom_ylabel)
287+
288+
assert isinstance(fig, plt.Figure)
289+
assert fig._suptitle.get_text() == custom_title
290+
assert axes[0].get_ylabel() == custom_ylabel
291+
292+
plt.close("all")

0 commit comments

Comments
 (0)