@@ -373,6 +373,7 @@ def allocate_budget(
373373 self ,
374374 total_budget : float ,
375375 budget_bounds : DataArray | dict [str , tuple [float , float ]] | None = None ,
376+ x0 : np .ndarray | None = None ,
376377 minimize_kwargs : dict [str , Any ] | None = None ,
377378 return_if_fail : bool = False ,
378379 ) -> tuple [DataArray , OptimizeResult ]:
@@ -391,8 +392,11 @@ def allocate_budget(
391392 - If None, default bounds of [0, total_budget] per channel are assumed.
392393 - If a dict, must map each channel to (low, high) budget pairs (only valid if there's one dimension).
393394 - If an xarray.DataArray, must have dims (*budget_dims, "bound"), specifying [low, high] per channel cell.
395+ x0 : np.ndarray, optional
396+ Initial guess. Array of real elements of size (n,), where n is the number of driver budgets to optimize. If
397+ None, the total budget is spread uniformly across all drivers to be optimized.
394398 minimize_kwargs : dict, optional
395- Extra kwargs for `scipy.optimize.minimize`. Defaults to method "SLSQP",
399+ Extra kwargs for `scipy.optimize.minimize`. Defaults to method= "SLSQP",
396400 ftol=1e-9, maxiter=1_000.
397401 return_if_fail : bool, optional
398402 Return output even if optimization fails. Default is False.
@@ -409,8 +413,16 @@ def allocate_budget(
409413 MinimizeException
410414 If the optimization fails for any reason, the exception message will contain the details.
411415 """
416+ # set total budget
412417 self ._total_budget .set_value (np .asarray (total_budget , dtype = "float64" ))
413418
419+ # coordinate user-provided and default minimize_kwargs
420+ if minimize_kwargs is None :
421+ minimize_kwargs = self .DEFAULT_MINIMIZE_KWARGS
422+ else :
423+ # Merge with defaults (preferring user-supplied keys)
424+ minimize_kwargs = {** self .DEFAULT_MINIMIZE_KWARGS , ** minimize_kwargs }
425+
414426 # 1. Process budget bounds
415427 if budget_bounds is None :
416428 warnings .warn (
@@ -466,21 +478,21 @@ def allocate_budget(
466478 else :
467479 budgets_size = self .budgets_to_optimize .sum ().item ()
468480
469- # 5. Create an initial guess
470- initial_guess = np .ones (budgets_size ) * (total_budget / budgets_size )
471- initial_guess = initial_guess .astype (self ._budgets_flat .type .dtype )
481+ # 5. Construct the initial guess (x0) if not provided
482+ if x0 is None :
483+ x0 = np .ones (budgets_size ) * (total_budget / budgets_size ).astype (
484+ self ._budgets_flat .type .dtype
485+ )
472486
473- if minimize_kwargs is None :
474- minimize_kwargs = self .DEFAULT_MINIMIZE_KWARGS .copy ()
475- else :
476- # Merge with defaults (preferring user-supplied keys)
477- minimize_kwargs = {** self .DEFAULT_MINIMIZE_KWARGS , ** minimize_kwargs }
487+ # filter x0 based on shape/type of self._budgets_flat
488+ # will raise a TypeError if x0 does not have acceptable shape and/or type
489+ x0 = self ._budgets_flat .type .filter (x0 )
478490
479491 # 6. Run the SciPy optimizer
480492 result = minimize (
481493 fun = self ._compiled_functions [self .utility_function ]["objective_and_grad" ],
494+ x0 = x0 ,
482495 jac = True ,
483- x0 = initial_guess ,
484496 bounds = bounds ,
485497 constraints = self ._compiled_constraints ,
486498 ** minimize_kwargs ,
0 commit comments