Skip to content

Commit 919ba12

Browse files
committed
h2o related fixes and test fixes
1 parent 5462026 commit 919ba12

File tree

9 files changed

+217
-82
lines changed

9 files changed

+217
-82
lines changed

ml_grid/model_classes/H2OBaseClassifier.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def __init__(self, estimator_class=None, **kwargs):
8888
def __del__(self):
8989
"""Cleans up the shared checkpoint directory if this is the last instance."""
9090
# This is a best-effort cleanup. In multi-process scenarios,
91-
# the directory might be in use by other processes.
92-
if os.path.exists(self._checkpoint_dir) and not os.listdir(self._checkpoint_dir):
91+
# the directory might be in use by other processes. Add hasattr check for partial init.
92+
if hasattr(self, '_checkpoint_dir') and os.path.exists(self._checkpoint_dir) and not os.listdir(self._checkpoint_dir):
9393
shutil.rmtree(self._checkpoint_dir, ignore_errors=True)
9494
logger.debug(f"Cleaned up empty shared checkpoint directory: {self._checkpoint_dir}")
9595

@@ -132,7 +132,7 @@ def _validate_input_data(self, X: pd.DataFrame, y: Optional[pd.Series] = None) -
132132
Raises:
133133
ValueError: If data is invalid
134134
"""
135-
# Convert to DataFrame if needed
135+
# Convert to DataFrame if needed and ensure columns are strings
136136
if not isinstance(X, pd.DataFrame):
137137
if self.feature_names_ is not None:
138138
X = pd.DataFrame(X, columns=self.feature_names_)
@@ -142,8 +142,14 @@ def _validate_input_data(self, X: pd.DataFrame, y: Optional[pd.Series] = None) -
142142
f"Input data (X) has {X.shape[1]} columns, but expected {len(self.feature_names_)} "
143143
f"based on training features. Please ensure column count matches."
144144
) # This was the syntax error fix
145-
else: # This else block should be aligned with the outer 'if self.feature_names_ is not None:'
145+
else:
146+
# If X is a numpy array, convert it to a DataFrame and ensure
147+
# its columns are strings to prevent KeyErrors with H2O.
146148
X = pd.DataFrame(X)
149+
X.columns = X.columns.astype(str)
150+
else:
151+
# If it's already a DataFrame, still ensure columns are strings.
152+
X.columns = X.columns.astype(str)
147153

148154
# Reset index to avoid sklearn CV indexing issues
149155
# CRITICAL: If we reset X, we MUST also reset y to maintain alignment.
@@ -272,8 +278,9 @@ def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
272278
model_params.setdefault('ignore_const_cols', False)
273279

274280
# --- ROBUSTNESS FIX: Save checkpoints for model recovery ---
275-
# Unconditionally add checkpoint directory. All H2O estimators support this.
276-
model_params["export_checkpoints_dir"] = self._checkpoint_dir
281+
# Conditionally add checkpoint directory, as not all estimators (e.g., RuleFit) support it.
282+
if 'export_checkpoints_dir' in estimator_params:
283+
model_params["export_checkpoints_dir"] = self._checkpoint_dir
277284

278285
return train_h2o, x_vars, outcome_var, model_params
279286

@@ -294,6 +301,15 @@ def _get_model_params(self) -> Dict[str, Any]:
294301
if key in valid_param_keys
295302
}
296303

304+
# --- FIX for H2OTypeError (e.g., max_depth, sample_rate, learn_rate) ---
305+
# Scikit-learn's ParameterGrid/RandomizedSearchCV can pass single-element numpy arrays or lists.
306+
# H2O expects native Python types (int, float), so we convert them.
307+
for key, value in model_params.items():
308+
if isinstance(value, np.ndarray) and value.size == 1:
309+
model_params[key] = value.item()
310+
elif isinstance(value, list) and len(value) == 1:
311+
model_params[key] = value[0]
312+
297313
return model_params
298314

299315
def _handle_small_data_fallback(self, X: pd.DataFrame, y: pd.Series) -> bool:
@@ -575,31 +591,28 @@ def __sklearn_clone__(self):
575591
self.logger.debug(f"__sklearn_clone__ called: original instance {id(self)}, clone instance {id(cloned)}")
576592
return cloned # Removing dead code
577593

578-
@classmethod
579-
def _get_param_names(cls):
594+
def _get_param_names(self):
580595
"""Get parameter names for the estimator.
581596
582597
This override is necessary because we use **kwargs in __init__.
598+
It's an instance method to access parameters stored on self.
583599
584600
CRITICAL: This should ONLY return parameter names, NOT fitted attribute names.
585601
"""
586-
init_signature = inspect.signature(cls.__init__)
602+
init_signature = inspect.signature(self.__class__.__init__)
587603
init_params = [p.name for p in init_signature.parameters.values()
588604
if p.name not in ('self', 'args', 'kwargs')]
589605

590-
# For instances, also include kwargs that were set
591-
if not isinstance(cls, type):
592-
extra_params = [
593-
key for key in cls.__dict__
594-
if not key.startswith('_') # Exclude private attributes
595-
and not key.endswith('_') # CRITICAL: Exclude fitted attributes
596-
and key not in init_params
597-
and key not in ['estimator_class', 'logger'] # Exclude special attributes
598-
and key not in ['model', 'model_', 'classes_', 'feature_names_', 'model_id'] # Exclude fitted
599-
]
600-
return sorted(init_params + extra_params)
601-
602-
return sorted(init_params)
606+
extra_params = [
607+
key for key in self.__dict__
608+
if not key.startswith('_')
609+
and not (key.endswith('_') and key != 'lambda_') # Allow lambda_
610+
and key not in init_params
611+
and key not in ['estimator_class', 'logger']
612+
and key not in ['model', 'model_', 'classes_', 'feature_names_', 'model_id']
613+
]
614+
615+
return sorted(init_params + extra_params)
603616

604617
def set_params(self: "H2OBaseClassifier", **kwargs: Any) -> "H2OBaseClassifier":
605618
"""Sets the parameters of this estimator.

