Skip to content

Commit 5ee48e6

Browse files
authored
Linear regression component tests (#1715)
* change the lookup name * add additional tests for linear regression components
1 parent e77252d commit 5ee48e6

File tree

4 files changed

+10
-2
lines changed

4 files changed

+10
-2
lines changed

pymc_marketing/mmm/components/adstock.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def function(self, x, alpha):
5555
from __future__ import annotations
5656

5757
import numpy as np
58+
import pytensor.tensor as pt
5859
import xarray as xr
5960
from pydantic import Field, validate_call
6061

@@ -336,7 +337,7 @@ class NoAdstock(AdstockTransformation):
336337

337338
def function(self, x):
338339
"""No adstock function."""
339-
return x
340+
return pt.as_tensor_variable(x)
340341

341342
default_priors = {}
342343

pymc_marketing/mmm/components/saturation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ class NoSaturation(SaturationTransformation):
481481
482482
"""
483483

484-
lookup_name = "linear"
484+
lookup_name = "no_saturation"
485485

486486
def function(self, x, beta):
487487
"""Linear saturation function."""

tests/mmm/components/test_adstock.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AdstockTransformation,
3030
DelayedAdstock,
3131
GeometricAdstock,
32+
NoAdstock,
3233
WeibullCDFAdstock,
3334
WeibullPDFAdstock,
3435
adstock_from_dict,
@@ -45,6 +46,7 @@ def adstocks() -> list:
4546
GeometricAdstock(l_max=10),
4647
WeibullPDFAdstock(l_max=10),
4748
WeibullCDFAdstock(l_max=10),
49+
NoAdstock(l_max=1),
4850
]
4951

5052
return [
@@ -101,6 +103,9 @@ def test_adstock_no_negative_lmax():
101103
adstocks(),
102104
)
103105
def test_adstock_sample_curve(adstock) -> None:
106+
if adstock.lookup_name == "no_adstock":
107+
raise pytest.skip(reason="NoAdstock has no parameters to sample.")
108+
104109
prior = adstock.sample_prior()
105110
assert isinstance(prior, xr.Dataset)
106111
curve = adstock.sample_curve(prior)

tests/mmm/components/test_saturation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
InverseScaledLogisticSaturation,
3232
LogisticSaturation,
3333
MichaelisMentenSaturation,
34+
NoSaturation,
3435
RootSaturation,
3536
SaturationTransformation,
3637
TanhSaturation,
@@ -56,6 +57,7 @@ def saturation_functions():
5657
HillSaturation(),
5758
HillSaturationSigmoid(),
5859
RootSaturation(),
60+
NoSaturation(),
5961
]
6062
return [
6163
pytest.param(transformation, id=transformation.lookup_name)

0 commit comments

Comments
 (0)