Skip to content

Commit a803d36

Browse files
committed
test fixes
1 parent 5d34e84 commit a803d36

File tree

6 files changed

+269
-557
lines changed

6 files changed

+269
-557
lines changed

tests/conftest.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
1-
"""pytest configuration file to adjust the Python path."""
1+
"""
2+
Pytest configuration file for shared fixtures.
23
4+
This file makes fixtures available to all test files in this directory
5+
and its subdirectories without needing to import them.
6+
"""
7+
8+
import pytest
9+
import pandas as pd
10+
import numpy as np
11+
import h2o
12+
13+
# Add the project root directory to the Python path
314
import sys
415
import os
16+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
517

6-
# Add the project root directory to the Python path
7-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
18+
@pytest.fixture(scope="session")
19+
def h2o_session_fixture():
20+
"""Initializes H2O once per test session for stability and speed."""
21+
h2o.init(nthreads=1, log_level="FATA")
22+
yield
23+
h2o.shutdown(prompt=False)
24+
25+
@pytest.fixture(scope="session")
26+
def synthetic_data():
27+
"""Provides a simple, reusable dataset for testing classifiers."""
28+
X = pd.DataFrame(np.random.rand(50, 3), columns=['f1', 'f2', 'f3'])
29+
y = pd.Series(np.random.randint(0, 2, 50), name="outcome")
30+
return X, y

tests/test_h2o_classifiers.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,6 @@
2222
from ml_grid.model_classes.h2o_classifier_class import H2OAutoMLConfig as H2O_class # This is the AutoML class
2323

2424

25-
# A session-scoped fixture to initialize H2O once for all tests
26-
@pytest.fixture(scope="session", autouse=True)
27-
def h2o_session_fixture():
28-
"""Initializes H2O at the beginning of the test session and shuts it down at the end."""
29-
h2o.init(nthreads=1, log_level="FATA") # Use 1 thread for faster, more stable test runs
30-
yield
31-
h2o.shutdown(prompt=False)
32-
33-
34-
# A pytest fixture to create synthetic data for tests
35-
@pytest.fixture
36-
def synthetic_data():
37-
"""Provides a simple dataset for testing classifiers."""
38-
X = pd.DataFrame({
39-
'feature1': np.random.rand(50),
40-
'feature2': np.random.rand(50),
41-
'feature3': np.random.randint(0, 4, 50)
42-
})
43-
y = pd.Series(np.random.randint(0, 2, 50), name="outcome")
44-
return X, y
45-
4625
@pytest.fixture
4726
def tiny_problematic_data():
4827
"""
@@ -98,11 +77,14 @@ def h2o_model_instance(request, synthetic_data):
9877
return instance.algorithm_implementation
9978

10079
# Use pytest.mark.parametrize to run the same test for all classifiers
101-
def test_h2o_classifier_fit_predict(h2o_model_instance, synthetic_data):
80+
def test_h2o_classifier_fit_predict(h2o_model_instance, synthetic_data, h2o_session_fixture):
10281
"""
10382
Tests the basic fit and predict functionality of each H2O wrapper.
10483
"""
10584
X, y = synthetic_data
85+
# Clean up frames from any previous test runs to avoid conflicts
86+
h2o.remove_all()
87+
10688
estimator = h2o_model_instance
10789

10890
# 1. Fit the model
@@ -134,6 +116,10 @@ def test_h2o_classifiers_with_cross_validation(model_class, tiny_problematic_dat
134116
"""
135117
X, y = tiny_problematic_data
136118

