Skip to content

Commit a891405

Browse files
Permit BudgetOptimizer.allocate_budget() to take x0 as an argument (#1565)
* formatting * expose x0 as an arg in allocate_budget * change typehint and validate dtype * streamline shape and dtype validation * add new bo test with manually set x0 --------- Co-authored-by: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com>
1 parent a2cf877 commit a891405

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

pymc_marketing/mmm/budget_optimizer.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def allocate_budget(
373373
self,
374374
total_budget: float,
375375
budget_bounds: DataArray | dict[str, tuple[float, float]] | None = None,
376+
x0: np.ndarray | None = None,
376377
minimize_kwargs: dict[str, Any] | None = None,
377378
return_if_fail: bool = False,
378379
) -> tuple[DataArray, OptimizeResult]:
@@ -391,8 +392,11 @@ def allocate_budget(
391392
- If None, default bounds of [0, total_budget] per channel are assumed.
392393
- If a dict, must map each channel to (low, high) budget pairs (only valid if there's one dimension).
393394
- If an xarray.DataArray, must have dims (*budget_dims, "bound"), specifying [low, high] per channel cell.
395+
x0 : np.ndarray, optional
396+
Initial guess. Array of real elements of size (n,), where n is the number of driver budgets to optimize. If
397+
None, the total budget is spread uniformly across all drivers to be optimized.
394398
minimize_kwargs : dict, optional
395-
Extra kwargs for `scipy.optimize.minimize`. Defaults to method "SLSQP",
399+
Extra kwargs for `scipy.optimize.minimize`. Defaults to method="SLSQP",
396400
ftol=1e-9, maxiter=1_000.
397401
return_if_fail : bool, optional
398402
Return output even if optimization fails. Default is False.
@@ -409,8 +413,16 @@ def allocate_budget(
409413
MinimizeException
410414
If the optimization fails for any reason, the exception message will contain the details.
411415
"""
416+
# set total budget
412417
self._total_budget.set_value(np.asarray(total_budget, dtype="float64"))
413418

419+
# coordinate user-provided and default minimize_kwargs
420+
if minimize_kwargs is None:
421+
minimize_kwargs = self.DEFAULT_MINIMIZE_KWARGS
422+
else:
423+
# Merge with defaults (preferring user-supplied keys)
424+
minimize_kwargs = {**self.DEFAULT_MINIMIZE_KWARGS, **minimize_kwargs}
425+
414426
# 1. Process budget bounds
415427
if budget_bounds is None:
416428
warnings.warn(
@@ -466,21 +478,21 @@ def allocate_budget(
466478
else:
467479
budgets_size = self.budgets_to_optimize.sum().item()
468480

469-
# 5. Create an initial guess
470-
initial_guess = np.ones(budgets_size) * (total_budget / budgets_size)
471-
initial_guess = initial_guess.astype(self._budgets_flat.type.dtype)
481+
# 5. Construct the initial guess (x0) if not provided
482+
if x0 is None:
483+
x0 = np.ones(budgets_size) * (total_budget / budgets_size).astype(
484+
self._budgets_flat.type.dtype
485+
)
472486

473-
if minimize_kwargs is None:
474-
minimize_kwargs = self.DEFAULT_MINIMIZE_KWARGS.copy()
475-
else:
476-
# Merge with defaults (preferring user-supplied keys)
477-
minimize_kwargs = {**self.DEFAULT_MINIMIZE_KWARGS, **minimize_kwargs}
487+
# filter x0 based on shape/type of self._budgets_flat
488+
# will raise a TypeError if x0 does not have acceptable shape and/or type
489+
x0 = self._budgets_flat.type.filter(x0)
478490

479491
# 6. Run the SciPy optimizer
480492
result = minimize(
481493
fun=self._compiled_functions[self.utility_function]["objective_and_grad"],
494+
x0=x0,
482495
jac=True,
483-
x0=initial_guess,
484496
bounds=bounds,
485497
constraints=self._compiled_constraints,
486498
**minimize_kwargs,

tests/mmm/test_budget_optimizer.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,36 @@ def dummy_df():
5959

6060

6161
@pytest.mark.parametrize(
62-
argnames="total_budget, budget_bounds, parameters, minimize_kwargs, expected_optimal, expected_response",
62+
argnames="total_budget, budget_bounds, x0, parameters, minimize_kwargs, expected_optimal, expected_response",
6363
argvalues=[
6464
(
6565
100,
6666
None,
67+
None,
68+
{
69+
"saturation_params": {
70+
"lam": np.array(
71+
[[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]
72+
), # dims: chain, draw, channel
73+
"beta": np.array(
74+
[[[0.5, 1.0], [0.5, 1.0]], [[0.5, 1.0], [0.5, 1.0]]]
75+
), # dims: chain, draw, channel
76+
},
77+
"adstock_params": {
78+
"alpha": np.array(
79+
[[[0.5, 0.7], [0.5, 0.7]], [[0.5, 0.7], [0.5, 0.7]]]
80+
) # dims: chain, draw, channel
81+
},
82+
},
83+
None,
84+
{"channel_1": 54.78357587906867, "channel_2": 45.21642412093133},
85+
48.8,
86+
),
87+
# set x0 manually
88+
(
89+
100,
90+
None,
91+
np.array([50, 50]),
6792
{
6893
"saturation_params": {
6994
"lam": np.array(
@@ -91,6 +116,7 @@ def dummy_df():
91116
channel=["channel_1", "channel_2"],
92117
bound=["lower", "upper"],
93118
),
119+
None,
94120
{
95121
"saturation_params": {
96122
"lam": np.array(
@@ -121,6 +147,7 @@ def dummy_df():
121147
channel=["channel_1", "channel_2"],
122148
bound=["lower", "upper"],
123149
),
150+
None,
124151
{
125152
"saturation_params": {
126153
"lam": np.array(
@@ -142,11 +169,17 @@ def dummy_df():
142169
2.38e-10,
143170
),
144171
],
145-
ids=["default_minimizer_kwargs", "custom_minimizer_kwargs", "zero_total_budget"],
172+
ids=[
173+
"default_minimizer_kwargs",
174+
"manually_set_x0",
175+
"custom_minimizer_kwargs",
176+
"zero_total_budget",
177+
],
146178
)
147179
def test_allocate_budget(
148180
total_budget,
149181
budget_bounds,
182+
x0,
150183
parameters,
151184
minimize_kwargs,
152185
expected_optimal,
@@ -184,6 +217,7 @@ def test_allocate_budget(
184217
optimal_budgets, optimization_res = optimizer.allocate_budget(
185218
total_budget=total_budget,
186219
budget_bounds=budget_bounds,
220+
x0=x0,
187221
minimize_kwargs=minimize_kwargs,
188222
)
189223

0 commit comments

Comments
 (0)