@@ -256,7 +256,7 @@ def _check_tuning_inputs(
256256 return resolve_optuna_cv (cv )
257257
258258
259- def _get_optuna_settings (optuna_settings , params_name = None ):
259+ def _get_optuna_settings (optuna_settings , params_name ):
260260 """
261261 Get Optuna settings, considering defaults, user-provided values, and learner-specific overrides.
262262
@@ -265,7 +265,7 @@ def _get_optuna_settings(optuna_settings, params_name=None):
265265 optuna_settings : dict or None
266266 User-provided Optuna settings.
267267 params_name : str
268- Name of the learner to check for specific setting, e.g. `ml_g0` or `ml_g1` for `DoubleMLIRM`.
268+ Name of the nuisance params to check for specific setting, e.g. `ml_g0` or `ml_g1` for `DoubleMLIRM`.
269269
270270 Returns
271271 -------
@@ -286,18 +286,25 @@ def _get_optuna_settings(optuna_settings, params_name=None):
286286
287287 # Find matching learner-specific settings, handles the case to match ml_g to ml_g0, ml_g1, etc.
288288 learner_specific_settings = {}
289- if any (params_name in key for key in learner_or_params_keys ):
290- for k in learner_or_params_keys :
291- if params_name in k and params_name != k :
292- learner_specific_settings = optuna_settings [k ]
289+ prefix_matches = [key for key in learner_or_params_keys if key != params_name and params_name .startswith (key )]
290+ if prefix_matches :
291+ learner_key = max (prefix_matches , key = len )
292+ learner_specific_settings = optuna_settings [learner_key ]
293+ if not isinstance (learner_specific_settings , dict ):
294+ raise TypeError (f"Optuna settings for '{ learner_key } ' must be a dict." )
293295
294296 # set params specific settings
295297 params_specific_settings = {}
296298 if params_name in learner_or_params_keys :
297299 params_specific_settings = optuna_settings [params_name ]
300+ if not isinstance (params_specific_settings , dict ):
301+ raise TypeError (f"Optuna settings for '{ params_name } ' must be a dict." )
298302
299303 # Merge settings: defaults < base < learner-specific < params_specific
300- resolved = default_settings .copy () | base_settings | learner_specific_settings | params_specific_settings
304+ resolved = default_settings .copy ()
305+ resolved |= base_settings
306+ resolved |= learner_specific_settings
307+ resolved |= params_specific_settings
301308
302309 # Validate types
303310 if not isinstance (resolved ["study_kwargs" ], dict ):
0 commit comments