1+ import pytest
2+ import numpy as np
3+ import h2o
4+ import logging
5+
6+ # Essential ml_grid imports
7+ from ml_grid .pipeline .grid_search_cross_validate import grid_search_crossvalidate
8+ from ml_grid .util .global_params import global_parameters
9+
10+ # Import all H2O model definition classes
11+ from ml_grid .model_classes .h2o_gbm_classifier_class import H2O_GBM_class
12+ from ml_grid .model_classes .h2o_drf_classifier_class import H2O_DRF_class
13+ from ml_grid .model_classes .h2o_gam_classifier_class import H2O_GAM_class
14+ from ml_grid .model_classes .h2o_deeplearning_classifier_class import H2O_DeepLearning_class
15+ from ml_grid .model_classes .h2o_glm_classifier_class import H2O_GLM_class
16+ from ml_grid .model_classes .h2o_naive_bayes_classifier_class import H2O_NaiveBayes_class
17+ from ml_grid .model_classes .h2o_rulefit_classifier_class import H2O_RuleFit_class
18+ from ml_grid .model_classes .h2o_xgboost_classifier_class import H2O_XGBoost_class
19+ from ml_grid .model_classes .h2o_stackedensemble_classifier_class import H2O_StackedEnsemble_class
20+ from ml_grid .model_classes .h2o_classifier_class import H2OAutoMLConfig as H2O_class # AutoML
21+
22+ # A mock class to simulate the main 'pipe' object for integration testing
23+ class MockMlGridObject :
24+ def __init__ (self , X , y , search_strategy = 'random' ):
25+ self .X_train = X
26+ self .y_train = y
27+ self .X_test = X
28+ self .y_test = y
29+ self .X_test_orig = X
30+ self .y_test_orig = y
31+ self .local_param_dict = {'param_space_size' : 'small' }
32+ self .global_params = global_parameters
33+ self .base_project_dir = "test_experiments/test_run"
34+ self .verbose = 0
35+ self .global_params .cv_folds = 2 # Use 2 folds for faster tests
36+ self .global_params .verbose = 0
37+ self .global_params .error_raise = True
38+ self .global_params .grid_n_jobs = 1 # H2O requires n_jobs=1
39+ self .global_params .test_mode = True # Skips final CV for speed
40+ self .global_params .sub_sample_param_space_pct = 1.0
41+
42+ # Configure search strategy
43+ if search_strategy == 'random' :
44+ self .global_params .random_grid_search = True
45+ self .global_params .bayessearch = False
46+ self .global_params .max_param_space_iter_value = 1
47+ elif search_strategy == 'grid' :
48+ self .global_params .random_grid_search = False
49+ self .global_params .bayessearch = False
50+ elif search_strategy == 'bayes' :
51+ self .global_params .random_grid_search = False
52+ self .global_params .bayessearch = True
53+ self .global_params .max_param_space_iter_value = 1
54+
55+ self .logger = logging .getLogger ('test_logger' )
56+ self .logger .setLevel (logging .DEBUG )
57+
58+ def _prepare_h2o_param_space (instance , model_class , search_strategy ):
59+ """
60+ Helper function to prepare and sanitize H2O parameter spaces for testing.
61+ This centralizes the logic for limiting runtimes and ensuring compatibility.
62+ """
63+ param_space = instance .parameter_space
64+
65+ # Flatten list of dicts into a single dict if necessary
66+ if isinstance (param_space , list ):
67+ flat_param_space = {}
68+ for d in param_space :
69+ flat_param_space .update (d )
70+ param_space = flat_param_space
71+
72+ # For grid search, values must be in a list, even if it's a single element.
73+ # Grid search treats each list element as a separate value to test
74+ # For random/bayes, we can use small lists for sampling
75+
76+ # 1. For AutoML, force a very short runtime
77+ if model_class == H2O_class :
78+ if search_strategy == 'grid' :
79+ param_space ['max_runtime_secs' ] = [5 ]
80+ param_space ['max_models' ] = [2 ]
81+ param_space ['sort_metric' ] = ["AUC" ]
82+ else :
83+ param_space ['max_runtime_secs' ] = [5 ]
84+ param_space ['max_models' ] = [2 ]
85+ param_space ['sort_metric' ] = ["AUC" ]
86+
87+ # 2. For tree-based models, force minimal trees
88+ if 'ntrees' in param_space :
89+ if search_strategy == 'grid' :
90+ param_space ['ntrees' ] = [2 ] # Wrap single value in a list
91+ else :
92+ param_space ['ntrees' ] = [2 , 3 ] # Small list for sampling
93+
94+ # 3. For Deep Learning, force minimal epochs
95+ if model_class == H2O_DeepLearning_class and 'epochs' in param_space :
96+ if search_strategy == 'grid' :
97+ param_space ['epochs' ] = [1 ] # Wrap single value in a list
98+ else :
99+ param_space ['epochs' ] = [1 , 2 ]
100+
101+ # 4. Add max_runtime_secs to ALL models as safety net
102+ if 'max_runtime_secs' not in param_space :
103+ if search_strategy == 'grid' :
104+ param_space ['max_runtime_secs' ] = [10 ] # 10 second timeout
105+ else :
106+ param_space ['max_runtime_secs' ] = [10 ]
107+
108+ # 5. For non-Bayesian searches, convert skopt distributions to concrete values
109+ if search_strategy != 'bayes' :
110+ for key , value in param_space .items ():
111+ if hasattr (value , 'rvs' ): # It's a skopt distribution
112+ if search_strategy == 'grid' :
113+ # Single concrete value for grid search, wrapped in a list
114+ param_space [key ] = [value .rvs (random_state = 0 )]
115+ else : # random search
116+ # Convert to small list
117+ if hasattr (value , 'categories' ):
118+ cats = list (value .categories )
119+ param_space [key ] = cats [:2 ] if len (cats ) > 2 else cats
120+ elif hasattr (value , 'low' ) and isinstance (value .low , int ):
121+ param_space [key ] = [value .low , min (value .low + 1 , value .high )]
122+ elif hasattr (value , 'low' ) and isinstance (value .low , float ):
123+ param_space [key ] = [value .low , (value .low + value .high ) / 2 ]
124+
125+ return param_space
126+
127+ # --- PERFORMANCE FIX: Use only fast models ---
128+ # Exclude AutoML and StackedEnsemble as they are too slow/complex for integration tests
129+ H2O_MODELS_TO_TEST = [
130+ H2O_GLM_class , # Fast, simple
131+ H2O_DRF_class , # Can be fast with limited trees
132+ H2O_GBM_class , # Can be fast with limited trees
133+ ]
134+
135+ # Optional: Add a separate slow test for comprehensive coverage
136+ H2O_SLOW_MODELS = [
137+ H2O_DeepLearning_class ,
138+ H2O_GAM_class ,
139+ H2O_NaiveBayes_class ,
140+ H2O_RuleFit_class ,
141+ H2O_XGBoost_class ,
142+ ]
143+
144+ @pytest .mark .parametrize ("search_strategy" , ["random" , "bayes" , "grid" ])
145+ @pytest .mark .parametrize ("model_class" , H2O_MODELS_TO_TEST )
146+ def test_h2o_search_integrations (model_class , search_strategy , synthetic_data , h2o_session_fixture ):
147+ """
148+ Tests H2O models with all search strategies (Randomized, Bayes, Grid).
149+ This test is parameterized by both model and search strategy to ensure
150+ maximum isolation between test runs, which is more stable for H2O.
151+ """
152+ X , y = synthetic_data
153+
154+ if model_class == H2O_class :
155+ instance = model_class (parameter_space_size = "small" )
156+ else :
157+ instance = model_class (X = X , y = y , parameter_space_size = "small" )
158+
159+ mock_ml_grid_object = MockMlGridObject (X , y , search_strategy = search_strategy )
160+
161+ param_space = _prepare_h2o_param_space (
162+ instance = instance ,
163+ model_class = model_class ,
164+ search_strategy = search_strategy
165+ )
166+
167+ # Clean H2O state before each test
168+ h2o .remove_all ()
169+
170+ result = grid_search_crossvalidate (
171+ algorithm_implementation = instance .algorithm_implementation ,
172+ parameter_space = param_space ,
173+ method_name = instance .method_name ,
174+ ml_grid_object = mock_ml_grid_object
175+ )
176+
177+ assert isinstance (result .grid_search_cross_validate_score_result , float )
178+
179+ # Additional cleanup
180+ h2o .remove_all ()
181+
182+
183+ @pytest .mark .slow
184+ @pytest .mark .parametrize ("search_strategy" , ["random" ]) # Only test one strategy for slow models
185+ @pytest .mark .parametrize ("model_class" , H2O_SLOW_MODELS )
186+ def test_h2o_slow_models (model_class , search_strategy , synthetic_data , h2o_session_fixture ):
187+ """
188+ Separate test for slower H2O models. Run with: pytest -m slow
189+ """
190+ X , y = synthetic_data
191+
192+ instance = model_class (X = X , y = y , parameter_space_size = "small" )
193+ mock_ml_grid_object = MockMlGridObject (X , y , search_strategy = search_strategy )
194+
195+ param_space = _prepare_h2o_param_space (
196+ instance = instance ,
197+ model_class = model_class ,
198+ search_strategy = search_strategy
199+ )
200+
201+ h2o .remove_all ()
202+
203+ result = grid_search_crossvalidate (
204+ algorithm_implementation = instance .algorithm_implementation ,
205+ parameter_space = param_space ,
206+ method_name = instance .method_name ,
207+ ml_grid_object = mock_ml_grid_object
208+ )
209+
210+ assert isinstance (result .grid_search_cross_validate_score_result , float )
211+ h2o .remove_all ()
0 commit comments