@@ -59,9 +59,9 @@ def columnwise_metric(
5959 if type_cols == "all" :
6060 cols = df1 .columns
6161 elif type_cols == "numerical" :
62- cols = _get_numerical_features (df1 )
62+ cols = utils . _get_numerical_features (df1 )
6363 elif type_cols == "categorical" :
64- cols = _get_categorical_features (df1 )
64+ cols = utils . _get_categorical_features (df1 )
6565 else :
6666 raise ValueError (f"Value { type_cols } is not valid for parameter `type_cols`!" )
6767 values = {}
@@ -282,56 +282,6 @@ def dist_wasserstein(
282282 )
283283
284284
285- def _get_numerical_features (df1 : pd .DataFrame ) -> List [str ]:
286- """Get numerical features from dataframe
287-
288- Parameters
289- ----------
290- df1 : pd.DataFrame
291-
292- Returns
293- -------
294- List[str]
295- List of numerical features
296-
297- Raises
298- ------
299- Exception
300- No numerical feature is found
301- """
302- cols_numerical = df1 .select_dtypes (include = np .number ).columns .tolist ()
303- if len (cols_numerical ) == 0 :
304- raise Exception ("No numerical feature is found." )
305- else :
306- return cols_numerical
307-
308-
309- def _get_categorical_features (df1 : pd .DataFrame ) -> List [str ]:
310- """Get categorical features from dataframe
311-
312- Parameters
313- ----------
314- df1 : pd.DataFrame
315-
316- Returns
317- -------
318- List[str]
319- List of categorical features
320-
321- Raises
322- ------
323- Exception
324- No categorical feature is found
325- """
326-
327- cols_numerical = df1 .select_dtypes (include = np .number ).columns .tolist ()
328- cols_categorical = [col for col in df1 .columns .to_list () if col not in cols_numerical ]
329- if len (cols_categorical ) == 0 :
330- raise Exception ("No categorical feature is found." )
331- else :
332- return cols_categorical
333-
334-
335285def kolmogorov_smirnov_test_1D (df1 : pd .Series , df2 : pd .Series ) -> float :
336286 """Compute KS test statistic of the two-sample Kolmogorov-Smirnov test for goodness of fit.
337287 See more in https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html.
@@ -418,7 +368,7 @@ def total_variance_distance(
418368 pd.Series
419369 Total variance distance
420370 """
421- cols_categorical = _get_categorical_features (df1 )
371+ cols_categorical = utils . _get_categorical_features (df1 )
422372 return columnwise_metric (
423373 df1 [cols_categorical ],
424374 df2 [cols_categorical ],
@@ -491,7 +441,7 @@ def mean_difference_correlation_matrix_numerical_features(
491441
492442 _check_same_number_columns (df1 , df2 )
493443
494- cols_numerical = _get_numerical_features (df1 )
444+ cols_numerical = utils . _get_numerical_features (df1 )
495445 df_corr1 = _get_correlation_pearson_matrix (df1 [cols_numerical ], use_p_value = use_p_value )
496446 df_corr2 = _get_correlation_pearson_matrix (df2 [cols_numerical ], use_p_value = use_p_value )
497447
@@ -560,7 +510,7 @@ def mean_difference_correlation_matrix_categorical_features(
560510
561511 _check_same_number_columns (df1 , df2 )
562512
563- cols_categorical = _get_categorical_features (df1 )
513+ cols_categorical = utils . _get_categorical_features (df1 )
564514 df_corr1 = _get_correlation_chi2_matrix (df1 [cols_categorical ], use_p_value = use_p_value )
565515 df_corr2 = _get_correlation_chi2_matrix (df2 [cols_categorical ], use_p_value = use_p_value )
566516
@@ -635,8 +585,8 @@ def mean_diff_corr_matrix_categorical_vs_numerical_features(
635585
636586 _check_same_number_columns (df1 , df2 )
637587
638- cols_categorical = _get_categorical_features (df1 )
639- cols_numerical = _get_numerical_features (df1 )
588+ cols_categorical = utils . _get_categorical_features (df1 )
589+ cols_numerical = utils . _get_numerical_features (df1 )
640590 df_corr1 = _get_correlation_f_oneway_matrix (
641591 df1 , cols_categorical , cols_numerical , use_p_value = use_p_value
642592 )
@@ -763,10 +713,10 @@ def sum_pairwise_distances(
763713###########################
764714
765715
766- def frechet_distance (
716+ def frechet_distance_base (
767717 df1 : pd .DataFrame ,
768718 df2 : pd .DataFrame ,
769- ) -> float :
719+ ) -> pd . Series :
770720 """Compute the Fréchet distance between two dataframes df1 and df2
771721 Frechet_distance = || mu_1 - mu_2 ||_2^2 + Tr(Sigma_1 + Sigma_2 - 2(Sigma_1 . Sigma_2)^(1/2))
772722 It is normalized, df1 and df2 are first scaled by a factor (std(df1) + std(df2)) / 2
@@ -783,8 +733,8 @@ def frechet_distance(
783733
784734 Returns
785735 -------
786- float
787- frechet distance
736+ pd.Series
737+ Frechet distance in a Series object
788738 """
789739
790740 if df1 .shape != df2 .shape :
@@ -798,16 +748,23 @@ def frechet_distance(
798748 means1 , cov1 = utils .nan_mean_cov (df1 .values )
799749 means2 , cov2 = utils .nan_mean_cov (df2 .values )
800750
801- return algebra .frechet_distance_exact (means1 , cov1 , means2 , cov2 )
751+ distance = algebra .frechet_distance_exact (means1 , cov1 , means2 , cov2 )
752+ return pd .Series (distance , index = ["All" ])
802753
803754
804- def frechet_distance_pattern (
755+ def frechet_distance (
805756 df1 : pd .DataFrame ,
806757 df2 : pd .DataFrame ,
807758 df_mask : pd .DataFrame ,
759+ method : str = "single" ,
808760 min_n_rows : int = 10 ,
809761) -> pd .Series :
810- """Frechet distance computed using a pattern decomposition
762+ """
763+ Frechet distance computed using a pattern decomposition. Several variant are implemented:
764+ - the `single` method relies on a single estimation of the means and covariance matrix. It is
765+ relevent for MCAR data.
766+ - the `pattern`method relies on the aggregation of the estimated distance between each
767+ pattern. It is relevent for MAR data.
811768
812769 Parameters
813770 ----------
@@ -817,6 +774,9 @@ def frechet_distance_pattern(
817774 Second empirical ditribution
818775 df_mask : pd.DataFrame
819776 Mask indicating on which values the distance has to computed on
777+ method: str
778+ Method used to compute the distance on multivariate datasets with missing values.
779+ Possible values are `robust` and `pattern`.
820780 min_n_rows: int
821781 Minimum number of rows for a KL estimation
822782
@@ -826,6 +786,8 @@ def frechet_distance_pattern(
826786 Series of computed metrics
827787 """
828788
789+ if method == "single" :
790+ return frechet_distance_base (df1 , df2 )
829791 return pattern_based_weighted_mean_metric (
830792 df1 ,
831793 df2 ,
@@ -890,7 +852,7 @@ def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame) -> float:
890852 return div_kl
891853
892854
893- def kl_divergence_pattern (
855+ def kl_divergence (
894856 df1 : pd .DataFrame ,
895857 df2 : pd .DataFrame ,
896858 df_mask : pd .DataFrame ,
@@ -913,7 +875,8 @@ def kl_divergence_pattern(
913875 df_mask: pd.DataFrame
914876 Mask indicating on what values the divergence should be computed
915877 method: str
916- Method used
878+ Method used to compute the divergence on multivariate datasets with missing values.
879+ Possible values are `columnwise` and `gaussian`.
917880 min_n_rows: int
918881 Minimum number of rows for a KL estimation
919882
@@ -1073,12 +1036,13 @@ def get_metric(name: str) -> Callable:
10731036 "wmape" : weighted_mean_absolute_percentage_error ,
10741037 "accuracy" : accuracy ,
10751038 "wasserstein_columnwise" : dist_wasserstein ,
1076- "KL_columnwise" : partial (kl_divergence_pattern , method = "columnwise" ),
1077- "KL_gaussian" : partial (kl_divergence_pattern , method = "gaussian" ),
1078- "ks_test " : kolmogorov_smirnov_test ,
1039+ "KL_columnwise" : partial (kl_divergence , method = "columnwise" ),
1040+ "KL_gaussian" : partial (kl_divergence , method = "gaussian" ),
1041+ "KS_test " : kolmogorov_smirnov_test ,
10791042 "correlation_diff" : mean_difference_correlation_matrix_numerical_features ,
10801043 "energy" : sum_energy_distances ,
1081- "frechet" : frechet_distance_pattern ,
1044+ "frechet_single" : partial (frechet_distance , method = "single" ),
1045+ "frechet_pattern" : partial (frechet_distance , method = "pattern" ),
10821046 "dist_corr_pattern" : distance_anticorr_pattern ,
10831047 }
10841048 return dict_metrics [name ]
0 commit comments