@@ -1040,24 +1040,25 @@ def plot_effects(
10401040 """
10411041 if self .framework is None :
10421042 raise ValueError ("Apply fit() before plot_effects()." )
1043-
1043+
10441044 if result_type not in ["effect" , "rv" , "est_bounds" , "ci_bounds" ]:
10451045 raise ValueError ("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'." )
1046-
1046+
10471047 if result_type != "effect" and self ._framework .sensitivity_params is None :
1048- raise ValueError (f"result_type='{ result_type } ' requires sensitivity analysis. "
1049- "Please call sensitivity_analysis() first." )
1050-
1048+ raise ValueError (
1049+ f"result_type='{ result_type } ' requires sensitivity analysis. " "Please call sensitivity_analysis() first."
1050+ )
1051+
10511052 df = self ._create_ci_dataframe (level = level , joint = joint )
10521053
10531054 # Set default y_label and title based on result_type
10541055 label_configs = {
10551056 "effect" : {"y_label" : "Effect" , "title" : "Estimated ATTs by Group" },
10561057 "rv" : {"y_label" : "Robustness Value" , "title" : "Robustness Values by Group" },
10571058 "est_bounds" : {"y_label" : "Estimate Bounds" , "title" : "Estimate Bounds by Group" },
1058- "ci_bounds" : {"y_label" : "Confidence Interval Bounds" , "title" : "Confidence Interval Bounds by Group" }
1059+ "ci_bounds" : {"y_label" : "Confidence Interval Bounds" , "title" : "Confidence Interval Bounds by Group" },
10591060 }
1060-
1061+
10611062 config = label_configs [result_type ]
10621063 y_label = y_label if y_label is not None else config ["y_label" ]
10631064 title = title if title is not None else config ["title" ]
@@ -1198,10 +1199,20 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date
11981199 plot_configs = {
11991200 "effect" : {"plot_col" : "Estimate" , "err_col_upper" : "CI Upper" , "err_col_lower" : "CI Lower" , "s_val" : 30 },
12001201 "rv" : {"plot_col" : "RV" , "plot_col_2" : "RVa" , "s_val" : 50 },
1201- "est_bounds" : {"plot_col" : "Estimate" , "err_col_upper" : "Estimate Upper Bound" , "err_col_lower" : "Estimate Lower Bound" , "s_val" : 30 },
1202- "ci_bounds" : {"plot_col" : "Estimate" , "err_col_upper" : "CI Upper Bound" , "err_col_lower" : "CI Lower Bound" , "s_val" : 30 }
1202+ "est_bounds" : {
1203+ "plot_col" : "Estimate" ,
1204+ "err_col_upper" : "Estimate Upper Bound" ,
1205+ "err_col_lower" : "Estimate Lower Bound" ,
1206+ "s_val" : 30 ,
1207+ },
1208+ "ci_bounds" : {
1209+ "plot_col" : "Estimate" ,
1210+ "err_col_upper" : "CI Upper Bound" ,
1211+ "err_col_lower" : "CI Lower Bound" ,
1212+ "s_val" : 30 ,
1213+ },
12031214 }
1204-
1215+
12051216 config = plot_configs [result_type ]
12061217 plot_col = config ["plot_col" ]
12071218 plot_col_2 = config .get ("plot_col_2" )
@@ -1240,13 +1251,17 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date
12401251 markeredgewidth = 1 ,
12411252 linewidth = 1 ,
12421253 )
1243-
1254+
12441255 elif result_type == "rv" :
12451256 ax .scatter (
1246- category_data ["jittered_x" ], category_data [plot_col_2 ], color = colors [category_name ], alpha = 0.8 , s = s_val ,
1247- marker = "s"
1257+ category_data ["jittered_x" ],
1258+ category_data [plot_col_2 ],
1259+ color = colors [category_name ],
1260+ alpha = 0.8 ,
1261+ s = s_val ,
1262+ marker = "s" ,
12481263 )
1249-
1264+
12501265 # Format axes
12511266 if is_datetime :
12521267 period_str = np .datetime64 (period , self ._dml_data .datetime_unit )
@@ -1519,4 +1534,4 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
15191534 df ["CI Upper Bound" ] = self .framework .sensitivity_params ["ci" ]["upper" ]
15201535 df ["Estimate Lower Bound" ] = self .framework .sensitivity_params ["theta" ]["lower" ]
15211536 df ["Estimate Upper Bound" ] = self .framework .sensitivity_params ["theta" ]["upper" ]
1522- return df
1537+ return df
0 commit comments