@@ -979,12 +979,13 @@ def aggregate(self, aggregation="group"):
979979 def plot_effects (
980980 self ,
981981 level = 0.95 ,
982+ result_type = "effect" ,
982983 joint = True ,
983984 figsize = (12 , 8 ),
984985 color_palette = "colorblind" ,
985986 date_format = None ,
986- y_label = "Effect" ,
987- title = "Estimated ATTs by Group" ,
987+ y_label = None ,
988+ title = None ,
988989 jitter_value = None ,
989990 default_jitter = 0.1 ,
990991 ):
@@ -996,6 +997,10 @@ def plot_effects(
996997 level : float
997998 The confidence level for the intervals.
998999 Default is ``0.95``.
1000+ result_type : str
1001+ Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
1002+ ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
1003+ Default is ``'effect'``.
9991004 joint : bool
10001005 Indicates whether joint confidence intervals are computed.
10011006 Default is ``True``.
@@ -1010,10 +1015,10 @@ def plot_effects(
10101015 Default is ``None``.
10111016 y_label : str
10121017 Label for y-axis.
1013- Default is ``"Effect" ``.
1018+ Default is ``None ``.
10141019 title : str
10151020 Title for the entire plot.
1016- Default is ``"Estimated ATTs by Group" ``.
1021+ Default is ``None ``.
10171022 jitter_value : float
10181023 Amount of jitter to apply to points.
10191024 Default is ``None``.
@@ -1035,8 +1040,29 @@ def plot_effects(
10351040 """
10361041 if self .framework is None :
10371042 raise ValueError ("Apply fit() before plot_effects()." )
1043+
1044+ if result_type not in ["effect" , "rv" , "est_bounds" , "ci_bounds" ]:
1045+ raise ValueError ("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'." )
1046+
1047+ if result_type != "effect" and self ._framework .sensitivity_params is None :
1048+ raise ValueError (
1049+ f"result_type='{ result_type } ' requires sensitivity analysis. " "Please call sensitivity_analysis() first."
1050+ )
1051+
10381052 df = self ._create_ci_dataframe (level = level , joint = joint )
10391053
1054+ # Set default y_label and title based on result_type
1055+ label_configs = {
1056+ "effect" : {"y_label" : "Effect" , "title" : "Estimated ATTs by Group" },
1057+ "rv" : {"y_label" : "Robustness Value" , "title" : "Robustness Values by Group" },
1058+ "est_bounds" : {"y_label" : "Estimate Bounds" , "title" : "Estimate Bounds by Group" },
1059+ "ci_bounds" : {"y_label" : "Confidence Interval Bounds" , "title" : "Confidence Interval Bounds by Group" },
1060+ }
1061+
1062+ config = label_configs [result_type ]
1063+ y_label = y_label if y_label is not None else config ["y_label" ]
1064+ title = title if title is not None else config ["title" ]
1065+
10401066 # Sort time periods and treatment groups
10411067 first_treated_periods = sorted (df ["First Treated" ].unique ())
10421068 n_periods = len (first_treated_periods )
@@ -1068,7 +1094,7 @@ def plot_effects(
10681094 period_df = df [df ["First Treated" ] == period ]
10691095 ax = axes [idx ]
10701096
1071- self ._plot_single_group (ax , period_df , period , colors , is_datetime , jitter_value )
1097+ self ._plot_single_group (ax , period_df , period , result_type , colors , is_datetime , jitter_value )
10721098
10731099 # Set axis labels
10741100 if idx == n_periods - 1 : # Only bottom plot gets x label
@@ -1085,7 +1111,7 @@ def plot_effects(
10851111 legend_ax .axis ("off" )
10861112 legend_elements = [
10871113 Line2D ([0 ], [0 ], color = "red" , linestyle = ":" , alpha = 0.7 , label = "Treatment start" ),
1088- Line2D ([0 ], [0 ], color = "black" , linestyle = "--" , alpha = 0.5 , label = "Zero effect " ),
1114+ Line2D ([0 ], [0 ], color = "black" , linestyle = "--" , alpha = 0.5 , label = f "Zero { result_type } " ),
10891115 Line2D ([0 ], [0 ], marker = "o" , color = colors ["pre" ], linestyle = "None" , label = "Pre-treatment" , markersize = 5 ),
10901116 ]
10911117
@@ -1108,7 +1134,7 @@ def plot_effects(
11081134
11091135 return fig , axes
11101136
1111- def _plot_single_group (self , ax , period_df , period , colors , is_datetime , jitter_value ):
1137+ def _plot_single_group (self , ax , period_df , period , result_type , colors , is_datetime , jitter_value ):
11121138 """
11131139 Plot estimates for a single treatment group on the given axis.
11141140
@@ -1120,6 +1146,10 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11201146 DataFrame containing estimates for a specific time period.
11211147 period : int or datetime
11221148 Treatment period for this group.
1149+ result_type : str
1150+ Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
1151+ ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
1152+ Default is ``'effect'``.
11231153 colors : dict
11241154 Dictionary with 'pre', 'anticipation' (if applicable), and 'post' color values.
11251155 is_datetime : bool
@@ -1165,6 +1195,31 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11651195 # Define category mappings
11661196 categories = [("pre" , pre_treatment_mask ), ("anticipation" , anticipation_mask ), ("post" , post_treatment_mask )]
11671197
1198+ # Define plot configurations for each result type
1199+ plot_configs = {
1200+ "effect" : {"plot_col" : "Estimate" , "err_col_upper" : "CI Upper" , "err_col_lower" : "CI Lower" , "s_val" : 30 },
1201+ "rv" : {"plot_col" : "RV" , "plot_col_2" : "RVa" , "s_val" : 50 },
1202+ "est_bounds" : {
1203+ "plot_col" : "Estimate" ,
1204+ "err_col_upper" : "Estimate Upper Bound" ,
1205+ "err_col_lower" : "Estimate Lower Bound" ,
1206+ "s_val" : 30 ,
1207+ },
1208+ "ci_bounds" : {
1209+ "plot_col" : "Estimate" ,
1210+ "err_col_upper" : "CI Upper Bound" ,
1211+ "err_col_lower" : "CI Lower Bound" ,
1212+ "s_val" : 30 ,
1213+ },
1214+ }
1215+
1216+ config = plot_configs [result_type ]
1217+ plot_col = config ["plot_col" ]
1218+ plot_col_2 = config .get ("plot_col_2" )
1219+ err_col_upper = config .get ("err_col_upper" )
1220+ err_col_lower = config .get ("err_col_lower" )
1221+ s_val = config ["s_val" ]
1222+
11681223 # Plot each category
11691224 for category_name , mask in categories :
11701225 if not mask .any ():
@@ -1179,22 +1234,33 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
11791234
11801235 if not category_data .empty :
11811236 ax .scatter (
1182- category_data ["jittered_x" ], category_data ["Estimate" ], color = colors [category_name ], alpha = 0.8 , s = 30
1183- )
1184- ax .errorbar (
1185- category_data ["jittered_x" ],
1186- category_data ["Estimate" ],
1187- yerr = [
1188- category_data ["Estimate" ] - category_data ["CI Lower" ],
1189- category_data ["CI Upper" ] - category_data ["Estimate" ],
1190- ],
1191- fmt = "o" ,
1192- capsize = 3 ,
1193- color = colors [category_name ],
1194- markersize = 4 ,
1195- markeredgewidth = 1 ,
1196- linewidth = 1 ,
1237+ category_data ["jittered_x" ], category_data [plot_col ], color = colors [category_name ], alpha = 0.8 , s = s_val
11971238 )
1239+ if result_type in ["effect" , "est_bounds" , "ci_bounds" ]:
1240+ ax .errorbar (
1241+ category_data ["jittered_x" ],
1242+ category_data [plot_col ],
1243+ yerr = [
1244+ category_data [plot_col ] - category_data [err_col_lower ],
1245+ category_data [err_col_upper ] - category_data [plot_col ],
1246+ ],
1247+ fmt = "o" ,
1248+ capsize = 3 ,
1249+ color = colors [category_name ],
1250+ markersize = 4 ,
1251+ markeredgewidth = 1 ,
1252+ linewidth = 1 ,
1253+ )
1254+
1255+ elif result_type == "rv" :
1256+ ax .scatter (
1257+ category_data ["jittered_x" ],
1258+ category_data [plot_col_2 ],
1259+ color = colors [category_name ],
1260+ alpha = 0.8 ,
1261+ s = s_val ,
1262+ marker = "s" ,
1263+ )
11981264
11991265 # Format axes
12001266 if is_datetime :
@@ -1431,6 +1497,8 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
14311497 - 'CI Lower': Lower bound of confidence intervals
14321498 - 'CI Upper': Upper bound of confidence intervals
14331499 - 'Pre-Treatment': Boolean indicating if evaluation period is before treatment
1500+ - 'RV': Robustness values (if sensitivity_analysis() has been called before)
1501+ - 'RVa': Robustness values for (1-a) confidence bounds (if sensitivity_analysis() has been called before)
14341502
14351503 Notes
14361504 -----
@@ -1459,5 +1527,11 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
14591527 "Pre-Treatment" : [gt_combination [2 ] < gt_combination [0 ] for gt_combination in self .gt_combinations ],
14601528 }
14611529 )
1462-
1530+ if self ._framework .sensitivity_params is not None :
1531+ df ["RV" ] = self .framework .sensitivity_params ["rv" ]
1532+ df ["RVa" ] = self .framework .sensitivity_params ["rva" ]
1533+ df ["CI Lower Bound" ] = self .framework .sensitivity_params ["ci" ]["lower" ]
1534+ df ["CI Upper Bound" ] = self .framework .sensitivity_params ["ci" ]["upper" ]
1535+ df ["Estimate Lower Bound" ] = self .framework .sensitivity_params ["theta" ]["lower" ]
1536+ df ["Estimate Upper Bound" ] = self .framework .sensitivity_params ["theta" ]["upper" ]
14631537 return df
0 commit comments