Skip to content

Commit 0913b1b

Browse files
committed
refactor _get_optuna_settings to simplify learner-specific settings retrieval and merge logic
1 parent 84dfa64 commit 0913b1b

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

doubleml/utils/_tune_optuna.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,25 +286,17 @@ def _get_optuna_settings(optuna_settings, params_name):
286286

287287
# Find matching learner-specific settings, handles the case to match ml_g to ml_g0, ml_g1, etc.
288288
learner_specific_settings = {}
289-
prefix_matches = [key for key in learner_or_params_keys if key != params_name and params_name.startswith(key)]
290-
if prefix_matches:
291-
learner_key = max(prefix_matches, key=len)
292-
learner_specific_settings = optuna_settings[learner_key]
293-
if not isinstance(learner_specific_settings, dict):
294-
raise TypeError(f"Optuna settings for '{learner_key}' must be a dict.")
289+
for k in learner_or_params_keys:
290+
if k in params_name and params_name != k:
291+
learner_specific_settings = optuna_settings[k]
295292

296293
# set params specific settings
297294
params_specific_settings = {}
298295
if params_name in learner_or_params_keys:
299296
params_specific_settings = optuna_settings[params_name]
300-
if not isinstance(params_specific_settings, dict):
301-
raise TypeError(f"Optuna settings for '{params_name}' must be a dict.")
302297

303298
# Merge settings: defaults < base < learner-specific < params_specific
304-
resolved = default_settings.copy()
305-
resolved |= base_settings
306-
resolved |= learner_specific_settings
307-
resolved |= params_specific_settings
299+
resolved = default_settings.copy() | base_settings | learner_specific_settings | params_specific_settings
308300

309301
# Validate types
310302
if not isinstance(resolved["study_kwargs"], dict):

0 commit comments

Comments
 (0)