ml_grid/model_classes/H2OGAMClassifier.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
3737
self._fallback_to_glm = False # Reset flag
3838

3939
# --- 1. Parameter Preprocessing for GAM ---
40+
self.logger.debug(f"DEBUG: Before GAM column processing, model_params['gam_columns'] type: {type(model_params.get('gam_columns'))}, value: {model_params.get('gam_columns')}")
41+
4042
if 'gam_columns' not in model_params or not model_params['gam_columns']:
4143
self.logger.warning("H2OGAMClassifier: 'gam_columns' not provided or empty. Defaulting to all numerical features.")
4244
numeric_cols = [col for col in x_vars if train_h2o[col].types[col] in ['int', 'real']]
@@ -46,6 +48,12 @@ def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
4648
model_params['gam_columns'] = [model_params['gam_columns']]
4749
elif isinstance(model_params['gam_columns'], tuple):
4850
model_params['gam_columns'] = list(model_params['gam_columns'])
51+
# --- FIX for TypeError: object of type 'int' has no len() ---
52+
elif isinstance(model_params['gam_columns'], int):
53+
# If an integer is passed (e.g., from a hyperparameter search),
54+
# convert it to a list containing the column name as a string.
55+
# H2O expects column names to be strings.
56+
model_params['gam_columns'] = [str(model_params['gam_columns'])]
4957
elif isinstance(model_params['gam_columns'], list) and model_params['gam_columns'] and isinstance(model_params['gam_columns'][0], list):
5058
model_params['gam_columns'] = [item for sublist in model_params['gam_columns'] for item in sublist]
5159

@@ -127,7 +135,7 @@ def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
127135
suitable_gam_cols.append(col)
128136
suitable_knots.append(required_knots)
129137
if i < len(bs_list): suitable_bs.append(bs_list[i])
130-
if i < len(scale_list): suitable_scale.append(scale_list[i])
138+
if scale_list and i < len(scale_list): suitable_scale.append(scale_list[i])
131139

132140
if not suitable_gam_cols:
133141
self.logger.warning("No suitable GAM columns found after checking cardinality. Falling back to GLM.")

ml_grid/model_classes/H2OGLMClassifier.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from h2o.estimators import H2OGeneralizedLinearEstimator
22
from .H2OBaseClassifier import H2OBaseClassifier
3+
import pandas as pd
4+
from typing import Any, Dict
35

