Skip to content

Commit 3b7d4a9

Browse files
Julien RousselJulien Roussel
authored andcommitted
tests extended to all preprocessing.py
1 parent 51e237f commit 3b7d4a9

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

tests/imputations/test_preprocessing.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.compose import make_column_selector as selector
55

66
from sklearn.pipeline import Pipeline
7-
from sklearn.base import BaseEstimator
7+
from sklearn.base import BaseEstimator, TransformerMixin
88
from sklearn.metrics import mean_squared_error
99
from sklearn.utils.estimator_checks import check_estimator
1010
from sklearn.utils.validation import check_X_y, check_array
@@ -14,6 +14,7 @@
1414
BinTransformer,
1515
MixteHGBM,
1616
OneHotEncoderProjector,
17+
WrapperTransformer,
1718
make_pipeline_mixte_preprocessing,
1819
make_robust_MixteHGB,
1920
)
@@ -130,6 +131,56 @@ def test_inverse_transform_OneHotEncoderProjector(encoder):
130131
pd.testing.assert_frame_equal(df, df_back)
131132

132133

134+
##############################
135+
# Testing WrapperTransformer #
136+
##############################
137+
138+
139+
class DummyTransformer(TransformerMixin, BaseEstimator):
140+
def fit(self, X, y=None):
141+
return self
142+
143+
def transform(self, X):
144+
return X
145+
146+
def fit_transform(self, X, y=None):
147+
return self.fit(X, y).transform(X)
148+
149+
def inverse_transform(self, X, y=None):
150+
return X
151+
152+
153+
@pytest.fixture
154+
def wrapper_transformer():
155+
transformer = DummyTransformer()
156+
wrapper = DummyTransformer()
157+
return WrapperTransformer(transformer, wrapper)
158+
159+
160+
def test_fit_WrapperTransformer(wrapper_transformer):
161+
X = np.array([[1, 2], [3, 4]])
162+
result = wrapper_transformer.fit(X)
163+
assert result == wrapper_transformer
164+
165+
166+
def test_fit_transform_WrapperTransformer(wrapper_transformer):
167+
X = np.array([[1, 2], [3, 4]])
168+
result = wrapper_transformer.fit_transform(X)
169+
assert np.array_equal(result, X)
170+
171+
172+
def test_transform_WrapperTransformer(wrapper_transformer):
173+
X = np.array([[1, 2], [3, 4]])
174+
result = wrapper_transformer.transform(X)
175+
assert np.array_equal(result, X)
176+
177+
178+
def test_fit_transform_with_dataframes_WrapperTransformer(wrapper_transformer):
179+
df = pd.DataFrame({"C1": ["a", "b", "b"], "C2": ["c", "d", "c"]})
180+
result = wrapper_transformer.fit_transform(df)
181+
pd.testing.assert_frame_equal(result, df)
182+
183+
133184
#############################################
134185
# Testing make_pipeline_mixte_preprocessing #
135186
#############################################

0 commit comments

Comments
 (0)