Skip to content

Commit 21871a5

Browse files
Julien RousselJulien Roussel
authored andcommitted
tests passing
1 parent 4a47ebf commit 21871a5

File tree

3 files changed

+77
-53
lines changed

3 files changed

+77
-53
lines changed

examples/tutorials/plot_tuto_categorical.ipynb

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,39 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"id": "a220df49",
6+
"id": "7bec9ffc",
77
"metadata": {},
88
"outputs": [],
99
"source": [
1010
"%load_ext autoreload\n",
1111
"%autoreload 2"
1212
]
1313
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 122,
17+
"id": "3587aa0a",
18+
"metadata": {},
19+
"outputs": [
20+
{
21+
"data": {
22+
"text/plain": [
23+
"(1, 5)"
24+
]
25+
},
26+
"execution_count": 122,
27+
"metadata": {},
28+
"output_type": "execute_result"
29+
}
30+
],
31+
"source": [
32+
"np.array([[1, 2, 3, np.nan, 5]]).shape"
33+
]
34+
},
1435
{
1536
"cell_type": "code",
1637
"execution_count": 2,
17-
"id": "80d3ba10",
38+
"id": "b5136ec4",
1839
"metadata": {},
1940
"outputs": [
2041
{
@@ -59,7 +80,7 @@
5980
{
6081
"cell_type": "code",
6182
"execution_count": 3,
62-
"id": "96667a7c",
83+
"id": "55e57e8b",
6384
"metadata": {},
6485
"outputs": [],
6586
"source": [
@@ -69,7 +90,7 @@
6990
{
7091
"cell_type": "code",
7192
"execution_count": 4,
72-
"id": "f5ee81fb",
93+
"id": "61bcee05",
7394
"metadata": {},
7495
"outputs": [],
7596
"source": [
@@ -79,7 +100,7 @@
79100
{
80101
"cell_type": "code",
81102
"execution_count": 5,
82-
"id": "ace56085",
103+
"id": "88f12b44",
83104
"metadata": {},
84105
"outputs": [],
85106
"source": [
@@ -90,7 +111,7 @@
90111
{
91112
"cell_type": "code",
92113
"execution_count": 6,
93-
"id": "dab55dd3",
114+
"id": "34cfca81",
94115
"metadata": {},
95116
"outputs": [],
96117
"source": [
@@ -101,7 +122,7 @@
101122
{
102123
"cell_type": "code",
103124
"execution_count": 7,
104-
"id": "ba1ea100",
125+
"id": "7cefa249",
105126
"metadata": {},
106127
"outputs": [
107128
{
@@ -184,7 +205,7 @@
184205
{
185206
"cell_type": "code",
186207
"execution_count": 8,
187-
"id": "f571ad13",
208+
"id": "87e6d8c6",
188209
"metadata": {},
189210
"outputs": [],
190211
"source": [
@@ -197,7 +218,7 @@
197218
{
198219
"cell_type": "code",
199220
"execution_count": 10,
200-
"id": "d2a26bd9",
221+
"id": "394b40d9",
201222
"metadata": {},
202223
"outputs": [],
203224
"source": [
@@ -212,7 +233,7 @@
212233
{
213234
"cell_type": "code",
214235
"execution_count": 11,
215-
"id": "a005e3b6",
236+
"id": "915c5caf",
216237
"metadata": {},
217238
"outputs": [],
218239
"source": [
@@ -222,7 +243,7 @@
222243
{
223244
"cell_type": "code",
224245
"execution_count": 12,
225-
"id": "d5bdbcb3",
246+
"id": "93d0c02d",
226247
"metadata": {},
227248
"outputs": [],
228249
"source": [
@@ -232,7 +253,7 @@
232253
{
233254
"cell_type": "code",
234255
"execution_count": 15,
235-
"id": "2ad54886",
256+
"id": "ea24e781",
236257
"metadata": {},
237258
"outputs": [],
238259
"source": [
@@ -244,7 +265,7 @@
244265
{
245266
"cell_type": "code",
246267
"execution_count": 16,
247-
"id": "711a8e3e",
268+
"id": "e5347dfe",
248269
"metadata": {},
249270
"outputs": [],
250271
"source": [
@@ -254,7 +275,7 @@
254275
{
255276
"cell_type": "code",
256277
"execution_count": 17,
257-
"id": "e57379ae",
278+
"id": "325b7354",
258279
"metadata": {},
259280
"outputs": [],
260281
"source": [
@@ -268,7 +289,7 @@
268289
{
269290
"cell_type": "code",
270291
"execution_count": 61,
271-
"id": "c727306a",
292+
"id": "5d4c2127",
272293
"metadata": {},
273294
"outputs": [
274295
{
@@ -295,7 +316,7 @@
295316
{
296317
"cell_type": "code",
297318
"execution_count": 111,
298-
"id": "7668b17c",
319+
"id": "4b0ebe4e",
299320
"metadata": {},
300321
"outputs": [
301322
{
@@ -358,7 +379,7 @@
358379
{
359380
"cell_type": "code",
360381
"execution_count": 112,
361-
"id": "edcd6516",
382+
"id": "08640c07",
362383
"metadata": {},
363384
"outputs": [
364385
{
@@ -432,7 +453,7 @@
432453
},
433454
{
434455
"cell_type": "markdown",
435-
"id": "b6127f00",
456+
"id": "d8193a27",
436457
"metadata": {},
437458
"source": [
438459
"# Imputation analysis"
@@ -441,7 +462,7 @@
441462
{
442463
"cell_type": "code",
443464
"execution_count": 113,
444-
"id": "d6ad8c0c",
465+
"id": "4df8e2ce",
445466
"metadata": {},
446467
"outputs": [],
447468
"source": [
@@ -453,7 +474,7 @@
453474
{
454475
"cell_type": "code",
455476
"execution_count": 114,
456-
"id": "8834e9e6",
477+
"id": "c4681f8e",
457478
"metadata": {},
458479
"outputs": [],
459480
"source": [
@@ -464,7 +485,7 @@
464485
{
465486
"cell_type": "code",
466487
"execution_count": 115,
467-
"id": "02cb4a6e",
488+
"id": "1537a2a7",
468489
"metadata": {},
469490
"outputs": [],
470491
"source": [
@@ -476,7 +497,7 @@
476497
{
477498
"cell_type": "code",
478499
"execution_count": 116,
479-
"id": "b11df2f4",
500+
"id": "dad580cc",
480501
"metadata": {},
481502
"outputs": [
482503
{
@@ -600,7 +621,7 @@
600621
{
601622
"cell_type": "code",
602623
"execution_count": 117,
603-
"id": "671d6b3c",
624+
"id": "12a99c70",
604625
"metadata": {},
605626
"outputs": [
606627
{
@@ -632,7 +653,7 @@
632653
{
633654
"cell_type": "code",
634655
"execution_count": 120,
635-
"id": "ccc38665",
656+
"id": "8006ba1e",
636657
"metadata": {},
637658
"outputs": [
638659
{
@@ -663,7 +684,7 @@
663684
{
664685
"cell_type": "code",
665686
"execution_count": null,
666-
"id": "b8c5a4b4",
687+
"id": "b8cc543a",
667688
"metadata": {},
668689
"outputs": [],
669690
"source": []
Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22
import pandas as pd
33
import pytest
44
from sklearn.compose import make_column_selector as selector
5-
from sklearn.ensemble import (
6-
HistGradientBoostingClassifier,
7-
HistGradientBoostingRegressor,
8-
)
5+
96
from sklearn.pipeline import Pipeline
107
from sklearn.base import BaseEstimator
118
from sklearn.metrics import mean_squared_error
129
from sklearn.utils.estimator_checks import check_estimator
1310
from sklearn.utils.validation import check_X_y, check_array
1411
from sklearn.model_selection import train_test_split
1512
from sklearn.compose import ColumnTransformer
16-
from qolmat.imputations.estimators import (
13+
from qolmat.imputations.preprocessing import (
1714
BinTransformer,
1815
MixteHGBM,
1916
make_pipeline_mixte_preprocessing,
@@ -71,33 +68,39 @@ def bin_transformer():
7168

7269

7370
def test_fit_transform(bin_transformer):
74-
X = np.array([1, 2, 3, np.nan, 5])
75-
transformed_X = bin_transformer.fit_transform(X)
76-
assert np.array_equal(transformed_X, np.array([1, 2, 3, np.nan, 5]), equal_nan=True)
71+
X = np.array([[1, 2, 3, np.nan, 5]]).T
72+
X_transformed = bin_transformer.fit_transform(X)
73+
assert np.array_equal(X_transformed, X, equal_nan=True)
7774

7875

7976
def test_transform(bin_transformer):
8077
bin_transformer.dict_df_bins_ = {
8178
0: pd.DataFrame({"value": [1, 2, 3, 4, 5], "min": [-np.inf, 1.5, 2.5, 3.5, 4.5]})
8279
}
83-
X = np.array([4.2, -1, 3.0, 4.5, 12])
84-
transformed_X = bin_transformer.transform(X)
85-
assert np.array_equal(transformed_X, np.array([4, 1, 3, 5, 5]))
80+
X = np.array([[4.2, -1, 3.0, 4.5, 12]]).T
81+
X_transformed = bin_transformer.transform(X)
82+
print(X_transformed)
83+
print(X)
84+
assert np.array_equal(X_transformed, np.array([[4, 1, 3, 5, 5]]).T)
8685

8786

88-
def test_fit_transform_with_series(bin_transformer):
89-
X = pd.Series([1, 2, 3, np.nan, 5])
90-
transformed_X = bin_transformer.fit_transform(X)
91-
pd.testing.assert_series_equal(transformed_X, pd.Series([1, 2, 3, np.nan, 5]))
87+
def test_fit_transform_with_dataframes(bin_transformer):
88+
X = pd.DataFrame({"0": [1, 2, 3, np.nan, 5]})
89+
X_transformed = bin_transformer.fit_transform(X)
90+
print(X_transformed)
91+
print(X)
92+
pd.testing.assert_frame_equal(X_transformed, X)
9293

9394

94-
def test_transform_with_series(bin_transformer):
95+
def test_transform_with_dataframes(bin_transformer):
9596
bin_transformer.dict_df_bins_ = {
9697
0: pd.DataFrame({"value": [1, 2, 3, 4, 5], "min": [0.5, 1.5, 2.5, 3.5, 4.5]})
9798
}
98-
X = pd.Series([1, 2, 3, 4, 5])
99-
transformed_X = bin_transformer.transform(X)
100-
pd.testing.assert_series_equal(transformed_X, pd.Series([1, 2, 3, 4, 5], dtype=float))
99+
X = pd.DataFrame({"0": [1, 2, 3, 4, 5]})
100+
X_transformed = bin_transformer.transform(X)
101+
print(X_transformed)
102+
print(X)
103+
pd.testing.assert_frame_equal(X_transformed, X)
101104

102105

103106
# Testing make_pipeline_mixte_preprocessing
@@ -114,21 +117,21 @@ def test_preprocessing_pipeline(preprocessing_pipeline):
114117

115118
# Test with numerical features
116119
X_num = pd.DataFrame([[1, 2], [3, 4], [5, 6]])
117-
transformed_X = preprocessing_pipeline.fit_transform(X_num)
118-
assert isinstance(transformed_X, pd.DataFrame)
119-
assert transformed_X.shape[1] == X_num.shape[1]
120+
X_transformed = preprocessing_pipeline.fit_transform(X_num)
121+
assert isinstance(X_transformed, pd.DataFrame)
122+
assert X_transformed.shape[1] == X_num.shape[1]
120123

121124
# Test with categorical features
122125
X_cat = pd.DataFrame([["a", "b"], ["c", "d"], ["e", "f"]])
123-
transformed_X = preprocessing_pipeline.fit_transform(X_cat)
124-
assert isinstance(transformed_X, pd.DataFrame)
125-
assert transformed_X.shape[1] > X_cat.shape[1]
126+
X_transformed = preprocessing_pipeline.fit_transform(X_cat)
127+
assert isinstance(X_transformed, pd.DataFrame)
128+
assert X_transformed.shape[1] > X_cat.shape[1]
126129

127130
# Test with mixed features
128131
X_mixed = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"]])
129-
transformed_X = preprocessing_pipeline.fit_transform(X_mixed)
130-
assert isinstance(transformed_X, pd.DataFrame)
131-
assert transformed_X.shape[1] > X_mixed.shape[1]
132+
X_transformed = preprocessing_pipeline.fit_transform(X_mixed)
133+
assert isinstance(X_transformed, pd.DataFrame)
134+
assert X_transformed.shape[1] > X_mixed.shape[1]
132135

133136

134137
# Testing make_robust_MixteHGB

0 commit comments

Comments
 (0)