Skip to content

Commit 397d26f

Browse files
Julien RousselJulien Roussel
authored andcommitted
frechet distance refacto
1 parent 992246c commit 397d26f

File tree

6 files changed

+17
-35
lines changed

6 files changed

+17
-35
lines changed

examples/benchmark.md

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ jupyter:
1616
**This notebook aims to present the Qolmat repo through an example of a multivariate time series.
1717
In Qolmat, a few data imputation methods are implemented as well as a way to evaluate their performance.**
1818

19-
```python
20-
21-
```
2219

2320
First, import some useful librairies
2421

@@ -36,26 +33,18 @@ from IPython.display import Image
3633
import pandas as pd
3734
from datetime import datetime
3835
import numpy as np
39-
import scipy
4036
import hyperopt as ho
41-
from hyperopt.pyll.base import Apply as hoApply
4237
np.random.seed(1234)
43-
import pprint
4438
from matplotlib import pyplot as plt
45-
import matplotlib.image as mpimg
4639
import matplotlib.ticker as plticker
4740

4841
tab10 = plt.get_cmap("tab10")
4942
plt.rcParams.update({'font.size': 18})
5043

51-
from typing import Optional
5244

5345
from sklearn.linear_model import LinearRegression
54-
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGradientBoostingRegressor
55-
5646

57-
import sys
58-
from qolmat.benchmark import comparator, missing_patterns, hyperparameters
47+
from qolmat.benchmark import comparator, missing_patterns
5948
from qolmat.imputations import imputers
6049
from qolmat.utils import data, utils, plot
6150

@@ -239,10 +228,6 @@ df_plot = data.add_datetime_features(df_plot, col_time="date")
239228
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}
240229
```
241230

242-
```python tags=[]
243-
dfs_imputed["VAR_max"].groupby("station").min()
244-
```
245-
246231
```python tags=[]
247232
station = df_plot.index.get_level_values("station")[0]
248233
# station = "Huairou"

qolmat/benchmark/metrics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ def total_variance_distance(
368368
pd.Series
369369
Total variance distance
370370
"""
371-
cols_categorical = utils._get_categorical_features(df1)
372371
return columnwise_metric(
373-
df1[cols_categorical],
374-
df2[cols_categorical],
375-
df_mask[cols_categorical],
372+
df1,
373+
df2,
374+
df_mask,
376375
_total_variance_distance_1D,
376+
type_cols="categorical",
377377
)
378378

379379

@@ -792,7 +792,7 @@ def frechet_distance(
792792
df1,
793793
df2,
794794
df_mask,
795-
frechet_distance,
795+
frechet_distance_base,
796796
min_n_rows=min_n_rows,
797797
type_cols="numerical",
798798
)
@@ -1003,10 +1003,12 @@ def pattern_based_weighted_mean_metric(
10031003
cols = df1.select_dtypes(exclude=["number"]).columns
10041004
else:
10051005
raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!")
1006+
10061007
if np.any(df_mask & df1.isna()):
10071008
raise ValueError("The argument df1 has missing values on the mask!")
10081009
if np.any(df_mask & df2.isna()):
10091010
raise ValueError("The argument df2 has missing values on the mask!")
1011+
10101012
rows_mask = df_mask.any(axis=1)
10111013
scores = []
10121014
weights = []
@@ -1041,7 +1043,7 @@ def get_metric(name: str) -> Callable:
10411043
"KS_test": kolmogorov_smirnov_test,
10421044
"correlation_diff": mean_difference_correlation_matrix_numerical_features,
10431045
"energy": sum_energy_distances,
1044-
"frechet_single": partial(frechet_distance, method="single"),
1046+
"frechet": partial(frechet_distance, method="single"),
10451047
"frechet_pattern": partial(frechet_distance, method="pattern"),
10461048
"dist_corr_pattern": distance_anticorr_pattern,
10471049
}

qolmat/imputations/preprocessing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,13 @@ def make_pipeline_mixte_preprocessing(
314314

315315
ohe = OneHotEncoder(handle_unknown="ignore", use_cat_names=True)
316316
transformers += [("cat", ohe, selector(dtype_exclude=np.number))]
317-
col_transformer = ColumnTransformer(transformers=transformers).set_output(transform="pandas")
317+
col_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough")
318+
col_transformer = col_transformer.set_output(transform="pandas")
318319
preprocessor = Pipeline(steps=[("col_transformer", col_transformer)])
320+
319321
if avoid_new:
320322
preprocessor.steps.append(("bins", BinTransformer()))
323+
print(preprocessor)
321324
return preprocessor
322325

323326

qolmat/utils/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def get_shape_original(M: NDArray, shape: tuple) -> NDArray:
288288

289289

290290
def create_lag_matrices(X: NDArray, p: int) -> Tuple[NDArray, NDArray]:
291-
n_rows, n_cols = X.shape
291+
n_rows, _ = X.shape
292292
n_rows_new = n_rows - p
293293
list_X_lag = [np.ones((n_rows_new, 1))]
294294
for lag in range(p):
@@ -304,7 +304,5 @@ def nan_mean_cov(X: NDArray) -> Tuple[NDArray, NDArray]:
304304
_, n_variables = X.shape
305305
means = np.nanmean(X, axis=0)
306306
cov = np.ma.cov(np.ma.masked_invalid(X), rowvar=False).data
307-
print(cov.shape)
308-
print(X.shape)
309307
cov = cov.reshape(n_variables, n_variables)
310308
return means, cov

tests/benchmark/test_metrics.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,20 +383,12 @@ def test_pattern_based_weighted_mean_metric(
383383

384384
def test_pattern_mae_comparison(mocker) -> None:
385385

386-
# def mock_metric(values1: pd.Series, values2: pd.Series) -> float:
387-
# call_count += 1
388-
# return 0
389-
390386
mock_metric = mocker.patch("qolmat.benchmark.metrics.accuracy_1D", return_value=0)
391-
# def fun_mean_mae(df_gauss1, df_gauss2, df_mask_gauss) -> float:
392-
# return metrics.mean_squared_error(df_gauss1, df_gauss2, df_mask_gauss).mean()
393387

394-
print(df_mask)
395388
df_nonan = df_incomplete.notna()
396-
result = metrics.pattern_based_weighted_mean_metric(
389+
metrics.pattern_based_weighted_mean_metric(
397390
df_incomplete, df_imputed, df_nonan, metric=mock_metric, min_n_rows=1
398391
)
399-
print(result)
400392
assert mock_metric.call_count == 2
401393

402394

tests/imputations/test_preprocessing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def test_preprocessing_pipeline(preprocessing_pipeline):
198198
# Test with numerical features
199199
X_num = pd.DataFrame([[1, 2], [3, 4], [5, 6]])
200200
X_transformed = preprocessing_pipeline.fit_transform(X_num)
201+
print(X_num.shape)
202+
print(X_transformed.shape)
201203
assert isinstance(X_transformed, pd.DataFrame)
202204
assert X_transformed.shape[1] == X_num.shape[1]
203205

0 commit comments

Comments
 (0)