Skip to content

Commit ee9ef76

Browse files
committed
docs
1 parent a61e1e7 commit ee9ef76

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

examples/tutorials/plot_tuto_diffusion_models.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Tutorial for imputers based on diffusion models
44
===============================================
55
6-
In this tutorial, we show how to use :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM`
7-
and :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS` classes.
6+
In this tutorial, we show how to use :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM`
7+
and :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM` classes.
88
"""
99

1010
# %%
@@ -15,7 +15,7 @@
1515
from qolmat.benchmark import comparator, missing_patterns
1616

1717
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
18-
from qolmat.imputations.diffusions.diffusions import TabDDPM, TabDDPMTS
18+
from qolmat.imputations.diffusions.ddpms import TabDDPM, TsDDPM
1919

2020
# %%
2121
# 1. Data
@@ -34,8 +34,8 @@
3434
# 2. Hyperparameters for the wapper ImputerDiffusion
3535
# ---------------------------------------------------------------
3636
# We use the wapper :class:`~qolmat.imputations.imputers_pytorch.ImputerDiffusion` for our
37-
# diffusion models (e.g., :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM`,
38-
# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS`). The most important hyperparameter
37+
# diffusion models (e.g., :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM`,
38+
# :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM`). The most important hyperparameter
3939
# is ``model`` where we select a diffusion base model for the task of imputation
4040
# (e.g., ``model=TabDDPM()``).
4141
# Other hyperparams are for training the selected diffusion model.
@@ -62,7 +62,7 @@
6262
df_data_valid = df_data.iloc[:5000]
6363

6464
tabddpm = ImputerDiffusion(
65-
model=TabDDPM(), epochs=50, batch_size=100, x_valid=df_data_valid, print_valid=True
65+
model=TabDDPM(), epochs=10, batch_size=100, x_valid=df_data_valid, print_valid=True
6666
)
6767
tabddpm = tabddpm.fit(df_data)
6868

@@ -115,7 +115,7 @@
115115
# %%
116116
# 3. Hyperparameters for TabDDPM
117117
# ---------------------------------------------------------------
118-
# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM` is a diffusion model based on
118+
# :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` is a diffusion model based on
119119
# Denoising Diffusion Probabilistic Models [1] for imputing tabular data. Several important
120120
# hyperparameters are
121121
#
@@ -164,10 +164,10 @@
164164
results.groupby(axis=0, level=0).mean().groupby(axis=0, level=0).mean()
165165

166166
# %%
167-
# 4. Hyperparameters for TabDDPMTS
167+
# 4. Hyperparameters for TsDDPM
168168
# ---------------------------------------------------------------
169-
# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS` is built on top of
170-
# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM` to capture time-based relationships
169+
# :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM` is built on top of
170+
# :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` to capture time-based relationships
171171
# between data points in a dataset.
172172
#
173173
# Two important hyperparameters for processing time-series data are ``index_datetime``
@@ -183,7 +183,7 @@
183183
# `link <https://pandas.pydata.org/pandas-docs/
184184
# stable/user_guide/timeseries.html#offset-aliases>`_
185185
#
186-
# For TabDDPMTS, we have two options for splitting data:
186+
# For TsDDPM, we have two options for splitting data:
187187
#
188188
# * ``is_rolling=False`` (default value): the data is splited by using
189189
# pandas.DataFrame.resample(rule=freq_str). There is no duplication of row between chunks,
@@ -196,8 +196,8 @@
196196

197197
dict_imputers = {
198198
"tabddpm": ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=10, batch_size=100),
199-
"tabddpmts": ImputerDiffusion(
200-
model=TabDDPMTS(num_sampling=5, is_rolling=True),
199+
"TsDDPM": ImputerDiffusion(
200+
model=TsDDPM(num_sampling=5, is_rolling=True),
201201
epochs=10,
202202
batch_size=100,
203203
index_datetime="datetime",

0 commit comments

Comments
 (0)