|
3 | 3 | Tutorial for imputers based on diffusion models |
4 | 4 | =============================================== |
5 | 5 |
|
6 | | -In this tutorial, we show how to use :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM` |
7 | | -and :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS` classes. |
| 6 | +In this tutorial, we show how to use :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` |
| 7 | +and :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM` classes. |
8 | 8 | """ |
9 | 9 |
|
10 | 10 | # %% |
|
15 | 15 | from qolmat.benchmark import comparator, missing_patterns |
16 | 16 |
|
17 | 17 | from qolmat.imputations.imputers_pytorch import ImputerDiffusion |
18 | | -from qolmat.imputations.diffusions.diffusions import TabDDPM, TabDDPMTS |
| 18 | +from qolmat.imputations.diffusions.ddpms import TabDDPM, TsDDPM |
19 | 19 |
|
20 | 20 | # %% |
21 | 21 | # 1. Data |
|
34 | 34 | # 2. Hyperparameters for the wapper ImputerDiffusion |
35 | 35 | # --------------------------------------------------------------- |
36 | 36 | # We use the wapper :class:`~qolmat.imputations.imputers_pytorch.ImputerDiffusion` for our |
37 | | -# diffusion models (e.g., :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM`, |
38 | | -# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS`). The most important hyperparameter |
| 37 | +# diffusion models (e.g., :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM`, |
| 38 | +# :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM`). The most important hyperparameter |
39 | 39 | # is ``model`` where we select a diffusion base model for the task of imputation |
40 | 40 | # (e.g., ``model=TabDDPM()``). |
41 | 41 | # Other hyperparams are for training the selected diffusion model. |
|
62 | 62 | df_data_valid = df_data.iloc[:5000] |
63 | 63 |
|
64 | 64 | tabddpm = ImputerDiffusion( |
65 | | - model=TabDDPM(), epochs=50, batch_size=100, x_valid=df_data_valid, print_valid=True |
| 65 | + model=TabDDPM(), epochs=10, batch_size=100, x_valid=df_data_valid, print_valid=True |
66 | 66 | ) |
67 | 67 | tabddpm = tabddpm.fit(df_data) |
68 | 68 |
|
|
115 | 115 | # %% |
116 | 116 | # 3. Hyperparameters for TabDDPM |
117 | 117 | # --------------------------------------------------------------- |
118 | | -# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM` is a diffusion model based on |
| 118 | +# :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` is a diffusion model based on |
119 | 119 | # Denoising Diffusion Probabilistic Models [1] for imputing tabular data. Several important |
120 | 120 | # hyperparameters are |
121 | 121 | # |
|
164 | 164 | results.groupby(axis=0, level=0).mean().groupby(axis=0, level=0).mean() |
165 | 165 |
|
166 | 166 | # %% |
167 | | -# 4. Hyperparameters for TabDDPMTS |
| 167 | +# 4. Hyperparameters for TsDDPM |
168 | 168 | # --------------------------------------------------------------- |
169 | | -# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPMTS` is built on top of |
170 | | -# :class:`~qolmat.imputations.diffusions.diffusions.TabDDPM` to capture time-based relationships |
| 169 | +# :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM` is built on top of |
| 170 | +# :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` to capture time-based relationships |
171 | 171 | # between data points in a dataset. |
172 | 172 | # |
173 | 173 | # Two important hyperparameters for processing time-series data are ``index_datetime`` |
|
183 | 183 | # `link <https://pandas.pydata.org/pandas-docs/ |
184 | 184 | # stable/user_guide/timeseries.html#offset-aliases>`_ |
185 | 185 | # |
186 | | -# For TabDDPMTS, we have two options for splitting data: |
| 186 | +# For TsDDPM, we have two options for splitting data: |
187 | 187 | # |
188 | 188 | # * ``is_rolling=False`` (default value): the data is splited by using |
189 | 189 | # pandas.DataFrame.resample(rule=freq_str). There is no duplication of row between chunks, |
|
196 | 196 |
|
197 | 197 | dict_imputers = { |
198 | 198 | "tabddpm": ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=10, batch_size=100), |
199 | | - "tabddpmts": ImputerDiffusion( |
200 | | - model=TabDDPMTS(num_sampling=5, is_rolling=True), |
| 199 | + "TsDDPM": ImputerDiffusion( |
| 200 | + model=TsDDPM(num_sampling=5, is_rolling=True), |
201 | 201 | epochs=10, |
202 | 202 | batch_size=100, |
203 | 203 | index_datetime="datetime", |
|
0 commit comments