119+
# Clean up frames from any previous test runs to avoid conflicts
120+
if h2o.cluster().is_running():
121+
h2o.remove_all()
122+
137123
# Handle special instantiation for AutoML class
138124
if model_class == H2O_class:
139125
instance = model_class(parameter_space_size="small")
@@ -149,10 +135,6 @@ def test_h2o_classifiers_with_cross_validation(model_class, tiny_problematic_dat
149135
# Use 5-fold CV. On 10 samples, this creates 8-sample training folds.
150136
cv = KFold(n_splits=5, shuffle=True, random_state=42)
151137

152-
# Clean up frames from any previous test runs to avoid conflicts
153-
if h2o.cluster().is_running():
154-
h2o.remove_all()
155-
156138
# The tiny_problematic_data can cause folds with constant features.
157139
# The H2OBaseClassifier wrapper correctly raises a RuntimeError in this case.
158140
# We expect this test to either complete successfully OR fail gracefully with
@@ -164,11 +146,15 @@ def test_h2o_classifiers_with_cross_validation(model_class, tiny_problematic_dat
164146
assert "fit on a single constant feature" in str(e), f"Caught unexpected RuntimeError: {e}"
165147

166148

167-
def test_h2o_gam_knot_cardinality_error():
149+
def test_h2o_gam_knot_cardinality_error(h2o_session_fixture):
168150
"""
169151
Tests that H2OGAMClassifier raises a specific ValueError when a feature
170152
in a CV fold has fewer unique values than the number of knots.
171153
"""
154+
# The h2o_session_fixture ensures the cluster is running.
155+
# Clean up frames from any previous test runs to avoid conflicts
156+
h2o.remove_all()
157+
172158
# Create data where 'feature2' has low cardinality
173159
X = pd.DataFrame({
174160
'feature1': np.random.rand(20),
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import pytest
2+
import numpy as np
3+
import h2o
4+
import logging
5+
6+
# Essential ml_grid imports
7+
from ml_grid.pipeline.grid_search_cross_validate import grid_search_crossvalidate
8+
from ml_grid.util.global_params import global_parameters
9+
10+
# Import all H2O model definition classes
11+
from ml_grid.model_classes.h2o_gbm_classifier_class import H2O_GBM_class
12+
from ml_grid.model_classes.h2o_drf_classifier_class import H2O_DRF_class
13+
from ml_grid.model_classes.h2o_gam_classifier_class import H2O_GAM_class
14+
from ml_grid.model_classes.h2o_deeplearning_classifier_class import H2O_DeepLearning_class
15+
from ml_grid.model_classes.h2o_glm_classifier_class import H2O_GLM_class
16+
from ml_grid.model_classes.h2o_naive_bayes_classifier_class import H2O_NaiveBayes_class
17+
from ml_grid.model_classes.h2o_rulefit_classifier_class import H2O_RuleFit_class
18+
from ml_grid.model_classes.h2o_xgboost_classifier_class import H2O_XGBoost_class
19+
from ml_grid.model_classes.h2o_stackedensemble_classifier_class import H2O_StackedEnsemble_class
20+
from ml_grid.model_classes.h2o_classifier_class import H2OAutoMLConfig as H2O_class # AutoML
21+
22+
# A mock class to simulate the main 'pipe' object for integration testing
23+
class MockMlGridObject:
24+
def __init__(self, X, y, search_strategy='random'):
25+
self.X_train = X
26+
self.y_train = y
27+
self.X_test = X
28+
self.y_test = y
29+
self.X_test_orig = X
30+
self.y_test_orig = y
31+
self.local_param_dict = {'param_space_size': 'small'}
32+
self.global_params = global_parameters
33+
self.base_project_dir = "test_experiments/test_run"
34+
self.verbose = 0
35+
self.global_params.cv_folds = 2 # Use 2 folds for faster tests
36+
self.global_params.verbose = 0
37+
self.global_params.error_raise = True
38+
self.global_params.grid_n_jobs = 1 # H2O requires n_jobs=1
39+
self.global_params.test_mode = True # Skips final CV for speed
40+
self.global_params.sub_sample_param_space_pct = 1.0
41+
42+
# Configure search strategy
43+
if search_strategy == 'random':
44+
self.global_params.random_grid_search = True
45+
self.global_params.bayessearch = False
46+
self.global_params.max_param_space_iter_value = 1
47+
elif search_strategy == 'grid':
48+
self.global_params.random_grid_search = False
49+
self.global_params.bayessearch = False
50+
elif search_strategy == 'bayes':
51+
self.global_params.random_grid_search = False
52+
self.global_params.bayessearch = True
53+
self.global_params.max_param_space_iter_value = 1
54+
55+
self.logger = logging.getLogger('test_logger')
56+
self.logger.setLevel(logging.DEBUG)
57+
58+
def _prepare_h2o_param_space(instance, model_class, search_strategy):
59+
"""
60+
Helper function to prepare and sanitize H2O parameter spaces for testing.
61+
This centralizes the logic for limiting runtimes and ensuring compatibility.
62+
"""
63+
param_space = instance.parameter_space
64+
65+
# Flatten list of dicts into a single dict if necessary
66+
if isinstance(param_space, list):
67+
flat_param_space = {}
68+
for d in param_space:
69+
flat_param_space.update(d)
70+
param_space = flat_param_space
71+
72+
# For grid search, values must be in a list, even if it's a single element.
73+
# Grid search treats each list element as a separate value to test
74+
# For random/bayes, we can use small lists for sampling
75+
76+
# 1. For AutoML, force a very short runtime
77+
if model_class == H2O_class:
78+
if search_strategy == 'grid':
79+
param_space['max_runtime_secs'] = [5]
80+
param_space['max_models'] = [2]
81+
param_space['sort_metric'] = ["AUC"]
82+
else:
83+
param_space['max_runtime_secs'] = [5]
84+
param_space['max_models'] = [2]
85+
param_space['sort_metric'] = ["AUC"]
86+
87+
# 2. For tree-based models, force minimal trees
88+
if 'ntrees' in param_space:
89+
if search_strategy == 'grid':
90+
param_space['ntrees'] = [2] # Wrap single value in a list
91+
else:
92+
param_space['ntrees'] = [2, 3] # Small list for sampling
93+
94+
# 3. For Deep Learning, force minimal epochs
95+
if model_class == H2O_DeepLearning_class and 'epochs' in param_space:
96+
if search_strategy == 'grid':
97+
param_space['epochs'] = [1] # Wrap single value in a list
98+
else:
99+
param_space['epochs'] = [1, 2]
100+
101+
# 4. Add max_runtime_secs to ALL models as safety net
102+
if 'max_runtime_secs' not in param_space:
103+
if search_strategy == 'grid':
104+
param_space['max_runtime_secs'] = [10] # 10 second timeout
105+
else:
106+
param_space['max_runtime_secs'] = [10]
107+
108+
# 5. For non-Bayesian searches, convert skopt distributions to concrete values
109+
if search_strategy != 'bayes':
110+
for key, value in param_space.items():
111+
if hasattr(value, 'rvs'): # It's a skopt distribution
112+
if search_strategy == 'grid':
113+
# Single concrete value for grid search, wrapped in a list
114+
param_space[key] = [value.rvs(random_state=0)]
115+
else: # random search
116+
# Convert to small list
117+
if hasattr(value, 'categories'):
118+
cats = list(value.categories)
119+
param_space[key] = cats[:2] if len(cats) > 2 else cats
120+
elif hasattr(value, 'low') and isinstance(value.low, int):
121+
param_space[key] = [value.low, min(value.low + 1, value.high)]
122+
elif hasattr(value, 'low') and isinstance(value.low, float):
123+
param_space[key] = [value.low, (value.low + value.high) / 2]
124+
125+
return param_space
126+
127+
# --- PERFORMANCE FIX: Use only fast models ---
128+
# Exclude AutoML and StackedEnsemble as they are too slow/complex for integration tests
129+
H2O_MODELS_TO_TEST = [
130+
H2O_GLM_class, # Fast, simple
131+
H2O_DRF_class, # Can be fast with limited trees
132+
H2O_GBM_class, # Can be fast with limited trees
133+
]
134+
135+
# Optional: Add a separate slow test for comprehensive coverage
136+
H2O_SLOW_MODELS = [
137+
H2O_DeepLearning_class,
138+
H2O_GAM_class,
139+
H2O_NaiveBayes_class,
140+
H2O_RuleFit_class,
141+
H2O_XGBoost_class,
142+
]
143+
144+
@pytest.mark.parametrize("search_strategy", ["random", "bayes", "grid"])
145+
@pytest.mark.parametrize("model_class", H2O_MODELS_TO_TEST)
146+
def test_h2o_search_integrations(model_class, search_strategy, synthetic_data, h2o_session_fixture):
147+
"""
148+
Tests H2O models with all search strategies (Randomized, Bayes, Grid).
149+
This test is parameterized by both model and search strategy to ensure
150+
maximum isolation between test runs, which is more stable for H2O.
151+
"""
152+
X, y = synthetic_data
153+
154+
if model_class == H2O_class:
155+
instance = model_class(parameter_space_size="small")
156+
else:
157+
instance = model_class(X=X, y=y, parameter_space_size="small")
158+
159+
mock_ml_grid_object = MockMlGridObject(X, y, search_strategy=search_strategy)
160+
161+
param_space = _prepare_h2o_param_space(
162+
instance=instance,
163+
model_class=model_class,
164+
search_strategy=search_strategy
165+
)
166+
167+
# Clean H2O state before each test
168+
h2o.remove_all()
169+
170+
result = grid_search_crossvalidate(
171+
algorithm_implementation=instance.algorithm_implementation,
172+
parameter_space=param_space,
173+
method_name=instance.method_name,
174+
ml_grid_object=mock_ml_grid_object
175+
)
176+
177+
assert isinstance(result.grid_search_cross_validate_score_result, float)
178+
179+
# Additional cleanup
180+
h2o.remove_all()
181+
182+
183+
@pytest.mark.slow
184+
@pytest.mark.parametrize("search_strategy", ["random"]) # Only test one strategy for slow models
185+
@pytest.mark.parametrize("model_class", H2O_SLOW_MODELS)
186+
def test_h2o_slow_models(model_class, search_strategy, synthetic_data, h2o_session_fixture):
187+
"""
188+
Separate test for slower H2O models. Run with: pytest -m slow
189+
"""
190+
X, y = synthetic_data
191+
192+
instance = model_class(X=X, y=y, parameter_space_size="small")
193+
mock_ml_grid_object = MockMlGridObject(X, y, search_strategy=search_strategy)
194+
195+
param_space = _prepare_h2o_param_space(
196+
instance=instance,
197+
model_class=model_class,
198+
search_strategy=search_strategy
199+
)
200+
201+
h2o.remove_all()
202+
203+
result = grid_search_crossvalidate(
204+
algorithm_implementation=instance.algorithm_implementation,
205+
parameter_space=param_space,
206+
method_name=instance.method_name,
207+
ml_grid_object=mock_ml_grid_object
208+
)
209+
210+
assert isinstance(result.grid_search_cross_validate_score_result, float)
211+
h2o.remove_all()

0 commit comments

Comments
 (0)