|
1 | 1 | import numpy as np |
2 | 2 | import pandas as pd |
3 | 3 | import pytest |
| 4 | + |
4 | 5 | from sklearn.ensemble import RandomForestClassifier |
5 | 6 | from sklearn.linear_model import LinearRegression |
| 7 | +from sklearn.model_selection import StratifiedKFold |
6 | 8 | from sklearn.tree import DecisionTreeRegressor |
7 | 9 |
|
8 | 10 | from feature_engine.selection import SelectByShuffling |
@@ -93,6 +95,42 @@ def test_regression_cv_2_and_mse(load_diabetes_dataset): |
93 | 95 | pd.testing.assert_frame_equal(sel.transform(X), Xtransformed) |
94 | 96 |
|
95 | 97 |
|
| 98 | +def test_cv_generator(df_test): |
| 99 | + X, y = df_test |
| 100 | + cv = StratifiedKFold(n_splits=3) |
| 101 | + |
| 102 | + X, y = df_test |
| 103 | + sel = SelectByShuffling( |
| 104 | + RandomForestClassifier(random_state=1), |
| 105 | + threshold=0.01, |
| 106 | + random_state=1, |
| 107 | + cv=3, |
| 108 | + ) |
| 109 | + sel.fit(X, y) |
| 110 | + |
| 111 | + # expected result |
| 112 | + Xtransformed = pd.DataFrame(X["var_7"].copy()) |
| 113 | + pd.testing.assert_frame_equal(sel.transform(X), Xtransformed) |
| 114 | + |
| 115 | + sel = SelectByShuffling( |
| 116 | + RandomForestClassifier(random_state=1), |
| 117 | + threshold=0.01, |
| 118 | + random_state=1, |
| 119 | + cv=cv, |
| 120 | + ) |
| 121 | + sel.fit(X, y) |
| 122 | + pd.testing.assert_frame_equal(sel.transform(X), Xtransformed) |
| 123 | + |
| 124 | + sel = SelectByShuffling( |
| 125 | + RandomForestClassifier(random_state=1), |
| 126 | + threshold=0.01, |
| 127 | + random_state=1, |
| 128 | + cv=cv.split(X, y), |
| 129 | + ) |
| 130 | + sel.fit(X, y) |
| 131 | + pd.testing.assert_frame_equal(sel.transform(X), Xtransformed) |
| 132 | + |
| 133 | + |
96 | 134 | def test_raises_threshold_error(): |
97 | 135 | with pytest.raises(ValueError): |
98 | 136 | SelectByShuffling(RandomForestClassifier(random_state=1), threshold="hello") |
|
0 commit comments