Skip to content

Commit 0b47bbf

Browse files
committed
minor changes
1 parent f90de3a commit 0b47bbf

File tree

4 files changed

+70
-42
lines changed

4 files changed

+70
-42
lines changed

ml_grid/model_classes/H2OGAMClassifier.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -143,46 +143,30 @@ def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
143143
model_params["num_knots"] = num_knots_list
144144

145145
for i, col in enumerate(gam_columns):
146+
# --- FIX: Ensure column exists before trying to access it ---
146147
if col not in X.columns:
147148
self.logger.warning(
148149
f"GAM column '{col}' not found in input data X. Skipping."
149150
)
150151
continue
151152

153+
# --- FIX: Validate knot count against unique values in the data ---
152154
n_unique = X[col].nunique()
153155
required_knots = num_knots_list[i]
154156

155-
# H2O's backend requires num_knots < n_unique.
156-
if n_unique <= required_knots:
157+
# --- ROBUSTNESS FIX for java.lang.AssertionError in H2O quantile calculation ---
158+
# The quantile calculation can fail on sparse data or data with low cardinality.
159+
# Enforce a stricter requirement: the number of unique values must be at least
160+
# double the number of knots. This provides a safer margin for the algorithm.
161+
if n_unique < (required_knots * 2):
157162
if not self._suppress_low_cardinality_error:
158163
raise ValueError(
159-
f"Number of knots ({required_knots}) must be at least one less than the number of unique values ({n_unique}) for feature '{col}'."
164+
f"Feature '{col}' has {n_unique} unique values, which is insufficient "
165+
f"for the requested {required_knots} knots. At least {required_knots * 2} unique values are required."
160166
)
161167
self.logger.warning(
162168
f"Excluding GAM column '{col}': {n_unique} unique values "
163-
f"insufficient for {required_knots} knots (require >= {required_knots + 1})."
164-
)
165-
continue
166-
167-
# Pre-check for well-defined knots
168-
try:
169-
quantiles = np.linspace(0, 1, required_knots)
170-
knot_values = X[col].quantile(quantiles)
171-
# Check for enough unique values AND that they are monotonically increasing
172-
# The diff() will be > 0 for all elements in a strictly increasing series.
173-
are_knots_valid = (knot_values.nunique() >= required_knots) and (
174-
np.all(np.diff(knot_values.to_numpy()) > 0)
175-
)
176-
177-
if not are_knots_valid:
178-
self.logger.warning(
179-
f"Excluding GAM column '{col}': Not enough unique values to generate distinct, "
180-
f"monotonically increasing knots."
181-
)
182-
continue
183-
except Exception as e:
184-
self.logger.warning(
185-
f"Excluding GAM column '{col}' due to an error during knot pre-check: {e}"
169+
f"is insufficient for {required_knots} knots (requires at least {required_knots * 2}). Skipping."
186170
)
187171
continue
188172

ml_grid/pipeline/model_class_list.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,60 @@
4040
from ml_grid.model_classes.light_gbm_class import LightGBMClassifierWrapper
4141
from ml_grid.model_classes.logistic_regression_class import LogisticRegressionClass
4242
from ml_grid.model_classes.mlp_classifier_class import MLPClassifierClass as MLPClassifierClass
43+
from ml_grid.model_classes.NeuralNetworkClassifier_class import (
44+
NeuralNetworkClassifier_class,
45+
)
46+
4347
from ml_grid.model_classes.quadratic_discriminant_class import (
4448
QuadraticDiscriminantAnalysisClass,
4549
)
4650
from ml_grid.model_classes.randomforest_classifier_class import (
4751
RandomForestClassifierClass,
4852
)
53+
from ml_grid.model_classes.svc_class import SVCClass
4954
from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass
5055

5156

