Skip to content

Commit 1c552d0

Browse files
Julien RousselJulien Roussel
authored andcommitted
Transformer wrapper impelmented
1 parent bb3c8b4 commit 1c552d0

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

tests/imputations/test_imputers.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
{"col1": [0, np.nan, 2, 3, np.nan], "col2": [-1, np.nan, 0.5, np.nan, 1.5]}
1717
)
1818

19+
df_mixed = pd.DataFrame(
20+
{"col1": [0, np.nan, 2, 3, np.nan], "col2": ["a", np.nan, "b", np.nan, "b"]}
21+
)
22+
1923
df_timeseries = pd.DataFrame(
2024
pd.DataFrame(
2125
{
@@ -92,9 +96,7 @@ def test_hyperparameters_get_hyperparameters_modified(
9296
@pytest.mark.parametrize(
9397
"imputer",
9498
[
95-
imputers.ImputerMean(),
96-
imputers.ImputerMedian(),
97-
imputers.ImputerMode(),
99+
imputers.ImputerSimple(),
98100
imputers.ImputerShuffle(),
99101
imputers.ImputerLOCF(),
100102
imputers.ImputerNOCB(),
@@ -109,13 +111,13 @@ def test_Imputer_fit_transform_on_nan_column(df: pd.DataFrame, imputer: imputers
109111

110112
@pytest.mark.parametrize("df", "string")
111113
def test_fit_transform_not_on_pandas(df: Any) -> None:
112-
imputer = imputers.ImputerMean()
114+
imputer = imputers.ImputerSimple()
113115
np.testing.assert_raises(ValueError, imputer.fit_transform, df)
114116

115117

116118
@pytest.mark.parametrize("df", [df_groups])
117119
def test_fit_transform_on_grouped(df: pd.DataFrame) -> None:
118-
imputer = imputers.ImputerMean(groups=("col1",))
120+
imputer = imputers.ImputerSimple(groups=("col1",))
119121
result = imputer.fit_transform(df)
120122
expected = pd.DataFrame(
121123
{
@@ -136,29 +138,27 @@ def test_ImputerOracle_fit_transform(df: pd.DataFrame, df_oracle: pd.DataFrame)
136138
np.testing.assert_allclose(result, expected)
137139

138140

139-
@pytest.mark.parametrize("df", [df_incomplete])
140-
def test_ImputerMean_fit_transform(df: pd.DataFrame) -> None:
141-
imputer = imputers.ImputerMean()
141+
@pytest.mark.parametrize("df", [df_mixed])
142+
def test_ImputerSimple_mean_fit_transform(df: pd.DataFrame) -> None:
143+
imputer = imputers.ImputerSimple(strategy="mean")
142144
result = imputer.fit_transform(df)
143-
expected = pd.DataFrame(
144-
{"col1": [0, 5 / 3, 2, 3, 5 / 3], "col2": [-1, 1 / 3, 0.5, 1 / 3, 1.5]}
145-
)
145+
expected = pd.DataFrame({"col1": [0, 5 / 3, 2, 3, 5 / 3], "col2": ["a", "b", "b", "b", "b"]})
146146
np.testing.assert_allclose(result, expected)
147147

148148

149-
@pytest.mark.parametrize("df", [df_incomplete])
150-
def test_ImputerMedian_fit_transform(df: pd.DataFrame) -> None:
151-
imputer = imputers.ImputerMedian()
149+
@pytest.mark.parametrize("df", [df_mixed])
150+
def test_ImputerSimple_median_fit_transform(df: pd.DataFrame) -> None:
151+
imputer = imputers.ImputerSimple()
152152
result = imputer.fit_transform(df)
153-
expected = pd.DataFrame({"col1": [0, 2, 2, 3, 2], "col2": [-1, 0.5, 0.5, 0.5, 1.5]})
153+
expected = pd.DataFrame({"col1": [0, 2, 2, 3, 2], "col2": ["a", "b", "b", "b", "b"]})
154154
np.testing.assert_allclose(result, expected)
155155

156156

157-
@pytest.mark.parametrize("df", [df_incomplete])
158-
def test_ImputerMode_fit_transform(df: pd.DataFrame) -> None:
159-
imputer = imputers.ImputerMode()
157+
@pytest.mark.parametrize("df", [df_mixed])
158+
def test_ImputerSimple_mode_fit_transform(df: pd.DataFrame) -> None:
159+
imputer = imputers.ImputerSimple(strategy="most_frequent")
160160
result = imputer.fit_transform(df)
161-
expected = pd.DataFrame({"col1": [0, 0, 2, 3, 0], "col2": [-1, -1, 0.5, -1, 1.5]})
161+
expected = pd.DataFrame({"col1": [0, 0, 2, 3, 0], "col2": ["a", "b", "b", "b", "b"]})
162162
np.testing.assert_allclose(result, expected)
163163

164164

@@ -277,9 +277,7 @@ def test_ImputerRpcaNoisy_fit_transform(df: pd.DataFrame) -> None:
277277
df_grouped = pd.DataFrame(dict_values, index=index_grouped)
278278

279279
list_imputers = [
280-
imputers.ImputerMean(groups=("group",)),
281-
imputers.ImputerMedian(groups=("group",)),
282-
imputers.ImputerMode(groups=("group",)),
280+
imputers.ImputerSimple(groups=("group",)),
283281
imputers.ImputerShuffle(groups=("group",)),
284282
imputers.ImputerLOCF(groups=("group",)),
285283
imputers.ImputerNOCB(groups=("group",)),
@@ -306,9 +304,7 @@ def test_models_fit_transform_grouped(imputer):
306304
[
307305
imputers._Imputer(),
308306
imputers.ImputerOracle(),
309-
imputers.ImputerMean(),
310-
imputers.ImputerMedian(),
311-
imputers.ImputerMode(),
307+
imputers.ImputerSimple(),
312308
imputers.ImputerShuffle(),
313309
imputers.ImputerLOCF(),
314310
imputers.ImputerNOCB(),

0 commit comments

Comments
 (0)