Skip to content

Commit b6c98f1

Browse files
committed
fix: reduce computation time
1 parent 11fe4f6 commit b6c98f1

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

examples/tutorials/plot_tuto_diffusion_models.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
# %%
11+
import pandas as pd
1112
import numpy as np
1213
import matplotlib.pyplot as plt
1314

@@ -18,14 +19,20 @@
1819
from qolmat.imputations.diffusions.ddpms import TabDDPM, TsDDPM
1920

2021
# %%
21-
# 1. Data
22+
# 1. Time-series data
2223
# ---------------------------------------------------------------
2324
# We use the public Beijing Multi-Site Air-Quality Data Set.
2425
# It consists in hourly air pollutants data from 12 chinese nationally-controlled air-quality
2526
# monitoring sites. The original data from which the features were extracted comes from
26-
# https://archive.ics.uci.edu/static/public/501/beijing+multi+site+air+quality+data.zip
27+
# https://archive.ics.uci.edu/static/public/501/beijing+multi+site+air+quality+data.zip.
28+
# For this tutorial, we only use a small subset of this data
29+
# 1000 rows and 2 features (TEMP, PRES).
2730

2831
df_data = data.get_data_corrupted("Beijing")
32+
df_data = df_data[["TEMP", "PRES"]].iloc[:1000]
33+
df_data.index = df_data.index.set_levels(
34+
[df_data.index.levels[0], pd.to_datetime(df_data.index.levels[1])]
35+
)
2936

3037
print("Number of nan at each column:")
3138
print(df_data.isna().sum())
@@ -59,7 +66,7 @@
5966
# * ``print_valid``: a boolean to display/hide a training progress (including epoch_loss,
6067
# remaining training duration and performance scores computed by the metrics above).
6168

62-
df_data_valid = df_data.iloc[:5000]
69+
df_data_valid = df_data.iloc[:500]
6370

6471
tabddpm = ImputerDiffusion(
6572
model=TabDDPM(), epochs=10, batch_size=100, x_valid=df_data_valid, print_valid=True
@@ -71,7 +78,6 @@
7178

7279
print(tabddpm.get_summary_architecture())
7380

74-
7581
# %%
7682
# We also get the summary of the training progress with ``get_summary_training()``
7783

@@ -144,22 +150,20 @@
144150
# * ``dim_embedding``: dimension of hidden layers in residual blocks (``int = 128``)
145151
#
146152
# Let see an example below. We can observe that a large ``num_sampling`` generally improves
147-
# reconstruction errors (mae, wmape) but increases distribution distance (KL_columnwise,
148-
# wasserstein_columnwise).
153+
# reconstruction errors (mae) but increases distribution distance (KL_columnwise).
149154

150155
dict_imputers = {
151156
"num_sampling=5": ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=10, batch_size=100),
152-
"num_sampling=20": ImputerDiffusion(model=TabDDPM(num_sampling=10), epochs=10, batch_size=100),
157+
"num_sampling=10": ImputerDiffusion(model=TabDDPM(num_sampling=10), epochs=10, batch_size=100),
153158
}
154159

155160
comparison = comparator.Comparator(
156161
dict_imputers,
157162
selected_columns=df_data.columns,
158-
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=4),
159-
metrics=["mae", "wmape", "KL_columnwise", "wasserstein_columnwise"],
160-
max_evals=10,
163+
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2),
164+
metrics=["mae", "KL_columnwise"],
161165
)
162-
results = comparison.compare(df_data.iloc[:5000])
166+
results = comparison.compare(df_data)
163167

164168
results.groupby(axis=0, level=0).mean().groupby(axis=0, level=0).mean()
165169

@@ -174,7 +178,7 @@
174178
# and ``freq_str``.
175179
# E.g., ``ImputerDiffusion(model=TabDDPM(), index_datetime='datetime', freq_str='1D')``,
176180
#
177-
# * ``index_datetime``: the column name of datetime in index.
181+
# * ``index_datetime``: the column name of datetime in index. It must be a pandas datetime object.
178182
#
179183
# * ``freq_str``: the time-series frequency for splitting data into a list of chunks (each chunk
180184
# has the same number of rows). These chunks are fetched up in batches.
@@ -196,23 +200,22 @@
196200

197201
dict_imputers = {
198202
"tabddpm": ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=10, batch_size=100),
199-
"TsDDPM": ImputerDiffusion(
200-
model=TsDDPM(num_sampling=5, is_rolling=True),
203+
"tsddpm": ImputerDiffusion(
204+
model=TsDDPM(num_sampling=5, is_rolling=False),
201205
epochs=10,
202-
batch_size=100,
203-
index_datetime="datetime",
204-
freq_str="10D",
206+
batch_size=5,
207+
index_datetime="date",
208+
freq_str="5D",
205209
),
206210
}
207211

208212
comparison = comparator.Comparator(
209213
dict_imputers,
210214
selected_columns=df_data.columns,
211-
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=4),
212-
metrics=["mae", "wmape", "KL_columnwise", "wasserstein_columnwise"],
213-
max_evals=10,
215+
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2),
216+
metrics=["mae", "KL_columnwise"],
214217
)
215-
results = comparison.compare(df_data.iloc[:5000])
218+
results = comparison.compare(df_data)
216219

217220
results.groupby(axis=0, level=0).mean().groupby(axis=0, level=0).mean()
218221

0 commit comments

Comments
 (0)