Skip to content

Commit a2edc53

Browse files
Julien RousselJulien Roussel
authored andcommitted
frechet distance refacto
1 parent d5caf23 commit a2edc53

File tree

3 files changed

+92
-78
lines changed

3 files changed

+92
-78
lines changed

qolmat/benchmark/metrics.py

Lines changed: 34 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def columnwise_metric(
5959
if type_cols == "all":
6060
cols = df1.columns
6161
elif type_cols == "numerical":
62-
cols = _get_numerical_features(df1)
62+
cols = utils._get_numerical_features(df1)
6363
elif type_cols == "categorical":
64-
cols = _get_categorical_features(df1)
64+
cols = utils._get_categorical_features(df1)
6565
else:
6666
raise ValueError(f"Value {type_cols} is not valid for parameter `type_cols`!")
6767
values = {}
@@ -282,56 +282,6 @@ def dist_wasserstein(
282282
)
283283

284284

285-
def _get_numerical_features(df1: pd.DataFrame) -> List[str]:
286-
"""Get numerical features from dataframe
287-
288-
Parameters
289-
----------
290-
df1 : pd.DataFrame
291-
292-
Returns
293-
-------
294-
List[str]
295-
List of numerical features
296-
297-
Raises
298-
------
299-
Exception
300-
No numerical feature is found
301-
"""
302-
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
303-
if len(cols_numerical) == 0:
304-
raise Exception("No numerical feature is found.")
305-
else:
306-
return cols_numerical
307-
308-
309-
def _get_categorical_features(df1: pd.DataFrame) -> List[str]:
310-
"""Get categorical features from dataframe
311-
312-
Parameters
313-
----------
314-
df1 : pd.DataFrame
315-
316-
Returns
317-
-------
318-
List[str]
319-
List of categorical features
320-
321-
Raises
322-
------
323-
Exception
324-
No categorical feature is found
325-
"""
326-
327-
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
328-
cols_categorical = [col for col in df1.columns.to_list() if col not in cols_numerical]
329-
if len(cols_categorical) == 0:
330-
raise Exception("No categorical feature is found.")
331-
else:
332-
return cols_categorical
333-
334-
335285
def kolmogorov_smirnov_test_1D(df1: pd.Series, df2: pd.Series) -> float:
336286
"""Compute KS test statistic of the two-sample Kolmogorov-Smirnov test for goodness of fit.
337287
See more in https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html.
@@ -418,7 +368,7 @@ def total_variance_distance(
418368
pd.Series
419369
Total variance distance
420370
"""
421-
cols_categorical = _get_categorical_features(df1)
371+
cols_categorical = utils._get_categorical_features(df1)
422372
return columnwise_metric(
423373
df1[cols_categorical],
424374
df2[cols_categorical],
@@ -491,7 +441,7 @@ def mean_difference_correlation_matrix_numerical_features(
491441

492442
_check_same_number_columns(df1, df2)
493443

494-
cols_numerical = _get_numerical_features(df1)
444+
cols_numerical = utils._get_numerical_features(df1)
495445
df_corr1 = _get_correlation_pearson_matrix(df1[cols_numerical], use_p_value=use_p_value)
496446
df_corr2 = _get_correlation_pearson_matrix(df2[cols_numerical], use_p_value=use_p_value)
497447

@@ -560,7 +510,7 @@ def mean_difference_correlation_matrix_categorical_features(
560510

561511
_check_same_number_columns(df1, df2)
562512

563-
cols_categorical = _get_categorical_features(df1)
513+
cols_categorical = utils._get_categorical_features(df1)
564514
df_corr1 = _get_correlation_chi2_matrix(df1[cols_categorical], use_p_value=use_p_value)
565515
df_corr2 = _get_correlation_chi2_matrix(df2[cols_categorical], use_p_value=use_p_value)
566516

@@ -635,8 +585,8 @@ def mean_diff_corr_matrix_categorical_vs_numerical_features(
635585

636586
_check_same_number_columns(df1, df2)
637587

638-
cols_categorical = _get_categorical_features(df1)
639-
cols_numerical = _get_numerical_features(df1)
588+
cols_categorical = utils._get_categorical_features(df1)
589+
cols_numerical = utils._get_numerical_features(df1)
640590
df_corr1 = _get_correlation_f_oneway_matrix(
641591
df1, cols_categorical, cols_numerical, use_p_value=use_p_value
642592
)
@@ -763,10 +713,10 @@ def sum_pairwise_distances(
763713
###########################
764714

765715

766-
def frechet_distance(
716+
def frechet_distance_base(
767717
df1: pd.DataFrame,
768718
df2: pd.DataFrame,
769-
) -> float:
719+
) -> pd.Series:
770720
"""Compute the Fréchet distance between two dataframes df1 and df2
771721
Frechet_distance = || mu_1 - mu_2 ||_2^2 + Tr(Sigma_1 + Sigma_2 - 2(Sigma_1 . Sigma_2)^(1/2))
772722
It is normalized, df1 and df2 are first scaled by a factor (std(df1) + std(df2)) / 2
@@ -783,8 +733,8 @@ def frechet_distance(
783733
784734
Returns
785735
-------
786-
float
787-
frechet distance
736+
pd.Series
737+
Frechet distance in a Series object
788738
"""
789739

790740
if df1.shape != df2.shape:
@@ -798,16 +748,23 @@ def frechet_distance(
798748
means1, cov1 = utils.nan_mean_cov(df1.values)
799749
means2, cov2 = utils.nan_mean_cov(df2.values)
800750

801-
return algebra.frechet_distance_exact(means1, cov1, means2, cov2)
751+
distance = algebra.frechet_distance_exact(means1, cov1, means2, cov2)
752+
return pd.Series(distance, index=["All"])
802753

803754

804-
def frechet_distance_pattern(
755+
def frechet_distance(
805756
df1: pd.DataFrame,
806757
df2: pd.DataFrame,
807758
df_mask: pd.DataFrame,
759+
method: str = "single",
808760
min_n_rows: int = 10,
809761
) -> pd.Series:
810-
"""Frechet distance computed using a pattern decomposition
762+
"""
763+
Frechet distance computed using a pattern decomposition. Several variant are implemented:
764+
- the `single` method relies on a single estimation of the means and covariance matrix. It is
765+
relevent for MCAR data.
766+
- the `pattern`method relies on the aggregation of the estimated distance between each
767+
pattern. It is relevent for MAR data.
811768
812769
Parameters
813770
----------
@@ -817,6 +774,9 @@ def frechet_distance_pattern(
817774
Second empirical ditribution
818775
df_mask : pd.DataFrame
819776
Mask indicating on which values the distance has to computed on
777+
method: str
778+
Method used to compute the distance on multivariate datasets with missing values.
779+
Possible values are `robust` and `pattern`.
820780
min_n_rows: int
821781
Minimum number of rows for a KL estimation
822782
@@ -826,6 +786,8 @@ def frechet_distance_pattern(
826786
Series of computed metrics
827787
"""
828788

789+
if method == "single":
790+
return frechet_distance_base(df1, df2)
829791
return pattern_based_weighted_mean_metric(
830792
df1,
831793
df2,
@@ -890,7 +852,7 @@ def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame) -> float:
890852
return div_kl
891853

892854

893-
def kl_divergence_pattern(
855+
def kl_divergence(
894856
df1: pd.DataFrame,
895857
df2: pd.DataFrame,
896858
df_mask: pd.DataFrame,
@@ -913,7 +875,8 @@ def kl_divergence_pattern(
913875
df_mask: pd.DataFrame
914876
Mask indicating on what values the divergence should be computed
915877
method: str
916-
Method used
878+
Method used to compute the divergence on multivariate datasets with missing values.
879+
Possible values are `columnwise` and `gaussian`.
917880
min_n_rows: int
918881
Minimum number of rows for a KL estimation
919882
@@ -1073,12 +1036,13 @@ def get_metric(name: str) -> Callable:
10731036
"wmape": weighted_mean_absolute_percentage_error,
10741037
"accuracy": accuracy,
10751038
"wasserstein_columnwise": dist_wasserstein,
1076-
"KL_columnwise": partial(kl_divergence_pattern, method="columnwise"),
1077-
"KL_gaussian": partial(kl_divergence_pattern, method="gaussian"),
1078-
"ks_test": kolmogorov_smirnov_test,
1039+
"KL_columnwise": partial(kl_divergence, method="columnwise"),
1040+
"KL_gaussian": partial(kl_divergence, method="gaussian"),
1041+
"KS_test": kolmogorov_smirnov_test,
10791042
"correlation_diff": mean_difference_correlation_matrix_numerical_features,
10801043
"energy": sum_energy_distances,
1081-
"frechet": frechet_distance_pattern,
1044+
"frechet_single": partial(frechet_distance, method="single"),
1045+
"frechet_pattern": partial(frechet_distance, method="pattern"),
10821046
"dist_corr_pattern": distance_anticorr_pattern,
10831047
}
10841048
return dict_metrics[name]

qolmat/utils/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import List, Optional, Tuple, Union
22
import warnings
33

44
import numpy as np
@@ -12,6 +12,56 @@
1212
HyperValue = Union[int, float, str]
1313

1414

15+
def _get_numerical_features(df1: pd.DataFrame) -> List[str]:
16+
"""Get numerical features from dataframe
17+
18+
Parameters
19+
----------
20+
df1 : pd.DataFrame
21+
22+
Returns
23+
-------
24+
List[str]
25+
List of numerical features
26+
27+
Raises
28+
------
29+
Exception
30+
No numerical feature is found
31+
"""
32+
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
33+
if len(cols_numerical) == 0:
34+
raise Exception("No numerical feature is found.")
35+
else:
36+
return cols_numerical
37+
38+
39+
def _get_categorical_features(df1: pd.DataFrame) -> List[str]:
40+
"""Get categorical features from dataframe
41+
42+
Parameters
43+
----------
44+
df1 : pd.DataFrame
45+
46+
Returns
47+
-------
48+
List[str]
49+
List of categorical features
50+
51+
Raises
52+
------
53+
Exception
54+
No categorical feature is found
55+
"""
56+
57+
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
58+
cols_categorical = [col for col in df1.columns.to_list() if col not in cols_numerical]
59+
if len(cols_categorical) == 0:
60+
raise Exception("No categorical feature is found.")
61+
else:
62+
return cols_categorical
63+
64+
1565
def _validate_input(X: NDArray) -> pd.DataFrame:
1666
"""
1767
Checks that the input X can be converted into a DataFrame, and returns the corresponding

tests/benchmark/test_metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,16 @@ def test_wasserstein_distance(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.
124124
@pytest.mark.parametrize("df2", [df_imputed])
125125
@pytest.mark.parametrize("df_mask", [df_mask])
126126
def test_kl_divergence(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame) -> None:
127-
result = metrics.kl_divergence_pattern(df1, df1, df_mask, method="columnwise")
127+
result = metrics.kl_divergence(df1, df1, df_mask, method="columnwise")
128128
expected = pd.Series([0.0, 0.0], index=["col1", "col2"])
129129
pd.testing.assert_series_equal(result, expected, atol=1e-3)
130130

131-
result = metrics.kl_divergence_pattern(df1, df2, df_mask, method="columnwise")
131+
result = metrics.kl_divergence(df1, df2, df_mask, method="columnwise")
132132
expected = pd.Series([18.945, 36.637], index=["col1", "col2"])
133133
pd.testing.assert_series_equal(result, expected, atol=1e-3)
134134

135135
df_nonan = df1.notna()
136-
result = metrics.kl_divergence_pattern(df1, df2, df_nonan, method="gaussian", min_n_rows=2)
136+
result = metrics.kl_divergence(df1, df2, df_nonan, method="gaussian", min_n_rows=2)
137137
expected = pd.Series([1.029], index=["All"])
138138
pd.testing.assert_series_equal(result, expected, atol=1e-3)
139139

@@ -154,11 +154,11 @@ def test_kl_divergence_gaussian(
154154

155155
@pytest.mark.parametrize("df1", [df_incomplete])
156156
@pytest.mark.parametrize("df2", [df_imputed])
157-
def test_frechet_distance(df1: pd.DataFrame, df2: pd.DataFrame) -> None:
158-
result = metrics.frechet_distance(df1, df1)
157+
def test_frechet_distance_base(df1: pd.DataFrame, df2: pd.DataFrame) -> None:
158+
result = metrics.frechet_distance_base(df1, df1)
159159
np.testing.assert_allclose(result, 0, atol=1e-3)
160160

161-
result = metrics.frechet_distance(df1, df2)
161+
result = metrics.frechet_distance_base(df1, df2)
162162
np.testing.assert_allclose(result, 0.134, atol=1e-3)
163163

164164

@@ -320,7 +320,7 @@ def test_exception_raise_different_shapes(
320320
with pytest.raises(Exception):
321321
metrics.mean_difference_correlation_matrix_numerical_features(df1, df2, df_mask)
322322
with pytest.raises(Exception):
323-
metrics.frechet_distance(df1, df2)
323+
metrics.frechet_distance_base(df1, df2)
324324

325325

326326
@pytest.mark.parametrize("df1", [df_incomplete_cat])

0 commit comments

Comments
 (0)