@@ -263,6 +263,13 @@ def __init__(
263263 self .logger .info (
264264 "Adjusted KNN n_neighbors parameter space to prevent errors on small CV folds."
265265 )
266+
267+ # Dynamically adjust CatBoost subsample parameter for small datasets
268+ if "catboost" in method_name .lower ():
269+ self ._adjust_catboost_parameters (parameter_space )
270+ self .logger .info (
271+ "Adjusted CatBoost subsample parameter space to prevent errors on small CV folds."
272+ )
266273
267274 # Instantiate and run the hyperparameter grid/random search
268275 search = HyperparameterSearch (
@@ -533,6 +540,43 @@ def adjust_param(param_value):
533540 elif isinstance (parameter_space , dict ) and 'n_neighbors' in parameter_space :
534541 parameter_space ['n_neighbors' ] = adjust_param (parameter_space ['n_neighbors' ])
535542
543+ def _adjust_catboost_parameters (self , parameter_space : Union [Dict , List [Dict ]]):
544+ """
545+ Dynamically adjusts the 'subsample' parameter for CatBoost to prevent
546+ errors on small datasets during cross-validation.
547+ """
548+ n_splits = self .cv .get_n_splits ()
549+ n_samples_in_fold = int (len (self .X_train ) * (n_splits - 1 ) / n_splits )
550+
551+ # Ensure n_samples_in_fold is at least 1 to avoid division by zero
552+ n_samples_in_fold = max (1 , n_samples_in_fold )
553+
554+ # The minimum subsample value must be > 1/n_samples to ensure at least one sample is chosen
555+ min_subsample = 1.0 / n_samples_in_fold
556+
557+ def adjust_param (param_value ):
558+ if is_skopt_space (param_value ):
559+ # For skopt.space objects (Real), adjust the lower bound
560+ new_low = max (param_value .low , min_subsample )
561+ # Ensure the new low is not higher than the high
562+ if new_low > param_value .high :
563+ new_low = param_value .high
564+ param_value .low = new_low
565+ elif isinstance (param_value , (list , np .ndarray )):
566+ # For lists, filter the values
567+ new_param_value = [s for s in param_value if s >= min_subsample ]
568+ if not new_param_value :
569+ # If all values are filtered out, use the smallest valid value
570+ return [min (p for p in param_value if p > 0 ) if any (p > 0 for p in param_value ) else 1.0 ]
571+ return new_param_value
572+ return param_value
573+
574+ if isinstance (parameter_space , list ):
575+ for params in parameter_space :
576+ if 'subsample' in params :
577+ params ['subsample' ] = adjust_param (params ['subsample' ])
578+ elif isinstance (parameter_space , dict ) and 'subsample' in parameter_space :
579+ parameter_space ['subsample' ] = adjust_param (parameter_space ['subsample' ])
536580
537581
538582def dummy_auc () -> float :
0 commit comments