|
40 | 40 | from ml_grid.model_classes.light_gbm_class import LightGBMClassifierWrapper |
41 | 41 | from ml_grid.model_classes.logistic_regression_class import LogisticRegressionClass |
42 | 42 | from ml_grid.model_classes.mlp_classifier_class import MLPClassifierClass as MLPClassifierClass |
| 43 | +from ml_grid.model_classes.NeuralNetworkClassifier_class import ( |
| 44 | + NeuralNetworkClassifier_class, |
| 45 | +) |
| 46 | + |
43 | 47 | from ml_grid.model_classes.quadratic_discriminant_class import ( |
44 | 48 | QuadraticDiscriminantAnalysisClass, |
45 | 49 | ) |
46 | 50 | from ml_grid.model_classes.randomforest_classifier_class import ( |
47 | 51 | RandomForestClassifierClass, |
48 | 52 | ) |
| 53 | +from ml_grid.model_classes.svc_class import SVCClass |
49 | 54 | from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass |
50 | 55 |
|
51 | 56 |
|
| 57 | +# --- ROBUST MAPPING of config names to class objects --- |
| 58 | +# This dictionary provides a direct, secure, and explicit mapping from the |
| 59 | +# string names used in the YAML config files to the actual imported Python classes. |
| 60 | +# This avoids the use of `eval()` and makes the code easier to maintain. |
| 61 | +MODEL_CLASS_MAP = { |
| 62 | + # Scikit-learn and similar |
| 63 | + "LogisticRegression": LogisticRegressionClass, |
| 64 | + "LogisticRegressionClass": LogisticRegressionClass, |
| 65 | + "RandomForestClassifier": RandomForestClassifierClass, |
| 66 | + "RandomForestClassifierClass": RandomForestClassifierClass, |
| 67 | + "XGB_class": XGBClassifierClass, |
| 68 | + "XGBClassifierClass": XGBClassifierClass, |
| 69 | + "AdaBoostClassifierClass": AdaBoostClassifierClass, |
| 70 | + "CatBoostClassifierClass": CatBoostClassifierClass, |
| 71 | + "GaussianNBClassifierClass": GaussianNBClassifierClass, |
| 72 | + "GradientBoostingClassifierClass": GradientBoostingClassifierClass, |
| 73 | + "KNeighborsClassifierClass": KNeighborsClassifierClass, |
| 74 | + "LightGBMClassifierWrapper": LightGBMClassifierWrapper, |
| 75 | + "MLPClassifierClass": MLPClassifierClass, |
| 76 | + "QuadraticDiscriminantAnalysisClass": QuadraticDiscriminantAnalysisClass, |
| 77 | + "SVCClass": SVCClass, |
| 78 | + "NeuralNetworkClassifier_class": NeuralNetworkClassifier_class, # Corrected mapping |
| 79 | + # GPU specific |
| 80 | + "KerasClassifierClass": KerasClassifierClass, |
| 81 | + "KNNGpuWrapperClass": KNNGpuWrapperClass, |
| 82 | + # H2O Models |
| 83 | + "H2O_class": H2OAutoMLClass, # Alias for AutoML |
| 84 | + "H2OAutoMLClass": H2OAutoMLClass, |
| 85 | + "H2O_GBM_class": H2O_GBM_class, |
| 86 | + "H2O_DRF_class": H2O_DRF_class, |
| 87 | + "H2O_DeepLearning_class": H2O_DeepLearning_class, |
| 88 | + "H2O_GLM_class": H2O_GLM_class, |
| 89 | + "H2O_NaiveBayes_class": H2O_NaiveBayes_class, |
| 90 | + "H2O_RuleFit_class": H2O_RuleFit_class, |
| 91 | + "H2O_XGBoost_class": H2O_XGBoost_class, |
| 92 | + "H2O_StackedEnsemble_class": H2O_StackedEnsemble_class, |
| 93 | + "H2O_GAM_class": H2O_GAM_class, |
| 94 | +} |
| 95 | + |
| 96 | + |
52 | 97 | def get_model_class_list(ml_grid_object: pipe) -> List[Any]: |
53 | 98 | """Generates a list of instantiated model classes based on the configuration. |
54 | 99 |
|
@@ -153,17 +198,16 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]: |
153 | 198 | f"Skipping '{class_name}' because it requires a GPU, but no CUDA-enabled GPU is available." |
154 | 199 | ) |
155 | 200 | continue |
156 | | - # Try the exact name first, then try with '_class' appended for convenience |
157 | | - try: |
158 | | - model_class = eval(class_name) |
159 | | - except NameError: |
160 | | - class_name_with_suffix = f"{class_name}_class" |
161 | | - try: |
162 | | - model_class = eval(class_name_with_suffix) |
163 | | - except NameError: |
164 | | - raise NameError( |
165 | | - f"Could not find model class '{class_name}' or '{class_name_with_suffix}'. Please check the name and ensure it's imported." |
166 | | - ) |
| 201 | + |
| 202 | + # Look up the class in our explicit mapping dictionary |
| 203 | + model_class = MODEL_CLASS_MAP.get(class_name) |
| 204 | + |
| 205 | + if model_class is None: |
| 206 | + raise KeyError( |
| 207 | + f"Could not find model class '{class_name}' in MODEL_CLASS_MAP. " |
| 208 | + f"Please check the model name in your configuration and ensure it is imported and mapped in model_class_list.py." |
| 209 | + ) |
| 210 | + |
167 | 211 | # Pass X and y to constructors that accept them (like H2OStackedEnsemble) |
168 | 212 | init_signature = inspect.signature(model_class.__init__) |
169 | 213 | init_params = {} |
|
0 commit comments