@@ -116,9 +116,12 @@ def test_fit_successful(mock_h2o_init, mock_h2o_cluster, mock_h2o_frame, classif
116116 assert classifier_instance .feature_types_ == {'feature1' : 'real' , 'feature2' : 'real' , 'feature3' : 'enum' }
117117
118118@patch ('h2o.get_model' )
119+ @patch ('h2o.get_frame' )
120+ @patch ('h2o.assign' )
119121@patch ('h2o.H2OFrame' )
120122@patch ('h2o.cluster' )
121- def test_predict_successful (mock_h2o_cluster , mock_h2o_frame , mock_h2o_get_model , classifier_instance , sample_data ):
123+ def test_predict_successful (mock_h2o_cluster , mock_h2o_frame , mock_h2o_assign , mock_h2o_get_frame , mock_h2o_get_model ,
124+ classifier_instance , sample_data ):
122125 """
123126 Tests the `predict` method on a pre-fitted classifier.
124127 """
@@ -132,12 +135,26 @@ def test_predict_successful(mock_h2o_cluster, mock_h2o_frame, mock_h2o_get_model
132135
133136 # --- Setup Mocks ---
134137 # Mock the H2OFrame that will be created from the input data
135- mock_frame_instance = MagicMock ()
136- mock_frame_instance .nrows = len (X ) # This is the crucial fix
137- mock_h2o_frame .return_value = mock_frame_instance
138+ mock_tmp_frame = MagicMock (spec = h2o .H2OFrame )
139+ mock_h2o_frame .return_value = mock_tmp_frame
138140
141+ # Mock the final frame that get_frame will return
142+ mock_final_frame = MagicMock (spec = h2o .H2OFrame )
143+ mock_final_frame .nrows = len (X )
144+ mock_h2o_get_frame .return_value = mock_final_frame
139145 # Mock the model object that `h2o.get_model` will return
140- mock_model = MockH2OEstimator ()
146+ # --- FIX: Replace the real predict method with a MagicMock ---
147+ # Instantiate the mock estimator
148+ mock_model = MockH2OEstimator ()
149+ # Create a mock for the predict method so we can assert calls
150+ mock_model .predict = MagicMock ()
151+ # Configure the mock's return value to simulate H2O's behavior
152+ mock_pred_frame = MagicMock ()
153+ mock_pred_frame .as_data_frame .return_value = pd .DataFrame ({
154+ 'predict' : np .random .randint (0 , 2 , len (X ))
155+ })
156+ mock_model .predict .return_value = mock_pred_frame
157+
141158 mock_h2o_get_model .return_value = mock_model
142159
143160 # Mock H2O cluster status
@@ -150,8 +167,12 @@ def test_predict_successful(mock_h2o_cluster, mock_h2o_frame, mock_h2o_get_model
150167 # 1. Check that the model was retrieved from H2O
151168 mock_h2o_get_model .assert_called_with ("fitted_model_123" )
152169
153- # 2. Check that an H2OFrame was created for the prediction data with correct types
154- mock_h2o_frame .assert_called_with (X , column_names = list (X .columns ), column_types = classifier_instance .feature_types_ )
170+ # 2. Check that the new frame creation logic was called
171+ mock_h2o_frame .assert_called_once_with (X , column_names = list (X .columns ), column_types = classifier_instance .feature_types_ )
172+ mock_h2o_assign .assert_called_once_with (mock_tmp_frame , ANY )
173+ mock_h2o_get_frame .assert_called_once ()
174+ # Verify the model's predict method was called with the final mocked frame
175+ mock_model .predict .assert_called_once_with (mock_final_frame )
155176
156177 # 3. Check the output of the prediction
157178 assert isinstance (predictions , np .ndarray )
0 commit comments