@@ -1095,7 +1095,7 @@ def ml_l_params(trial):
10951095 """
10961096 # Validation
10971097
1098- requested_learners , expanded_param_space = self ._validate_optuna_param_space (ml_param_space )
1098+ expanded_param_space = self ._validate_optuna_param_space (ml_param_space )
10991099 scoring_methods = self ._resolve_scoring_methods (scoring_methods )
11001100 cv_splitter = resolve_optuna_cv (cv )
11011101 self ._validate_optuna_setting_keys (optuna_settings )
@@ -1123,11 +1123,9 @@ def ml_l_params(trial):
11231123 optuna_settings ,
11241124 )
11251125
1126- filtered_results = {key : value for key , value in res .items () if key in requested_learners }
1127- tuning_res [i_d ] = filtered_results
1128-
1126+ tuning_res [i_d ] = res
11291127 if set_as_params :
1130- for nuisance_model , tuned_result in filtered_results .items ():
1128+ for nuisance_model , tuned_result in res .items ():
11311129 if tuned_result is None :
11321130 params_to_set = None
11331131 else :
@@ -1220,7 +1218,6 @@ def _validate_optuna_param_space(self, ml_param_space):
12201218 + valid_keys_msg
12211219 + "."
12221220 )
1223- requested_learners = set (ml_param_space .keys ())
12241221 final_param_space = {k : None for k in self .params_names }
12251222
12261223 # Validate that all parameter spaces are callables
@@ -1242,7 +1239,7 @@ def _validate_optuna_param_space(self, ml_param_space):
12421239 for param_key in [pk for pk in self .params_names if pk in ml_param_space .keys ()]:
12431240 final_param_space [param_key ] = ml_param_space [param_key ]
12441241
1245- return requested_learners , final_param_space
1242+ return final_param_space
12461243
12471244 def set_ml_nuisance_params (self , learner , treat_var , params ):
12481245 """
0 commit comments