Skip to content

Commit f2697c1

Browse files
fix setting-merge bug in _get_optuna_settings
1 parent 9445cd3 commit f2697c1

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

doubleml/utils/_tune_optuna.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _check_tuning_inputs(
256256
return resolve_optuna_cv(cv)
257257

258258

259-
def _get_optuna_settings(optuna_settings, params_name=None):
259+
def _get_optuna_settings(optuna_settings, params_name):
260260
"""
261261
Get Optuna settings, considering defaults, user-provided values, and learner-specific overrides.
262262
@@ -265,7 +265,7 @@ def _get_optuna_settings(optuna_settings, params_name=None):
265265
optuna_settings : dict or None
266266
User-provided Optuna settings.
267267
params_name : str
268-
Name of the learner to check for specific setting, e.g. `ml_g0` or `ml_g1` for `DoubleMLIRM`.
268+
Name of the nuisance params to check for specific setting, e.g. `ml_g0` or `ml_g1` for `DoubleMLIRM`.
269269
270270
Returns
271271
-------
@@ -286,18 +286,25 @@ def _get_optuna_settings(optuna_settings, params_name=None):
286286

287287
# Find matching learner-specific settings, handles the case to match ml_g to ml_g0, ml_g1, etc.
288288
learner_specific_settings = {}
289-
if any(params_name in key for key in learner_or_params_keys):
290-
for k in learner_or_params_keys:
291-
if params_name in k and params_name != k:
292-
learner_specific_settings = optuna_settings[k]
289+
prefix_matches = [key for key in learner_or_params_keys if key != params_name and params_name.startswith(key)]
290+
if prefix_matches:
291+
learner_key = max(prefix_matches, key=len)
292+
learner_specific_settings = optuna_settings[learner_key]
293+
if not isinstance(learner_specific_settings, dict):
294+
raise TypeError(f"Optuna settings for '{learner_key}' must be a dict.")
293295

294296
# set params specific settings
295297
params_specific_settings = {}
296298
if params_name in learner_or_params_keys:
297299
params_specific_settings = optuna_settings[params_name]
300+
if not isinstance(params_specific_settings, dict):
301+
raise TypeError(f"Optuna settings for '{params_name}' must be a dict.")
298302

299303
# Merge settings: defaults < base < learner-specific < params_specific
300-
resolved = default_settings.copy() | base_settings | learner_specific_settings | params_specific_settings
304+
resolved = default_settings.copy()
305+
resolved |= base_settings
306+
resolved |= learner_specific_settings
307+
resolved |= params_specific_settings
301308

302309
# Validate types
303310
if not isinstance(resolved["study_kwargs"], dict):

0 commit comments

Comments
 (0)