Skip to content

Commit e77252d

Browse files
authored
Add LinearRegression wrapper (#1713)
* implement the transformations * wrapper around linear regression * expose via the mmm module * rename variable to beta * test for linear regression * change name to NoSaturation
1 parent 462632e commit e77252d

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

pymc_marketing/mmm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AdstockTransformation,
2020
DelayedAdstock,
2121
GeometricAdstock,
22+
NoAdstock,
2223
WeibullCDFAdstock,
2324
WeibullPDFAdstock,
2425
adstock_from_dict,
@@ -29,6 +30,7 @@
2930
InverseScaledLogisticSaturation,
3031
LogisticSaturation,
3132
MichaelisMentenSaturation,
33+
NoSaturation,
3234
RootSaturation,
3335
SaturationTransformation,
3436
TanhSaturation,
@@ -48,6 +50,7 @@
4850
create_eta_prior,
4951
create_m_and_L_recommendations,
5052
)
53+
from pymc_marketing.mmm.linear_regression import FancyLinearRegression
5154
from pymc_marketing.mmm.linear_trend import LinearTrend
5255
from pymc_marketing.mmm.media_transformation import (
5356
MediaConfig,
@@ -68,6 +71,7 @@
6871
"BaseValidateMMM",
6972
"CovFunc",
7073
"DelayedAdstock",
74+
"FancyLinearRegression",
7175
"GeometricAdstock",
7276
"HSGPPeriodic",
7377
"HillSaturation",
@@ -81,6 +85,8 @@
8185
"MediaTransformation",
8286
"MichaelisMentenSaturation",
8387
"MonthlyFourier",
88+
"NoAdstock",
89+
"NoSaturation",
8490
"PeriodicCovFunc",
8591
"RootSaturation",
8692
"SaturationTransformation",

pymc_marketing/mmm/components/adstock.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,22 @@ def function(self, x, lam, k):
329329
}
330330

331331

332+
class NoAdstock(AdstockTransformation):
333+
"""Wrapper around no adstock transformation."""
334+
335+
lookup_name: str = "no_adstock"
336+
337+
def function(self, x):
338+
"""No adstock function."""
339+
return x
340+
341+
default_priors = {}
342+
343+
def update_priors(self, priors):
344+
"""Update priors for the no adstock transformation."""
345+
return
346+
347+
332348
def adstock_from_dict(data: dict) -> AdstockTransformation:
333349
"""Create an adstock transformation from a dictionary."""
334350
data = data.copy()

pymc_marketing/mmm/components/saturation.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,37 @@ def function(self, x, alpha, beta):
459459
}
460460

461461

