Skip to content

Commit ec3cf20

Browse files
committed
tensorflow performance fix
1 parent 815edae commit ec3cf20

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,20 @@ def __init__(
453453
# Fit on the full training data first
454454
current_algorithm.fit(self.X_train, self.y_train)
455455

456+
# --- TENSORFLOW PERFORMANCE FIX (Corrected Position) ---
457+
# Pre-compile the predict function for Keras/TF models to avoid retracing warnings.
458+
# This is done AFTER fitting and before cross-validation.
459+
if isinstance(current_algorithm, (KerasClassifier, kerasClassifier_class, NeuralNetworkClassifier)):
460+
try:
461+
self.logger.debug("Pre-compiling TensorFlow predict function to avoid retracing.")
462+
n_features = self.X_train.shape[1]
463+
# Define an input signature that allows for variable batch size.
464+
input_signature = [tf.TensorSpec(shape=(None, n_features), dtype=tf.float32)]
465+
# Access the underlying Keras model via .model_
466+
current_algorithm.model_.predict.get_concrete_function(input_signature)
467+
except Exception as e:
468+
self.logger.warning(f"Could not pre-compile TF function. Performance may be impacted. Error: {e}")
469+
456470
# --- CRITICAL FIX: Pass the pandas Series, not the numpy array ---
457471
# Passing the numpy array (y_train.to_numpy()) causes index misalignment
458472
# with the pandas DataFrame (X_train_final) inside sklearn's CV,

0 commit comments

Comments
 (0)