Skip to content

Commit c89883d

Browse files
committed
h2o related
1 parent 919ba12 commit c89883d

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,20 @@ def __init__(
153153

154154
self.y_test_orig = self.ml_grid_object_iter.y_test_orig
155155

156-
# --- DEFINITIVE FIX for H2O data type error in CV ---
157-
# Convert the target variable to a categorical type *before* it's passed
158-
# to any H2O or search function. This ensures H2OFrame correctly infers
159-
# the type, even in complex nested pipelines like BayesSearchCV.
156+
# --- ROBUST DATA TYPE HANDLING ---
157+
# Ensure X_train is a pandas DataFrame and y_train is a pandas Series
158+
# with aligned indices. This handles inputs being numpy arrays (from tests)
159+
# or pandas objects, preventing AttributeError and ensuring consistency.
160+
161+
# 1. Ensure X_train is a DataFrame.
162+
if not isinstance(self.X_train, pd.DataFrame):
163+
self.X_train = pd.DataFrame(self.X_train).rename(columns=str)
164+
165+
# 2. Ensure y_train is a Series, using X_train's index for alignment.
166+
if not isinstance(self.y_train, (pd.Series, pd.DataFrame)):
167+
self.y_train = pd.Series(self.y_train, index=self.X_train.index)
168+
169+
# 3. Ensure target is categorical for classification models (especially H2O).
160170
self.y_train = self.y_train.astype('category')
161171

162172
# --- CRITICAL FIX for H2O Stacked Ensemble response column mismatch ---

0 commit comments

Comments
 (0)