462+
class NoSaturation(SaturationTransformation):
463+
"""Wrapper around linear saturation function.
464+
465+
For more information, see :func:`pymc_marketing.mmm.transformers.linear_saturation`.
466+
467+
.. plot::
468+
:context: close-figs
469+
470+
import matplotlib.pyplot as plt
471+
import numpy as np
472+
from pymc_marketing.mmm import NoSaturation
473+
474+
rng = np.random.default_rng(0)
475+
476+
saturation = NoSaturation()
477+
prior = saturation.sample_prior(random_seed=rng)
478+
curve = saturation.sample_curve(prior)
479+
saturation.plot_curve(curve, random_seed=rng)
480+
plt.show()
481+
482+
"""
483+
484+
lookup_name = "linear"
485+
486+
def function(self, x, beta):
487+
"""Linear saturation function."""
488+
return pt.as_tensor_variable(beta * x)
489+
490+
default_priors = {"beta": Prior("HalfNormal", sigma=1)}
491+
492+
462493
def saturation_from_dict(data: dict) -> SaturationTransformation:
463494
"""Get a saturation function from a dictionary."""
464495
data = data.copy()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Linear regression model implemented using the MMM class."""
15+
16+
from pymc_marketing.mmm.components.adstock import NoAdstock
17+
from pymc_marketing.mmm.components.saturation import NoSaturation
18+
from pymc_marketing.mmm.mmm import MMM
19+
20+
21+
def FancyLinearRegression(
22+
**mmm_kwargs,
23+
) -> MMM:
24+
"""Create wrapper around MMM for a linear regression model.
25+
26+
See :func:`pymc_marketing.mmm.mmm.MMM` for more details.
27+
28+
Parameters
29+
----------
30+
mmm_kwargs
31+
Keyword arguments to pass to the MMM constructor.
32+
33+
Returns
34+
-------
35+
MMM
36+
An instance of the MMM class with linear regression settings.
37+
38+
Examples
39+
--------
40+
Load a saved MMM model with linear regression settings:
41+
42+
.. code-block:: python
43+
44+
from pymc_marketing.mmm import MMM
45+
46+
linear_regression = MMM.load("linear_regression_model.nc")
47+
48+
"""
49+
return MMM(
50+
adstock=NoAdstock(l_max=1),
51+
saturation=NoSaturation(),
52+
**mmm_kwargs,
53+
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pandas as pd
16+
import pytest
17+
18+
from pymc_marketing.mmm.linear_regression import FancyLinearRegression
19+
from pymc_marketing.mmm.mmm import MMM
20+
21+
seed: int = sum(map(ord, "pymc_marketing"))
22+
rng: np.random.Generator = np.random.default_rng(seed=seed)
23+
24+
25+
@pytest.fixture(scope="module")
26+
def generate_data():
27+
def _generate_data(date_data: pd.DatetimeIndex) -> pd.DataFrame:
28+
n: int = date_data.size
29+
30+
return pd.DataFrame(
31+
data={
32+
"date": date_data,
33+
"channel_1": rng.integers(low=0, high=400, size=n),
34+
"channel_2": rng.integers(low=0, high=50, size=n),
35+
"control_1": rng.gamma(shape=1000, scale=500, size=n),
36+
"control_2": rng.gamma(shape=100, scale=5, size=n),
37+
"other_column_1": rng.integers(low=0, high=100, size=n),
38+
"other_column_2": rng.normal(loc=0, scale=1, size=n),
39+
}
40+
)
41+
42+
return _generate_data
43+
44+
45+
@pytest.fixture(scope="module")
46+
def toy_X(generate_data) -> pd.DataFrame:
47+
date_data: pd.DatetimeIndex = pd.date_range(
48+
start="2019-06-01", end="2021-12-31", freq="W-MON"
49+
)
50+
51+
return generate_data(date_data)
52+
53+
54+
@pytest.fixture(scope="module")
55+
def toy_y(toy_X: pd.DataFrame) -> pd.Series:
56+
return pd.Series(data=rng.integers(low=0, high=100, size=toy_X.shape[0]))
57+
58+
59+
@pytest.fixture(scope="module")
60+
def linear_regression() -> MMM:
61+
return FancyLinearRegression(
62+
date_column="date",
63+
channel_columns=["channel_1", "channel_2"],
64+
control_columns=["control_1", "control_2"],
65+
)
66+
67+
68+
def test_fancy_linear_regression(
69+
linear_regression: MMM,
70+
toy_X: pd.DataFrame,
71+
toy_y: pd.Series,
72+
mock_pymc_sample,
73+
) -> None:
74+
"""Test that FancyLinearRegression returns an instance of MMM."""
75+
assert isinstance(linear_regression, MMM)
76+
77+
linear_regression.fit(toy_X, toy_y)
78+
79+
assert set(linear_regression.fit_result.variables) == {
80+
"chain",
81+
"draw",
82+
"date",
83+
"channel",
84+
"channel_contributions",
85+
"control",
86+
"control_contributions",
87+
"intercept",
88+
"mu",
89+
"gamma_control",
90+
"saturation_beta",
91+
"total_contributions",
92+
"y_sigma",
93+
}

0 commit comments

Comments
 (0)