Skip to content

Commit f90de3a

Browse files
committed
linting, minor refactoring
1 parent ffc16ea commit f90de3a

18 files changed

+522
-663
lines changed

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,64 @@
1-
import logging
21
import time
2+
import traceback
3+
import logging
34
import warnings
45
from typing import Any, Dict, List, Optional, Union
56

7+
import keras
68
import numpy as np
79
import pandas as pd
810
import tensorflow as tf
911
import torch
1012
from IPython.display import clear_output
11-
from pandas.testing import assert_index_equal
13+
from numpy import absolute, mean, std
1214
from scikeras.wrappers import KerasClassifier
1315
from sklearn import metrics
14-
15-
# from sklearn.utils.testing import ignore_warnings
16-
from sklearn.exceptions import ConvergenceWarning
17-
from sklearn.metrics import *
18-
from sklearn.model_selection import (
19-
KFold,
20-
ParameterGrid,
21-
RepeatedKFold,
22-
cross_validate,
23-
)
24-
from sklearn.preprocessing import MinMaxScaler
25-
from skopt.space import Categorical
16+
from IPython.display import display
17+
from catboost import CatBoostError
18+
from pandas.testing import assert_index_equal
2619
from xgboost.core import XGBoostError
27-
2820
from ml_grid.model_classes.H2OAutoMLClassifier import H2OAutoMLClassifier
29-
from ml_grid.model_classes.H2ODeepLearningClassifier import H2ODeepLearningClassifier
21+
from ml_grid.model_classes.H2OGBMClassifier import H2OGBMClassifier
3022
from ml_grid.model_classes.H2ODRFClassifier import H2ODRFClassifier
3123
from ml_grid.model_classes.H2OGAMClassifier import H2OGAMClassifier
32-
from ml_grid.model_classes.H2OGBMClassifier import H2OGBMClassifier
24+
from ml_grid.model_classes.H2ODeepLearningClassifier import H2ODeepLearningClassifier
3325
from ml_grid.model_classes.H2OGLMClassifier import H2OGLMClassifier
3426
from ml_grid.model_classes.H2ONaiveBayesClassifier import H2ONaiveBayesClassifier
3527
from ml_grid.model_classes.H2ORuleFitClassifier import H2ORuleFitClassifier
28+
from ml_grid.model_classes.H2OXGBoostClassifier import H2OXGBoostClassifier
3629
from ml_grid.model_classes.H2OStackedEnsembleClassifier import (
3730
H2OStackedEnsembleClassifier,
3831
)
39-
from ml_grid.model_classes.H2OXGBoostClassifier import H2OXGBoostClassifier
40-
from ml_grid.model_classes.keras_classifier_class import KerasClassifierClass
4132
from ml_grid.model_classes.NeuralNetworkKerasClassifier import NeuralNetworkClassifier
33+
34+
# from sklearn.utils.testing import ignore_warnings
35+
from sklearn.exceptions import ConvergenceWarning
36+
from sklearn.metrics import *
37+
from sklearn.metrics import (
38+
classification_report,
39+
f1_score,
40+
make_scorer,
41+
matthews_corrcoef,
42+
roc_auc_score,
43+
)
44+
from sklearn.model_selection import (
45+
GridSearchCV,
46+
ParameterGrid,
47+
RandomizedSearchCV,
48+
RepeatedKFold,
49+
KFold,
50+
cross_validate,
51+
)
52+
53+
from ml_grid.model_classes.keras_classifier_class import KerasClassifierClass
4254
from ml_grid.pipeline.hyperparameter_search import HyperparameterSearch
43-
from ml_grid.util.bayes_utils import is_skopt_space
4455
from ml_grid.util.debug_print_statements import debug_print_statements_class
4556
from ml_grid.util.global_params import global_parameters
4657
from ml_grid.util.project_score_save import project_score_save_class
4758
from ml_grid.util.validate_parameters import validate_parameters_helper
59+
from sklearn.preprocessing import MinMaxScaler
60+
from ml_grid.util.bayes_utils import calculate_combinations, is_skopt_space
61+
from skopt.space import Categorical
4862

4963

