@@ -368,12 +368,12 @@ def total_variance_distance(
368368 pd.Series
369369 Total variance distance
370370 """
371- cols_categorical = utils ._get_categorical_features (df1 )
372371 return columnwise_metric (
373- df1 [ cols_categorical ] ,
374- df2 [ cols_categorical ] ,
375- df_mask [ cols_categorical ] ,
372+ df1 ,
373+ df2 ,
374+ df_mask ,
376375 _total_variance_distance_1D ,
376+ type_cols = "categorical" ,
377377 )
378378
379379
@@ -792,7 +792,7 @@ def frechet_distance(
792792 df1 ,
793793 df2 ,
794794 df_mask ,
795- frechet_distance ,
795+ frechet_distance_base ,
796796 min_n_rows = min_n_rows ,
797797 type_cols = "numerical" ,
798798 )
@@ -1003,10 +1003,12 @@ def pattern_based_weighted_mean_metric(
10031003 cols = df1 .select_dtypes (exclude = ["number" ]).columns
10041004 else :
10051005 raise ValueError (f"Value { type_cols } is not valid for parameter `type_cols`!" )
1006+
10061007 if np .any (df_mask & df1 .isna ()):
10071008 raise ValueError ("The argument df1 has missing values on the mask!" )
10081009 if np .any (df_mask & df2 .isna ()):
10091010 raise ValueError ("The argument df2 has missing values on the mask!" )
1011+
10101012 rows_mask = df_mask .any (axis = 1 )
10111013 scores = []
10121014 weights = []
@@ -1041,7 +1043,7 @@ def get_metric(name: str) -> Callable:
10411043 "KS_test" : kolmogorov_smirnov_test ,
10421044 "correlation_diff" : mean_difference_correlation_matrix_numerical_features ,
10431045 "energy" : sum_energy_distances ,
1044- "frechet_single " : partial (frechet_distance , method = "single" ),
1046+ "frechet " : partial (frechet_distance , method = "single" ),
10451047 "frechet_pattern" : partial (frechet_distance , method = "pattern" ),
10461048 "dist_corr_pattern" : distance_anticorr_pattern ,
10471049 }
0 commit comments