Skip to content

Commit 0e861cb

Browse files
Fix ShiftedBetaGeoModel Cohort Bug (#2115)
* pred method bug comments * fix sbg pred bug --------- Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
1 parent ffb3756 commit 0e861cb

File tree

6 files changed

+4757
-1121
lines changed

6 files changed

+4757
-1121
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,6 @@ dmypy.json
147147
# Gallery images
148148
docs/source/gallery/images/
149149
docs/gettext/
150+
151+
# Sandbox folder
152+
sandbox/

docs/source/notebooks/clv/dev/sBG_cohort.ipynb

Lines changed: 4611 additions & 1118 deletions
Large diffs are not rendered by default.

pymc_marketing/clv/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,9 @@ class ShiftedBetaGeometric(Discrete):
928928
Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:
929929
930930
.. math::
931-
\mathbb{P}(T=t|\alpha,\beta) = (\frac{B(\alpha+1,\beta+t-1)}{B(\alpha,\beta}),t=1,2,... \\
931+
\mathbb{P}(T=t|\alpha,\beta) = \frac{B(\alpha+1,\beta+t-1)}{B(\alpha,\beta)},t=1,2,... \\
932932
\begin{align}
933-
\mathbb{S}(t|\alpha,\beta) = (\frac{B(\alpha,\beta+t)}{B(\alpha,\beta}),t=1,2,... \\
933+
\mathbb{S}(t|\alpha,\beta) = \frac{B(\alpha,\beta+t)}{B(\alpha,\beta)},t=1,2,... \\
934934
\end{align}
935935
936936
======== ===============================================

pymc_marketing/clv/models/shifted_beta_geo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def _extract_predictive_variables(
421421

422422
# Map cohort indices for each customer
423423
pred_cohort_idx = pd.Categorical(
424-
customer_cohort_map.values, categories=self.cohorts
424+
customer_cohort_map.values, categories=cohorts_present
425425
).codes
426426

427427
# Reconstruct customer-level parameters
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
# Summary data from original research paper for ShiftedBetaGeoModel
5+
# From Table 1 in https://faculty.wharton.upenn.edu/wp-content/uploads/2012/04/Fader_hardie_jim_07.pdf
6+
sbg_research_data = pd.DataFrame(
7+
{
8+
"highend": [
9+
100.0,
10+
86.9,
11+
74.3,
12+
65.3,
13+
59.3,
14+
55.1,
15+
51.7,
16+
49.1,
17+
46.8,
18+
44.5,
19+
42.7,
20+
40.9,
21+
39.4,
22+
],
23+
"regular": [
24+
100.0,
25+
63.1,
26+
46.8,
27+
38.2,
28+
32.6,
29+
28.9,
30+
26.2,
31+
24.1,
32+
22.3,
33+
20.7,
34+
19.4,
35+
18.3,
36+
17.3,
37+
],
38+
}
39+
)
40+
41+
42+
def generate_sbg_data(
43+
n_customers: int,
44+
n_time_periods: int,
45+
survival_data: pd.DataFrame,
46+
) -> pd.DataFrame:
47+
"""
48+
Generate individual-level customer churn data from aggregate percentage alive data.
49+
50+
Parameters
51+
----------
52+
n_customers : int
53+
Number of customers to simulate for each cohort
54+
n_time_periods : int
55+
Number of time periods to include from the survival data
56+
survival_data : pd.DataFrame, optional
57+
DataFrame with columns representing cohorts and values as percentage alive.
58+
59+
Returns
60+
-------
61+
pd.DataFrame
62+
DataFrame with columns: recency, T, cohort
63+
Contains individual customer data for all cohorts in the survival data
64+
"""
65+
66+
def _individual_data_from_percentage_alive(percentage_alive, initial_customers):
67+
"""Convert percentage alive data to individual churn times."""
68+
n_alive = np.asarray(percentage_alive / 100 * initial_customers, dtype=int)
69+
70+
died_at = np.zeros((initial_customers,), dtype=int)
71+
counter = 0
72+
for t, diff in enumerate((n_alive[:-1] - n_alive[1:]), start=1):
73+
died_at[counter : counter + diff] = t
74+
counter += diff
75+
76+
censoring_t = t + 1
77+
died_at[counter:] = censoring_t
78+
79+
return died_at
80+
81+
# Truncate data to requested number of time periods
82+
truncated_df = survival_data[:n_time_periods]
83+
84+
# Generate individual churn data for each cohort
85+
datasets = []
86+
for cohort_name in truncated_df.columns:
87+
churn_data = _individual_data_from_percentage_alive(
88+
truncated_df[cohort_name], n_customers
89+
)
90+
91+
dataset = pd.DataFrame(
92+
{
93+
"recency": churn_data,
94+
"T": n_time_periods,
95+
"cohort": cohort_name,
96+
}
97+
)
98+
datasets.append(dataset)
99+
100+
# Combine all cohorts into a single dataset
101+
combined_dataset = pd.concat(datasets, ignore_index=True)
102+
103+
# Create customer_id column from index
104+
combined_dataset["customer_id"] = combined_dataset.index + 1
105+
106+
return combined_dataset[["customer_id", "recency", "T", "cohort"]]

tests/clv/models/test_shifted_beta_geo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,40 @@ def test_covariate_cols_only_in_config(self, covariate_test_data):
928928
model.build_model()
929929
assert "dropout_covariate" in model.model.coords
930930

931+
def test_predictions_with_covariates_subset_cohorts(self):
932+
# Training data with 3 cohorts
933+
train_data = pd.DataFrame(
934+
{
935+
"customer_id": range(30),
936+
"recency": [3, 4, 5] * 10,
937+
"T": [5] * 30,
938+
"cohort": ["A"] * 10 + ["B"] * 10 + ["C"] * 10,
939+
"channel": [0, 1, 0, 1, 0] * 6,
940+
}
941+
)
942+
943+
model = ShiftedBetaGeoModel(
944+
data=train_data, model_config={"dropout_covariate_cols": ["channel"]}
945+
)
946+
model.fit(method="map", maxeval=10)
947+
948+
# Prediction data with subset of cohorts (NOT starting at index 0)
949+
pred_data = pd.DataFrame(
950+
{
951+
"customer_id": [100, 101, 102],
952+
"T": [3, 3, 3],
953+
"cohort": ["B", "C", "C"], # Missing cohort "A"
954+
"channel": [1, 0, 1],
955+
}
956+
)
957+
958+
# Should not raise IndexError
959+
prob_alive = model.expected_probability_alive(data=pred_data, future_t=1)
960+
assert prob_alive.shape[-1] == 3 # 3 customers
961+
962+
retention = model.expected_retention_rate(data=pred_data, future_t=1)
963+
assert retention.shape[-1] == 3
964+
931965

932966
class TestShiftedBetaGeoModelIndividual:
933967
@classmethod

0 commit comments

Comments
 (0)