5064
class grid_search_crossvalidate:
@@ -412,10 +426,10 @@ def __init__(
412426

413427
# Catch only one class present AUC not defined:
414428

415-
# dummy_auc_scorer = make_scorer(dummy_auc)
416429
if len(np.unique(self.y_train)) < 2:
417430
raise ValueError(
418-
"Only one class present in y_train. ROC AUC score is not defined in that case. grid_search_cross_validate>>>cross_validate"
431+
"Only one class present in y_train. ROC AUC score is not defined "
432+
"in that case. grid_search_cross_validate>>>cross_validate"
419433
)
420434

421435
if self.global_parameters.verbose >= 1:
@@ -434,12 +448,10 @@ def __init__(
434448
# Default scores if cross-validation fails
435449
default_scores = {
436450
"test_accuracy": [
437-
0.5
438-
], # Default to random classifier performance (0.5 for binary classification)
451+
0.5 # Default to random classifier performance
452+
],
439453
"test_f1": [0.5], # Default F1 score (again, 0.5 for random classification)
440-
"test_auc": [
441-
0.5
442-
], # Default ROC AUC score (0.5 for random classifier) #is only auc not roc_auc?
454+
"test_auc": [0.5], # Default ROC AUC score (0.5 for random classifier)
443455
"fit_time": [0], # No fitting time if the model fails
444456
"score_time": [0], # No scoring time if the model fails
445457
"train_score": [0.5], # Default train score
@@ -625,7 +637,8 @@ def __init__(
625637
# Print a warning if the execution time exceeds the threshold
626638
if elapsed_time > time_threshold:
627639
self.logger.warning(
628-
f"Cross-validation took too long ({elapsed_time:.2f} seconds). Consider optimizing the parameters or reducing CV folds."
640+
f"Cross-validation took too long ({elapsed_time:.2f} seconds). "
641+
"Consider optimizing the parameters or reducing CV folds."
629642
)
630643
else:
631644
self.logger.info(

ml_grid/pipeline/model_class_list.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from typing import Any, Dict, List, Optional
77

88
import torch
9+
10+
from ml_grid.pipeline.data import pipe
11+
12+
# Import all model classes to make them available for eval()
913
from ml_grid.model_classes.adaboost_classifier_class import AdaBoostClassifierClass
10-
from ml_grid.model_classes.catboost_classifier_class import (
11-
CatBoostClassifierClass,
12-
)
14+
from ml_grid.model_classes.catboost_classifier_class import CatBoostClassifierClass
1315
from ml_grid.model_classes.gaussiannb_class import (
1416
GaussianNBClassifierClass,
1517
)
@@ -20,39 +22,31 @@
2022
from ml_grid.model_classes.h2o_deeplearning_classifier_class import (
2123
H2O_DeepLearning_class,
2224
)
23-
from ml_grid.model_classes.h2o_drf_classifier_class import H2ODRFClass as H2O_DRF_class
25+
from ml_grid.model_classes.h2o_drf_classifier_class import H2ODRFClass as H2O_DRF_class
2426
from ml_grid.model_classes.h2o_gam_classifier_class import H2OGAMClass as H2O_GAM_class
25-
from ml_grid.model_classes.h2o_gbm_classifier_class import (
26-
H2O_GBM_class,
27-
)
27+
from ml_grid.model_classes.h2o_gbm_classifier_class import H2O_GBM_class
2828
from ml_grid.model_classes.h2o_glm_classifier_class import H2O_GLM_class
29-
from ml_grid.model_classes.h2o_naive_bayes_classifier_class import (
30-
H2O_NaiveBayes_class,
29+
from ml_grid.model_classes.h2o_naive_bayes_classifier_class import H2O_NaiveBayes_class
30+
from ml_grid.model_classes.h2o_rulefit_classifier_class import (
31+
H2ORuleFitClass as H2O_RuleFit_class,
3132
)
32-
from ml_grid.model_classes.h2o_rulefit_classifier_class import H2ORuleFitClass as H2O_RuleFit_class
3333
from ml_grid.model_classes.h2o_stackedensemble_classifier_class import (
3434
H2O_StackedEnsemble_class,
3535
)
3636
from ml_grid.model_classes.h2o_xgboost_classifier_class import H2O_XGBoost_class
37-
from ml_grid.model_classes.keras_classifier_class import KerasClassifierClass
3837
from ml_grid.model_classes.knn_classifier_class import KNeighborsClassifierClass
3938
from ml_grid.model_classes.knn_gpu_classifier_class import KNNGpuWrapperClass
40-
from ml_grid.model_classes.light_gbm_class import LightGBMClassifierWrapper
39+
from ml_grid.model_classes.keras_classifier_class import KerasClassifierClass
40+
from ml_grid.model_classes.light_gbm_class import LightGBMClassifierWrapper
4141
from ml_grid.model_classes.logistic_regression_class import LogisticRegressionClass
42-
from ml_grid.model_classes.mlp_classifier_class import MLPClassifierClass
43-
from ml_grid.model_classes.NeuralNetworkClassifier_class import (
44-
NeuralNetworkClassifier_class,
45-
)
42+
from ml_grid.model_classes.mlp_classifier_class import MLPClassifierClass as MLPClassifierClass
4643
from ml_grid.model_classes.quadratic_discriminant_class import (
4744
QuadraticDiscriminantAnalysisClass,
4845
)
4946
from ml_grid.model_classes.randomforest_classifier_class import (
5047
RandomForestClassifierClass,
5148
)
52-
from ml_grid.model_classes.svc_class import SVCClass
53-
from ml_grid.model_classes.tabtransformer_classifier_class import TabTransformerClass
54-
from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass
55-
from ml_grid.pipeline.data import pipe
49+
from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass
5650

5751

5852
def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
@@ -98,7 +92,7 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
9892
"LogisticRegressionClass": True,
9993
"KNeighborsClassifierClass": True,
10094
"QuadraticDiscriminantAnalysisClass": True,
101-
"SVCClass": True,
95+
"SVCClass": False,
10296
"XGBClassifierClass": True,
10397
"MLPClassifierClass": True,
10498
"RandomForestClassifierClass": True,

0 commit comments

Comments
 (0)