57+
# --- ROBUST MAPPING of config names to class objects ---
58+
# This dictionary provides a direct, secure, and explicit mapping from the
59+
# string names used in the YAML config files to the actual imported Python classes.
60+
# This avoids the use of `eval()` and makes the code easier to maintain.
61+
MODEL_CLASS_MAP = {
62+
# Scikit-learn and similar
63+
"LogisticRegression": LogisticRegressionClass,
64+
"LogisticRegressionClass": LogisticRegressionClass,
65+
"RandomForestClassifier": RandomForestClassifierClass,
66+
"RandomForestClassifierClass": RandomForestClassifierClass,
67+
"XGB_class": XGBClassifierClass,
68+
"XGBClassifierClass": XGBClassifierClass,
69+
"AdaBoostClassifierClass": AdaBoostClassifierClass,
70+
"CatBoostClassifierClass": CatBoostClassifierClass,
71+
"GaussianNBClassifierClass": GaussianNBClassifierClass,
72+
"GradientBoostingClassifierClass": GradientBoostingClassifierClass,
73+
"KNeighborsClassifierClass": KNeighborsClassifierClass,
74+
"LightGBMClassifierWrapper": LightGBMClassifierWrapper,
75+
"MLPClassifierClass": MLPClassifierClass,
76+
"QuadraticDiscriminantAnalysisClass": QuadraticDiscriminantAnalysisClass,
77+
"SVCClass": SVCClass,
78+
"NeuralNetworkClassifier_class": NeuralNetworkClassifier_class, # Corrected mapping
79+
# GPU specific
80+
"KerasClassifierClass": KerasClassifierClass,
81+
"KNNGpuWrapperClass": KNNGpuWrapperClass,
82+
# H2O Models
83+
"H2O_class": H2OAutoMLClass, # Alias for AutoML
84+
"H2OAutoMLClass": H2OAutoMLClass,
85+
"H2O_GBM_class": H2O_GBM_class,
86+
"H2O_DRF_class": H2O_DRF_class,
87+
"H2O_DeepLearning_class": H2O_DeepLearning_class,
88+
"H2O_GLM_class": H2O_GLM_class,
89+
"H2O_NaiveBayes_class": H2O_NaiveBayes_class,
90+
"H2O_RuleFit_class": H2O_RuleFit_class,
91+
"H2O_XGBoost_class": H2O_XGBoost_class,
92+
"H2O_StackedEnsemble_class": H2O_StackedEnsemble_class,
93+
"H2O_GAM_class": H2O_GAM_class,
94+
}
95+
96+
5297
def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
5398
"""Generates a list of instantiated model classes based on the configuration.
5499
@@ -153,17 +198,16 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
153198
f"Skipping '{class_name}' because it requires a GPU, but no CUDA-enabled GPU is available."
154199
)
155200
continue
156-
# Try the exact name first, then try with '_class' appended for convenience
157-
try:
158-
model_class = eval(class_name)
159-
except NameError:
160-
class_name_with_suffix = f"{class_name}_class"
161-
try:
162-
model_class = eval(class_name_with_suffix)
163-
except NameError:
164-
raise NameError(
165-
f"Could not find model class '{class_name}' or '{class_name_with_suffix}'. Please check the name and ensure it's imported."
166-
)
201+
202+
# Look up the class in our explicit mapping dictionary
203+
model_class = MODEL_CLASS_MAP.get(class_name)
204+
205+
if model_class is None:
206+
raise KeyError(
207+
f"Could not find model class '{class_name}' in MODEL_CLASS_MAP. "
208+
f"Please check the model name in your configuration and ensure it is imported and mapped in model_class_list.py."
209+
)
210+
167211
# Pass X and y to constructors that accept them (like H2OStackedEnsemble)
168212
init_signature = inspect.signature(model_class.__init__)
169213
init_params = {}

ml_grid/pipeline/test_data_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def setUp(self):
8585
},
8686
}
8787
self.drop_term_list = ["chrom", "hfe", "phlebo"]
88-
self.model_class_dict = {"LogisticRegression_class": True}
88+
self.model_class_dict = {"LogisticRegressionClass": True}
8989

9090
def tearDown(self):
9191
"""Clean up the temporary directory after each test."""

tests/test_h2o_classifiers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,9 @@ def test_h2o_gam_knot_cardinality_error(h2o_session_fixture):
195195
estimator = H2O_GAM_class(X=X, y=y, parameter_space_size="small").algorithm_implementation
196196

197197
# Set parameters that will cause the error: 5 knots for a feature with 2
198-
# unique values. # noqa: E501
198+
# unique values.
199199
# Also, we must disable the wrapper's internal error handling that
200-
# suppresses this
201-
# specific error, so that cross_val_score can raise it as intended.
200+
# suppresses this specific error, so that cross_val_score can raise it as intended.
202201
estimator.set_params(
203202
gam_columns=['feature2'],
204203
num_knots=5,
@@ -210,9 +209,10 @@ def test_h2o_gam_knot_cardinality_error(h2o_session_fixture):
210209
cv = KFold(n_splits=2, shuffle=True, random_state=42)
211210

212211
# We expect cross_val_score to fail and raise our specific ValueError
212+
# Updated regex to match the actual error message from the code
213213
with pytest.raises(
214214
ValueError,
215-
match=r"Number of knots .* must be at least one less than the number of unique values",
215+
match=r"Feature .* has \d+ unique values, which is insufficient for the requested \d+ knots\. At least \d+ unique values are required\.",
216216
):
217217
# The error_score='raise' is crucial for pytest.raises to catch the exception
218218
cross_val_score(estimator, X, y, cv=cv, error_score='raise', n_jobs=1)

0 commit comments

Comments
 (0)