Skip to content

Commit 37edf6b

Browse files
committed
ensure all parameter spaces are returned
1 parent 40f2262 commit 37edf6b

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

doubleml/double_ml.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,7 @@ def ml_l_params(trial):
10951095
"""
10961096
# Validation
10971097

1098-
requested_learners, expanded_param_space = self._validate_optuna_param_space(ml_param_space)
1098+
expanded_param_space = self._validate_optuna_param_space(ml_param_space)
10991099
scoring_methods = self._resolve_scoring_methods(scoring_methods)
11001100
cv_splitter = resolve_optuna_cv(cv)
11011101
self._validate_optuna_setting_keys(optuna_settings)
@@ -1123,11 +1123,9 @@ def ml_l_params(trial):
11231123
optuna_settings,
11241124
)
11251125

1126-
filtered_results = {key: value for key, value in res.items() if key in requested_learners}
1127-
tuning_res[i_d] = filtered_results
1128-
1126+
tuning_res[i_d] = res
11291127
if set_as_params:
1130-
for nuisance_model, tuned_result in filtered_results.items():
1128+
for nuisance_model, tuned_result in res.items():
11311129
if tuned_result is None:
11321130
params_to_set = None
11331131
else:
@@ -1220,7 +1218,6 @@ def _validate_optuna_param_space(self, ml_param_space):
12201218
+ valid_keys_msg
12211219
+ "."
12221220
)
1223-
requested_learners = set(ml_param_space.keys())
12241221
final_param_space = {k: None for k in self.params_names}
12251222

12261223
# Validate that all parameter spaces are callables
@@ -1242,7 +1239,7 @@ def _validate_optuna_param_space(self, ml_param_space):
12421239
for param_key in [pk for pk in self.params_names if pk in ml_param_space.keys()]:
12431240
final_param_space[param_key] = ml_param_space[param_key]
12441241

1245-
return requested_learners, final_param_space
1242+
return final_param_space
12461243

12471244
def set_ml_nuisance_params(self, learner, treat_var, params):
12481245
"""

0 commit comments

Comments
 (0)