|
4 | 4 | from sklearn.compose import make_column_selector as selector |
5 | 5 |
|
6 | 6 | from sklearn.pipeline import Pipeline |
7 | | -from sklearn.base import BaseEstimator |
| 7 | +from sklearn.base import BaseEstimator, TransformerMixin |
8 | 8 | from sklearn.metrics import mean_squared_error |
9 | 9 | from sklearn.utils.estimator_checks import check_estimator |
10 | 10 | from sklearn.utils.validation import check_X_y, check_array |
|
14 | 14 | BinTransformer, |
15 | 15 | MixteHGBM, |
16 | 16 | OneHotEncoderProjector, |
| 17 | + WrapperTransformer, |
17 | 18 | make_pipeline_mixte_preprocessing, |
18 | 19 | make_robust_MixteHGB, |
19 | 20 | ) |
@@ -130,6 +131,56 @@ def test_inverse_transform_OneHotEncoderProjector(encoder): |
130 | 131 | pd.testing.assert_frame_equal(df, df_back) |
131 | 132 |
|
132 | 133 |
|
| 134 | +############################## |
| 135 | +# Testing WrapperTransformer # |
| 136 | +############################## |
| 137 | + |
| 138 | + |
| 139 | +class DummyTransformer(TransformerMixin, BaseEstimator): |
| 140 | + def fit(self, X, y=None): |
| 141 | + return self |
| 142 | + |
| 143 | + def transform(self, X): |
| 144 | + return X |
| 145 | + |
| 146 | + def fit_transform(self, X, y=None): |
| 147 | + return self.fit(X, y).transform(X) |
| 148 | + |
| 149 | + def inverse_transform(self, X, y=None): |
| 150 | + return X |
| 151 | + |
| 152 | + |
| 153 | +@pytest.fixture |
| 154 | +def wrapper_transformer(): |
| 155 | + transformer = DummyTransformer() |
| 156 | + wrapper = DummyTransformer() |
| 157 | + return WrapperTransformer(transformer, wrapper) |
| 158 | + |
| 159 | + |
| 160 | +def test_fit_WrapperTransformer(wrapper_transformer): |
| 161 | + X = np.array([[1, 2], [3, 4]]) |
| 162 | + result = wrapper_transformer.fit(X) |
| 163 | + assert result == wrapper_transformer |
| 164 | + |
| 165 | + |
| 166 | +def test_fit_transform_WrapperTransformer(wrapper_transformer): |
| 167 | + X = np.array([[1, 2], [3, 4]]) |
| 168 | + result = wrapper_transformer.fit_transform(X) |
| 169 | + assert np.array_equal(result, X) |
| 170 | + |
| 171 | + |
| 172 | +def test_transform_WrapperTransformer(wrapper_transformer): |
| 173 | + X = np.array([[1, 2], [3, 4]]) |
| 174 | + result = wrapper_transformer.transform(X) |
| 175 | + assert np.array_equal(result, X) |
| 176 | + |
| 177 | + |
| 178 | +def test_fit_transform_with_dataframes_WrapperTransformer(wrapper_transformer): |
| 179 | + df = pd.DataFrame({"C1": ["a", "b", "b"], "C2": ["c", "d", "c"]}) |
| 180 | + result = wrapper_transformer.fit_transform(df) |
| 181 | + pd.testing.assert_frame_equal(result, df) |
| 182 | + |
| 183 | + |
133 | 184 | ############################################# |
134 | 185 | # Testing make_pipeline_mixte_preprocessing # |
135 | 186 | ############################################# |
|
0 commit comments