Skip to content

Commit 88a4c76

Browse files
authored
URL pathlike for consistent paths.data_dir API (#1797)
* url pathlike implementation * remove previous way in docstrings * remove previous way in docs
1 parent 893a7f2 commit 88a4c76

File tree

8 files changed

+132
-23
lines changed

8 files changed

+132
-23
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ from pymc_marketing.mmm import (
8585
LogisticSaturation,
8686
MMM,
8787
)
88+
from pymc_marketing.paths import data_dir
8889

89-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
90-
data = pd.read_csv(data_url, parse_dates=["date_week"])
90+
file_path = data_dir / "mmm_example.csv"
91+
data = pd.read_csv(file_path, parse_dates=["date_week"])
9192

9293
mmm = MMM(
9394
adstock=GeometricAdstock(l_max=8),
@@ -168,9 +169,10 @@ import matplotlib.pyplot as plt
168169
import pandas as pd
169170
import seaborn as sns
170171
from pymc_marketing import clv
172+
from pymc_marketing.paths import data_dir
171173

172-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
173-
data = pd.read_csv(data_url)
174+
file_path = data_dir / "clv_quickstart.csv"
175+
data = pd.read_csv(data_path)
174176
data["customer_id"] = data.index
175177

176178
beta_geo_model = clv.BetaGeoModel(data=data)

docs/source/getting_started/quickstart/clv/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import matplotlib.pyplot as plt
77
import pandas as pd
88
import seaborn as sns
99
from pymc_marketing import clv
10+
from pymc_marketing.paths import data_dir
1011

11-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
12-
data = pd.read_csv(data_url)
12+
file_path = data_dir / "clv_quickstart.csv"
13+
data = pd.read_csv(file_path)
1314
data["customer_id"] = data.index
1415

1516
beta_geo_model = clv.BetaGeoModel(data=data)

docs/source/getting_started/quickstart/mmm/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ from pymc_marketing.mmm import (
88
LogisticSaturation,
99
MMM,
1010
)
11+
from pymc_marketing.paths import data_dir
1112

12-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
13-
data = pd.read_csv(data_url, parse_dates=["date_week"])
13+
file_path = data_dir / "mmm_example.csv"
14+
data = pd.read_csv(file_path, parse_dates=["date_week"])
1415

1516
mmm = MMM(
1617
adstock=GeometricAdstock(l_max=8),

pymc_marketing/mlflow.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,15 @@
7777
LogisticSaturation,
7878
MMM,
7979
)
80+
from pymc_marketing.paths import data_dir
8081
import pymc_marketing.mlflow
8182
8283
pymc_marketing.mlflow.autolog(log_mmm=True)
8384
8485
# Usual PyMC-Marketing model code
8586
86-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
87-
data = pd.read_csv(data_url, parse_dates=["date_week"])
87+
file_path = data_dir / "mmm_example.csv"
88+
data = pd.read_csv(file_path, parse_dates=["date_week"])
8889
8990
X = data.drop("y", axis=1)
9091
y = data["y"]
@@ -122,15 +123,16 @@
122123
import mlflow
123124
124125
from pymc_marketing.clv import BetaGeoModel
126+
from pymc_marketing.paths import data_dir
125127
126128
import pymc_marketing.mlflow
127129
128130
pymc_marketing.mlflow.autolog(log_clv=True)
129131
130132
mlflow.set_experiment("CLV Experiment")
131133
132-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
133-
data = pd.read_csv(data_url)
134+
file_path = data_dir / "clv_quickstart.csv"
135+
data = pd.read_csv(file_path)
134136
data["customer_id"] = data.index
135137
136138
model = BetaGeoModel(data=data)
@@ -832,15 +834,16 @@ def log_mmm(
832834
LogisticSaturation,
833835
MMM,
834836
)
837+
from pymc_marketing.paths import data_dir
835838
import pymc_marketing.mlflow
836839
from pymc_marketing.mlflow import log_mmm
837840
838841
pymc_marketing.mlflow.autolog(log_mmm=True)
839842
840843
# Usual PyMC-Marketing model code
841844
842-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
843-
data = pd.read_csv(data_url, parse_dates=["date_week"])
845+
file_path = data_dir / "mmm_example.csv"
846+
data = pd.read_csv(file_path, parse_dates=["date_week"])
844847
845848
X = data.drop("y", axis=1)
846849
y = data["y"]
@@ -1117,14 +1120,15 @@ def autolog(
11171120
LogisticSaturation,
11181121
MMM,
11191122
)
1123+
from pymc_marketing.paths import data_dir
11201124
import pymc_marketing.mlflow
11211125
11221126
pymc_marketing.mlflow.autolog(log_mmm=True)
11231127
11241128
# Usual PyMC-Marketing model code
11251129
1126-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
1127-
data = pd.read_csv(data_url, parse_dates=["date_week"])
1130+
file_path = data_dir / "mmm_example.csv"
1131+
data = pd.read_csv(file_path, parse_dates=["date_week"])
11281132
11291133
X = data.drop("y", axis=1)
11301134
y = data["y"]
@@ -1163,15 +1167,16 @@ def autolog(
11631167
import mlflow
11641168
11651169
from pymc_marketing.clv import BetaGeoModel
1170+
from pymc_marketing.paths import data_dir
11661171
11671172
import pymc_marketing.mlflow
11681173
11691174
pymc_marketing.mlflow.autolog(log_clv=True)
11701175
11711176
mlflow.set_experiment("CLV Experiment")
11721177
1173-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
1174-
data = pd.read_csv(data_url)
1178+
file_path = data_dir / "clv_quickstart.csv"
1179+
data = pd.read_csv(file_path)
11751180
data["customer_id"] = data.index
11761181
11771182
model = BetaGeoModel(data=data)

pymc_marketing/mmm/evaluation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ def compute_summary_metrics(
195195
LogisticSaturation,
196196
MMM,
197197
)
198+
from pymc_marketing.paths import data_dir
198199
from pymc_marketing.mmm.evaluation import compute_summary_metrics
199200
200201
# Usual PyMC-Marketing demo model code
201-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
202-
data = pd.read_csv(data_url, parse_dates=["date_week"])
202+
file_path = data_dir / "mmm_example.csv"
203+
data = pd.read_csv(file_path, parse_dates=["date_week"])
203204
204205
X = data.drop("y", axis=1)
205206
y = data["y"]

pymc_marketing/mmm/mmm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,10 @@ class MMM(
940940
LogisticSaturation
941941
MMM,
942942
)
943+
from pymc_marketing.paths import data_dir
943944
944-
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
945-
data = pd.read_csv(data_url, parse_dates=["date_week"])
945+
file_path = data_dir / "mmm_example.csv"
946+
data = pd.read_csv(file_path, parse_dates=["date_week"])
946947
947948
mmm = MMM(
948949
date_column="date_week",

pymc_marketing/paths.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,80 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Paths for the project."""
14+
"""Paths for the project.
15+
16+
The `data_dir` will be the local directory where the data files are stored in forked
17+
repositories. If the directory does not exist, it will create a URLPath object pointing
18+
to the data directory on the main branch of the pymc-labs/pymc-marketing repository.
19+
20+
"""
21+
22+
from __future__ import annotations
23+
24+
from dataclasses import dataclass
25+
from os import PathLike
1526

1627
from pyprojroot import here
1728

1829
root = here()
1930
data_dir = root / "data"
31+
32+
33+
URL = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/{branch}/data"
34+
35+
36+
@dataclass
37+
class URLPath(PathLike):
38+
"""A class representing a URL path which can be used like a file path.
39+
40+
Parameters
41+
----------
42+
url : str
43+
The URL to the data directory or file.
44+
45+
"""
46+
47+
url: str
48+
49+
def __fspath__(self) -> str:
50+
"""Return the URL as a string when the object is used as a file path."""
51+
return self.url
52+
53+
def __truediv__(self, other: str) -> URLPath:
54+
"""Combine the URL with another path component."""
55+
return URLPath(f"{self.url}/{other}")
56+
57+
58+
def create_data_url(branch: str) -> URLPath:
59+
"""Create a URLPath object for the data directory on a specific branch.
60+
61+
Parameters
62+
----------
63+
branch : str
64+
The branch name to create the URL for.
65+
66+
Returns
67+
-------
68+
URLPath
69+
An object representing the URL path to the data directory on the specified branch.
70+
71+
Examples
72+
--------
73+
Read MMM data from the main branch:
74+
75+
.. code-block:: python
76+
77+
import pandas as pd
78+
79+
from pymc_marketing.paths import create_data_url
80+
81+
data_dir = create_data_url("main")
82+
file = data_dir / "mmm_example.csv"
83+
df = pd.read_csv(file)
84+
85+
"""
86+
return URLPath(URL.format(branch=branch))
87+
88+
89+
if not data_dir.exists():
90+
data_dir = create_data_url("main")

tests/test_paths.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from os import fspath
1415
from pathlib import Path
1516

17+
import pytest
1618
from pyprojroot import here
1719

1820
from pymc_marketing import paths
@@ -37,3 +39,28 @@ def test_paths_are_absolute() -> None:
3739
"""Test that all defined paths are absolute."""
3840
assert paths.root.is_absolute()
3941
assert paths.data_dir.is_absolute()
42+
43+
44+
@pytest.fixture(scope="module")
45+
def data_url() -> paths.URLPath:
46+
return paths.create_data_url("main")
47+
48+
49+
def test_create_data_url(data_url) -> None:
50+
assert (
51+
data_url.url
52+
== "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data"
53+
)
54+
55+
56+
def test_url_path_fspath(data_url) -> None:
57+
"""Test the __fspath__ method of URLPath."""
58+
assert fspath(data_url) == data_url.url
59+
60+
61+
def test_url_path_truediv(data_url) -> None:
62+
"""Test the __truediv__ method of URLPath."""
63+
new_path = data_url / "new_file.csv"
64+
expected_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/new_file.csv"
65+
assert new_path.url == expected_url
66+
assert isinstance(new_path, paths.URLPath)

0 commit comments

Comments
 (0)