@@ -452,21 +452,6 @@ def __init__(
452452 if not getattr (self .global_parameters , 'test_mode' , False ):
453453 # Fit on the full training data first
454454 current_algorithm .fit (self .X_train , self .y_train )
455-
456- # --- TENSORFLOW PERFORMANCE FIX (Corrected Position) ---
457- # Pre-compile the predict function for Keras/TF models to avoid retracing warnings.
458- # This is done AFTER fitting and before cross-validation.
459- if isinstance (current_algorithm , (KerasClassifier , kerasClassifier_class , NeuralNetworkClassifier )):
460- try :
461- self .logger .debug ("Pre-compiling TensorFlow predict function to avoid retracing." )
462- n_features = self .X_train .shape [1 ]
463- # Define an input signature that allows for variable batch size.
464- input_signature = [tf .TensorSpec (shape = (None , n_features ), dtype = tf .float32 )]
465- # Access the underlying Keras model via .model_
466- current_algorithm .model_ .predict .get_concrete_function (input_signature )
467- except Exception as e :
468- self .logger .warning (f"Could not pre-compile TF function. Performance may be impacted. Error: { e } " )
469-
470455 # --- CRITICAL FIX: Pass the pandas Series, not the numpy array ---
471456 # Passing the numpy array (y_train.to_numpy()) causes index misalignment
472457 # with the pandas DataFrame (X_train_final) inside sklearn's CV,
@@ -481,6 +466,21 @@ def __init__(
481466 pre_dispatch = 80 ,
482467 error_score = self .error_raise , # Raise error if cross-validation fails
483468 )
469+
470+ # --- TENSORFLOW PERFORMANCE FIX (Corrected Position) ---
471+ # Pre-compile the predict function for Keras/TF models to avoid retracing warnings.
472+ # This is done AFTER fitting and before cross-validation.
473+ if isinstance (current_algorithm , (KerasClassifier , kerasClassifier_class , NeuralNetworkClassifier )):
474+ try :
475+ self .logger .debug ("Pre-compiling TensorFlow predict function to avoid retracing." )
476+ n_features = self .X_train .shape [1 ]
477+ # Define an input signature that allows for variable batch size.
478+ input_signature = [tf .TensorSpec (shape = (None , n_features ), dtype = tf .float32 )]
479+ # Access the underlying Keras model via .model_
480+ current_algorithm .model_ .predict .get_concrete_function (input_signature )
481+ except Exception as e :
482+ self .logger .warning (f"Could not pre-compile TF function. Performance may be impacted. Error: { e } " )
483+
484484
485485
486486 except XGBoostError as e :
0 commit comments