@@ -540,7 +540,35 @@ def _dml_tune_optuna(
540540
541541
542542def _join_param_spaces (param_space_global , param_space_local ):
543+ if param_space_global is None :
544+ return param_space_local
545+ if param_space_local is None :
546+ return param_space_global
547+
543548 def joined_param_space (trial ):
544- return param_space_global (trial ) | param_space_local (trial )
549+ local_params = param_space_local (trial )
550+
551+ class _ProxyTrial :
552+ def __init__ (self , base_trial , overrides ):
553+ self ._base_trial = base_trial
554+ self ._overrides = overrides
555+
556+ def __getattr__ (self , name ):
557+ attr = getattr (self ._base_trial , name )
558+ if not callable (attr ) or not name .startswith ("suggest_" ):
559+ return attr
560+
561+ def wrapped (* args , ** kwargs ):
562+ key = args [0 ] if args else kwargs .get ("name" )
563+ if key in self ._overrides :
564+ return self ._overrides [key ]
565+ return attr (* args , ** kwargs )
566+
567+ return wrapped
568+
569+ proxy_trial = _ProxyTrial (trial , local_params )
570+ global_params = param_space_global (proxy_trial )
571+
572+ return {** global_params , ** local_params }
545573
546574 return joined_param_space
0 commit comments