Skip to content

Commit 3fab6f9

Browse files
committed
formatting
1 parent 831dcb8 commit 3fab6f9

File tree

2 files changed

+52
-41
lines changed

2 files changed

+52
-41
lines changed

doubleml/did/did_multi.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,24 +1040,25 @@ def plot_effects(
10401040
"""
10411041
if self.framework is None:
10421042
raise ValueError("Apply fit() before plot_effects().")
1043-
1043+
10441044
if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]:
10451045
raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.")
1046-
1046+
10471047
if result_type != "effect" and self._framework.sensitivity_params is None:
1048-
raise ValueError(f"result_type='{result_type}' requires sensitivity analysis. "
1049-
"Please call sensitivity_analysis() first.")
1050-
1048+
raise ValueError(
1049+
f"result_type='{result_type}' requires sensitivity analysis. " "Please call sensitivity_analysis() first."
1050+
)
1051+
10511052
df = self._create_ci_dataframe(level=level, joint=joint)
10521053

10531054
# Set default y_label and title based on result_type
10541055
label_configs = {
10551056
"effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"},
10561057
"rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"},
10571058
"est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"},
1058-
"ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"}
1059+
"ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"},
10591060
}
1060-
1061+
10611062
config = label_configs[result_type]
10621063
y_label = y_label if y_label is not None else config["y_label"]
10631064
title = title if title is not None else config["title"]
@@ -1198,10 +1199,20 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date
11981199
plot_configs = {
11991200
"effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30},
12001201
"rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50},
1201-
"est_bounds": {"plot_col": "Estimate", "err_col_upper": "Estimate Upper Bound", "err_col_lower": "Estimate Lower Bound", "s_val": 30},
1202-
"ci_bounds": {"plot_col": "Estimate", "err_col_upper": "CI Upper Bound", "err_col_lower": "CI Lower Bound", "s_val": 30}
1202+
"est_bounds": {
1203+
"plot_col": "Estimate",
1204+
"err_col_upper": "Estimate Upper Bound",
1205+
"err_col_lower": "Estimate Lower Bound",
1206+
"s_val": 30,
1207+
},
1208+
"ci_bounds": {
1209+
"plot_col": "Estimate",
1210+
"err_col_upper": "CI Upper Bound",
1211+
"err_col_lower": "CI Lower Bound",
1212+
"s_val": 30,
1213+
},
12031214
}
1204-
1215+
12051216
config = plot_configs[result_type]
12061217
plot_col = config["plot_col"]
12071218
plot_col_2 = config.get("plot_col_2")
@@ -1240,13 +1251,17 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date
12401251
markeredgewidth=1,
12411252
linewidth=1,
12421253
)
1243-
1254+
12441255
elif result_type == "rv":
12451256
ax.scatter(
1246-
category_data["jittered_x"], category_data[plot_col_2], color=colors[category_name], alpha=0.8, s=s_val,
1247-
marker="s"
1257+
category_data["jittered_x"],
1258+
category_data[plot_col_2],
1259+
color=colors[category_name],
1260+
alpha=0.8,
1261+
s=s_val,
1262+
marker="s",
12481263
)
1249-
1264+
12501265
# Format axes
12511266
if is_datetime:
12521267
period_str = np.datetime64(period, self._dml_data.datetime_unit)
@@ -1519,4 +1534,4 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
15191534
df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"]
15201535
df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"]
15211536
df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"]
1522-
return df
1537+
return df

doubleml/did/tests/test_did_multi_plot.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -195,102 +195,98 @@ def test_plot_effects_result_types(doubleml_did_fixture):
195195
fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect")
196196
assert isinstance(fig_effect, plt.Figure)
197197
assert isinstance(axes_effect, list)
198-
198+
199199
# Check that the default y-label is set correctly
200200
assert axes_effect[0].get_ylabel() == "Effect"
201-
201+
202202
plt.close("all")
203203

204204

205205
@pytest.mark.ci
206206
def test_plot_effects_result_type_rv(doubleml_did_fixture):
207207
"""Test plot_effects with result_type='rv' (requires sensitivity analysis)."""
208208
dml_obj = doubleml_did_fixture["model"]
209-
209+
210210
# Perform sensitivity analysis first
211211
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
212-
212+
213213
# Test result_type='rv'
214214
fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv")
215215
assert isinstance(fig_rv, plt.Figure)
216216
assert isinstance(axes_rv, list)
217-
217+
218218
# Check that the y-label is set correctly
219219
assert axes_rv[0].get_ylabel() == "Robustness Value"
220-
220+
221221
plt.close("all")
222222

223223

224224
@pytest.mark.ci
225225
def test_plot_effects_result_type_est_bounds(doubleml_did_fixture):
226226
"""Test plot_effects with result_type='est_bounds' (requires sensitivity analysis)."""
227227
dml_obj = doubleml_did_fixture["model"]
228-
228+
229229
# Perform sensitivity analysis first
230230
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
231-
231+
232232
# Test result_type='est_bounds'
233233
fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds")
234234
assert isinstance(fig_est, plt.Figure)
235235
assert isinstance(axes_est, list)
236-
236+
237237
# Check that the y-label is set correctly
238238
assert axes_est[0].get_ylabel() == "Estimate Bounds"
239-
239+
240240
plt.close("all")
241241

242242

243243
@pytest.mark.ci
244244
def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture):
245245
"""Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis)."""
246246
dml_obj = doubleml_did_fixture["model"]
247-
247+
248248
# Perform sensitivity analysis first
249249
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
250-
250+
251251
# Test result_type='ci_bounds'
252252
fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds")
253253
assert isinstance(fig_ci, plt.Figure)
254254
assert isinstance(axes_ci, list)
255-
255+
256256
# Check that the y-label is set correctly
257257
assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds"
258-
258+
259259
plt.close("all")
260260

261261

262262
@pytest.mark.ci
263263
def test_plot_effects_result_type_invalid(doubleml_did_fixture):
264264
"""Test plot_effects with invalid result_type."""
265265
dml_obj = doubleml_did_fixture["model"]
266-
266+
267267
# Test with invalid result_type
268268
with pytest.raises(ValueError, match="result_type must be either"):
269269
dml_obj.plot_effects(result_type="invalid_type")
270-
270+
271271
plt.close("all")
272272

273273

274274
@pytest.mark.ci
275275
def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture):
276276
"""Test plot_effects with result_type and custom labels."""
277277
dml_obj = doubleml_did_fixture["model"]
278-
278+
279279
# Perform sensitivity analysis first
280280
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)
281-
281+
282282
# Test result_type with custom labels
283283
custom_title = "Custom Sensitivity Plot"
284284
custom_ylabel = "Custom Bounds Label"
285-
286-
fig, axes = dml_obj.plot_effects(
287-
result_type="est_bounds",
288-
title=custom_title,
289-
y_label=custom_ylabel
290-
)
291-
285+
286+
fig, axes = dml_obj.plot_effects(result_type="est_bounds", title=custom_title, y_label=custom_ylabel)
287+
292288
assert isinstance(fig, plt.Figure)
293289
assert fig._suptitle.get_text() == custom_title
294290
assert axes[0].get_ylabel() == custom_ylabel
295-
291+
296292
plt.close("all")

0 commit comments

Comments
 (0)