Skip to content

Commit 477fdb5

Browse files
fix issue for joining param_spaces
1 parent f2697c1 commit 477fdb5

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

doubleml/utils/_tune_optuna.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,35 @@ def _dml_tune_optuna(
540540

541541

542542
def _join_param_spaces(param_space_global, param_space_local):
543+
if param_space_global is None:
544+
return param_space_local
545+
if param_space_local is None:
546+
return param_space_global
547+
543548
def joined_param_space(trial):
544-
return param_space_global(trial) | param_space_local(trial)
549+
local_params = param_space_local(trial)
550+
551+
class _ProxyTrial:
552+
def __init__(self, base_trial, overrides):
553+
self._base_trial = base_trial
554+
self._overrides = overrides
555+
556+
def __getattr__(self, name):
557+
attr = getattr(self._base_trial, name)
558+
if not callable(attr) or not name.startswith("suggest_"):
559+
return attr
560+
561+
def wrapped(*args, **kwargs):
562+
key = args[0] if args else kwargs.get("name")
563+
if key in self._overrides:
564+
return self._overrides[key]
565+
return attr(*args, **kwargs)
566+
567+
return wrapped
568+
569+
proxy_trial = _ProxyTrial(trial, local_params)
570+
global_params = param_space_global(proxy_trial)
571+
572+
return {**global_params, **local_params}
545573

546574
return joined_param_space

0 commit comments

Comments
 (0)