46
class H2OGLMClassifier(H2OBaseClassifier):
57
"""A scikit-learn compatible wrapper for H2O's Generalized Linear Models.
@@ -10,7 +12,30 @@ def __init__(self, **kwargs):
1012
All keyword arguments are passed directly to the H2OGeneralizedLinearEstimator.
1113
Example args: family='binomial', alpha=0.5
1214
"""
15+
# --- FIX for scikit-learn cloning and H2O's 'lambda' parameter ---
16+
# scikit-learn's get_params() will return 'lambda_', but the user might
17+
# provide 'lambda' in the parameter grid. We must handle both cases.
18+
if 'lambda' in kwargs and 'lambda_' not in kwargs:
19+
kwargs['lambda_'] = kwargs.pop('lambda')
20+
1321
# Remove estimator_class from kwargs if present (happens during sklearn clone)
1422
kwargs.pop('estimator_class', None)
1523
# Pass the specific estimator class
16-
super().__init__(estimator_class=H2OGeneralizedLinearEstimator, **kwargs)
24+
super().__init__(estimator_class=H2OGeneralizedLinearEstimator, **kwargs)
25+
26+
def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs) -> "H2OGLMClassifier":
27+
"""
28+
Fits the H2O GLM model and then corrects the 'lambda_' parameter name for
29+
compatibility with the H2O backend during prediction.
30+
"""
31+
# Call the parent class's fit method to perform the actual training
32+
super().fit(X, y, **kwargs)
33+
34+
# --- CRITICAL FIX for predict-time NullPointerException ---
35+
# The H2O backend's predict method requires the 'lambda' parameter, but the
36+
# Python object may hold it as 'lambda_'. We must ensure the final model
37+
# object has the correct 'lambda' parameter set in its internal params dict.
38+
if self.model_ and 'lambda_' in self.model_.params:
39+
self.model_.params['lambda'] = self.model_.params.pop('lambda_')
40+
41+
return self

tests/conftest.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,76 @@
1-
"""
2-
Pytest configuration file for shared fixtures.
3-
4-
This file makes fixtures available to all test files in this directory
5-
and its subdirectories without needing to import them.
6-
"""
1+
# tests/conftest.py
72

83
import pytest
9-
import pandas as pd
10-
import numpy as np
114
import h2o
12-
13-
# Add the project root directory to the Python path
14-
import sys
5+
import logging
156
import os
16-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7+
8+
# --- Tame TensorFlow ---
9+
# Set log level to suppress info/warnings before importing
10+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
11+
try:
12+
import tensorflow as tf
13+
# Explicitly prevent TF from allocating any GPU memory.
14+
# This stops it from conflicting with H2O's Java VM.
15+
tf.config.set_visible_devices([], 'GPU')
16+
print("\n--- [Fixture Config] TensorFlow GPU explicitly disabled. ---")
17+
except ImportError:
18+
print("\n--- [Fixture Config] TensorFlow not found, skipping GPU disable. ---")
19+
pass
20+
# --- End Tame TensorFlow ---
21+
1722

1823
@pytest.fixture(scope="session")
1924
def h2o_session_fixture():
20-
"""Initializes H2O once per test session for stability and speed."""
21-
h2o.init(nthreads=1, log_level="FATA")
22-
yield
23-
h2o.shutdown(prompt=False)
25+
"""
26+
Session-scoped fixture to initialize and shut down the H2O cluster.
27+
This ensures h2o.init() is called only ONCE for the entire test session.
28+
"""
29+
print("\n--- [H2O Fixture] Initializing H2O cluster... ---")
30+
31+
# Stop h2o from printing progress bars, which can hang in pytest
32+
h2o.no_progress()
33+
34+
# Set up logging
35+
logging.getLogger('h2o').setLevel(logging.DEBUG)
36+
37+
try:
38+
# Start the H2O cluster.
39+
h2o.init(
40+
nthreads=-1, # Use all available cores
41+
max_mem_size="4g", # Adjust as needed
42+
log_level="DEBUG"
43+
)
44+
print("--- [H2O Fixture] H2O cluster initialized successfully. ---")
45+
46+
# Yield to let the tests run
47+
yield
48+
49+
finally:
50+
# This code runs *after* all tests in the session are complete
51+
print("\n--- [H2O Fixture] Shutting down H2O cluster... ---")
52+
53+
# Call remove_all() BEFORE shutdown() to avoid ConnectionError
54+
h2o.remove_all()
55+
h2o.cluster().shutdown()
56+
57+
print("--- [H2O Fixture] H2O cluster shutdown complete. ---")
2458

2559
@pytest.fixture(scope="session")
2660
def synthetic_data():
27-
"""Provides a simple, reusable dataset for testing classifiers."""
28-
X = pd.DataFrame(np.random.rand(50, 3), columns=['f1', 'f2', 'f3'])
29-
y = pd.Series(np.random.randint(0, 2, 50), name="outcome")
30-
return X, y
61+
"""Generates simple synthetic data for classification."""
62+
try:
63+
from sklearn.datasets import make_classification
64+
65+
# Keep n_samples large as a safety precaution
66+
X, y = make_classification(
67+
n_samples=1000,
68+
n_features=10,
69+
n_informative=5,
70+
n_redundant=0,
71+
n_classes=2,
72+
random_state=42
73+
)
74+
return X, y
75+
except ImportError:
76+
pytest.skip("sklearn not installed, skipping synthetic_data generation")

0 commit comments

Comments
 (0)