Skip to content

Commit 315e34a

Browse files
authored
Add new row synthesis single table metric (#226)
* Add synthetic uniqueness single table metric and tests * Add warning for edge case * Update metric name * Update implementation * Add input validation * fix unit test * Fix edge cases in new row synthesis query * Update query logic * Update unit test
1 parent 1e39044 commit 315e34a

File tree

6 files changed

+276
-0
lines changed

6 files changed

+276
-0
lines changed

sdmetrics/multi_table/multi_single_table.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ class BNLikelihood(MultiSingleTableMetric):
241241
single_table_metric = single_table.bayesian_network.BNLikelihood
242242

243243

244+
class NewRowSynthesis(MultiSingleTableMetric):
245+
"""MultiSingleTableMetric based on SingleTable NewRowSynthesis."""
246+
247+
single_table_metric = single_table.new_row_synthesis.NewRowSynthesis
248+
249+
244250
class BNLogLikelihood(MultiSingleTableMetric):
245251
"""MultiSingleTableMetric based on SingleTable BNLogLikelihood."""
246252

sdmetrics/reports/single_table/plot_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,15 @@ def _get_similarity_correlation_matrix(score_breakdowns, columns):
8585
Args:
8686
score_breakdowns (dict):
8787
Mapping of metric to the score breakdown result.
88+
columns (list[string] or set[string]):
89+
A list or set of column names.
8890
8991
Returns:
9092
pandas.DataFrame
9193
"""
94+
if isinstance(columns, set):
95+
columns = list(columns)
96+
9297
similarity_correlation = pd.DataFrame(
9398
index=columns,
9499
columns=columns,

sdmetrics/single_table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sdmetrics.single_table.multi_single_column import (
2323
BoundaryAdherence, CategoryCoverage, CSTest, KSComplement, MissingValueSimilarity,
2424
MultiSingleColumnMetric, RangeCoverage, StatisticSimilarity, TVComplement)
25+
from sdmetrics.single_table.new_row_synthesis import NewRowSynthesis
2526
from sdmetrics.single_table.privacy.base import CategoricalPrivacyMetric, NumericalPrivacyMetric
2627
from sdmetrics.single_table.privacy.cap import (
2728
CategoricalCAP, CategoricalGeneralizedCAP, CategoricalZeroCAP)
@@ -88,4 +89,5 @@
8889
'StatisticSimilarity',
8990
'TVComplement',
9091
'RangeCoverage',
92+
'NewRowSynthesis',
9193
]

sdmetrics/single_table/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Base Single Table metric class."""
22

3+
import copy
34
from operator import attrgetter
45

56
import pandas as pd
@@ -103,6 +104,11 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None):
103104
(pandas.DataFrame, pandas.DataFrame, dict):
104105
The validated data and metadata.
105106
"""
107+
real_data = real_data.copy()
108+
synthetic_data = synthetic_data.copy()
109+
if metadata is not None:
110+
metadata = copy.deepcopy(metadata)
111+
106112
if set(real_data.columns) != set(synthetic_data.columns):
107113
raise ValueError('`real_data` and `synthetic_data` must have the same columns')
108114

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""New Row Synthesis metric for single table."""
2+
import warnings
3+
4+
import pandas as pd
5+
6+
from sdmetrics.goal import Goal
7+
from sdmetrics.single_table.base import SingleTableMetric
8+
9+
10+
class NewRowSynthesis(SingleTableMetric):
11+
"""NewRowSynthesis Single Table metric.
12+
13+
This metric measures whether each row in the synthetic data is new,
14+
or whether it exactly matches a row in the real data.
15+
16+
Attributes:
17+
name (str):
18+
Name to use when reports about this metric are printed.
19+
goal (sdmetrics.goal.Goal):
20+
The goal of this metric.
21+
min_value (Union[float, tuple[float]]):
22+
Minimum value or values that this metric can take.
23+
max_value (Union[float, tuple[float]]):
24+
Maximum value or values that this metric can take.
25+
"""
26+
27+
name = 'NewRowSynthesis'
28+
goal = Goal.MAXIMIZE
29+
min_value = 0
30+
max_value = 1
31+
32+
@classmethod
33+
def compute(cls, real_data, synthetic_data, metadata=None, numerical_match_tolerance=0.01,
34+
synthetic_sample_size=None):
35+
"""Compute this metric.
36+
37+
This metric looks for matches between the real and synthetic data for
38+
the compatible columns. This metric also looks for matches in missing values.
39+
40+
Args:
41+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
42+
The values from the real dataset.
43+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
44+
The values from the synthetic dataset.
45+
metadata (dict):
46+
Table metadata dict.
47+
numerical_match_tolerance (float):
48+
A float larger than 0 representing how close two numerical values have to be
49+
in order to be considered a match. Defaults to `0.01`.
50+
synthetic_sample_size (int):
51+
The number of synthetic rows to sample before computing this metric.
52+
Use this to speed up the computation time if you have a large amount
53+
of synthetic data. Note that the final score may not be as precise if
54+
your sample size is low. Defaults to ``None``, which does not sample,
55+
and uses all of the provided rows.
56+
57+
Returns:
58+
float:
59+
The new row synthesis score.
60+
"""
61+
real_data, synthetic_data, metadata = cls._validate_inputs(
62+
real_data, synthetic_data, metadata)
63+
64+
if synthetic_sample_size is not None:
65+
if synthetic_sample_size > len(synthetic_data):
66+
warnings.warn(f'The provided `synthetic_sample_size` of {synthetic_sample_size} '
67+
'is larger than the number of synthetic data rows '
68+
f'({len(synthetic_data)}). Proceeding without sampling.')
69+
else:
70+
synthetic_data = synthetic_data.sample(n=synthetic_sample_size)
71+
72+
numerical_fields = []
73+
discrete_fields = []
74+
for field, field_meta in metadata['fields'].items():
75+
if field_meta['type'] == 'datetime':
76+
real_data[field] = pd.to_numeric(real_data[field])
77+
synthetic_data[field] = pd.to_numeric(synthetic_data[field])
78+
numerical_fields.append(field)
79+
elif field_meta['type'] == 'numerical':
80+
numerical_fields.append(field)
81+
else:
82+
discrete_fields.append(field)
83+
84+
num_unique_rows = 0
85+
for index, row in synthetic_data.iterrows():
86+
row_filter = []
87+
for field in real_data.columns:
88+
if pd.isna(row[field]):
89+
field_filter = f'{field}.isnull()'
90+
elif field in numerical_fields:
91+
field_filter = (
92+
f'abs({field} - {row[field]}) <= '
93+
f'{abs(numerical_match_tolerance * row[field])}'
94+
)
95+
else:
96+
if real_data[field].dtype == 'O':
97+
field_filter = f"{field} == '{row[field]}'"
98+
else:
99+
field_filter = f'{field} == {row[field]}'
100+
101+
row_filter.append(field_filter)
102+
103+
matches = real_data.query(' and '.join(row_filter))
104+
if matches is None or matches.empty:
105+
num_unique_rows += 1
106+
107+
return num_unique_rows / len(synthetic_data)
108+
109+
@classmethod
110+
def normalize(cls, raw_score):
111+
"""Normalize the log-likelihood value.
112+
113+
Notice that this is not the mean likelihood.
114+
115+
Args:
116+
raw_score (float):
117+
The value of the metric from `compute`.
118+
119+
Returns:
120+
float:
121+
The normalized value of the metric
122+
"""
123+
return super().normalize(raw_score)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from unittest.mock import patch
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from sdmetrics.single_table import NewRowSynthesis
7+
8+
9+
class TestNewRowSynthesis:
10+
11+
def test_compute(self):
12+
"""Test the ``compute`` method and expect that the new row synthesis score is returned."""
13+
# Setup
14+
real_data = pd.DataFrame({
15+
'col1': [0, 1, 2, 3, 4],
16+
'col2': [1, 2, 1, 3, 4],
17+
'col3': ['a', 'b', 'c', 'd', 'b'],
18+
'col4': [1.32, np.nan, 1.43, np.nan, 2.0],
19+
'col5': [51, 52, 53, 54, 55],
20+
'col6': ['2020-01-02', '2021-01-04', '2021-05-03', '2022-10-11', '2022-11-13'],
21+
})
22+
synthetic_data = pd.DataFrame({
23+
'col1': [0, 1, 2, 3, 4],
24+
'col2': [1, 3, 4, 2, 2],
25+
'col3': ['a', 'b', 'c', 'b', 'e'],
26+
'col4': [1.32, 1.56, 1.21, np.nan, 1.90],
27+
'col5': [51, 51, 54, 55, 53],
28+
'col6': ['2020-01-02', '2022-11-24', '2022-06-01', '2021-04-12', '2020-12-11'],
29+
})
30+
metadata = {
31+
'fields': {
32+
'col1': {'type': 'id', 'subtype': 'int'},
33+
'col2': {'type': 'numerical', 'subtype': 'int'},
34+
'col3': {'type': 'categorical'},
35+
'col4': {'type': 'numerical', 'subtype': 'float'},
36+
'col5': {'type': 'categorical'},
37+
'col6': {'type': 'datetime', 'format': '%Y-%m-%d'},
38+
},
39+
}
40+
metric = NewRowSynthesis()
41+
42+
# Run
43+
score = metric.compute(real_data, synthetic_data, metadata)
44+
45+
# Assert
46+
assert score == 0.8
47+
48+
def test_compute_with_sample_size(self):
49+
"""Test the ``compute`` method with a sample size.
50+
51+
Expect that the new row synthesis score is returned.
52+
"""
53+
# Setup
54+
real_data = pd.DataFrame({
55+
'col1': [1, 2, 1, 3, 4],
56+
'col2': ['a', 'b', 'c', 'd', 'b'],
57+
'col3': [1.32, np.nan, 1.43, np.nan, 2.0],
58+
})
59+
synthetic_data = pd.DataFrame({
60+
'col1': [1, 3, 4, 2, 2],
61+
'col2': ['a', 'b', 'c', 'd', 'e'],
62+
'col3': [1.46, 1.56, 1.21, np.nan, 1.92],
63+
})
64+
metadata = {
65+
'fields': {
66+
'col1': {'type': 'numerical', 'subtype': 'int'},
67+
'col2': {'type': 'categorical'},
68+
'col3': {'type': 'numerical', 'subtype': 'float'},
69+
},
70+
}
71+
sample_size = 2
72+
metric = NewRowSynthesis()
73+
74+
# Run
75+
score = metric.compute(
76+
real_data, synthetic_data, metadata, synthetic_sample_size=sample_size)
77+
78+
# Assert
79+
assert score == 1
80+
81+
@patch('sdmetrics.single_table.new_row_synthesis.warnings')
82+
def test_compute_with_sample_size_too_large(self, warnings_mock):
83+
"""Test the ``compute`` method with a sample size larger than the number of rows.
84+
85+
Expect that the new row synthesis is returned. Expect a warning to be raised.
86+
"""
87+
# Setup
88+
real_data = pd.DataFrame({
89+
'col1': [1, 2, 1, 3, 4],
90+
'col2': ['a', 'b', 'c', 'd', 'b'],
91+
'col3': [1.32, np.nan, 1.43, np.nan, 2.0],
92+
})
93+
synthetic_data = pd.DataFrame({
94+
'col1': [1, 3, 4, 2, 2],
95+
'col2': ['a', 'b', 'c', 'd', 'e'],
96+
'col3': [1.35, 1.56, 1.21, np.nan, 1.92],
97+
})
98+
metadata = {
99+
'fields': {
100+
'col1': {'type': 'numerical', 'subtype': 'int'},
101+
'col2': {'type': 'categorical'},
102+
'col3': {'type': 'numerical', 'subtype': 'float'},
103+
},
104+
}
105+
sample_size = 15
106+
metric = NewRowSynthesis()
107+
108+
# Run
109+
score = metric.compute(
110+
real_data, synthetic_data, metadata, synthetic_sample_size=sample_size)
111+
112+
# Assert
113+
assert score == 1
114+
warnings_mock.warn.assert_called_once_with(
115+
'The provided `synthetic_sample_size` of 15 is larger than the number of '
116+
'synthetic data rows (5). Proceeding without sampling.'
117+
)
118+
119+
@patch('sdmetrics.single_table.new_row_synthesis.SingleTableMetric.normalize')
120+
def test_normalize(self, normalize_mock):
121+
"""Test the ``normalize`` method.
122+
123+
Expect that the inherited ``normalize`` method is called.
124+
"""
125+
# Setup
126+
metric = NewRowSynthesis()
127+
raw_score = 0.9
128+
129+
# Run
130+
result = metric.normalize(raw_score)
131+
132+
# Assert
133+
normalize_mock.assert_called_once_with(raw_score)
134+
assert result == normalize_mock.return_value

0 commit comments

Comments
 (0)