Skip to content

Commit 6ecc90b

Browse files
committed
catboost param adjustment
1 parent 47379f3 commit 6ecc90b

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ def __init__(
263263
self.logger.info(
264264
"Adjusted KNN n_neighbors parameter space to prevent errors on small CV folds."
265265
)
266+
267+
# Dynamically adjust CatBoost subsample parameter for small datasets
268+
if "catboost" in method_name.lower():
269+
self._adjust_catboost_parameters(parameter_space)
270+
self.logger.info(
271+
"Adjusted CatBoost subsample parameter space to prevent errors on small CV folds."
272+
)
266273

267274
# Instantiate and run the hyperparameter grid/random search
268275
search = HyperparameterSearch(
@@ -533,6 +540,43 @@ def adjust_param(param_value):
533540
elif isinstance(parameter_space, dict) and 'n_neighbors' in parameter_space:
534541
parameter_space['n_neighbors'] = adjust_param(parameter_space['n_neighbors'])
535542

543+
def _adjust_catboost_parameters(self, parameter_space: Union[Dict, List[Dict]]):
544+
"""
545+
Dynamically adjusts the 'subsample' parameter for CatBoost to prevent
546+
errors on small datasets during cross-validation.
547+
"""
548+
n_splits = self.cv.get_n_splits()
549+
n_samples_in_fold = int(len(self.X_train) * (n_splits - 1) / n_splits)
550+
551+
# Ensure n_samples_in_fold is at least 1 to avoid division by zero
552+
n_samples_in_fold = max(1, n_samples_in_fold)
553+
554+
# The minimum subsample value must be > 1/n_samples to ensure at least one sample is chosen
555+
min_subsample = 1.0 / n_samples_in_fold
556+
557+
def adjust_param(param_value):
558+
if is_skopt_space(param_value):
559+
# For skopt.space objects (Real), adjust the lower bound
560+
new_low = max(param_value.low, min_subsample)
561+
# Ensure the new low is not higher than the high
562+
if new_low > param_value.high:
563+
new_low = param_value.high
564+
param_value.low = new_low
565+
elif isinstance(param_value, (list, np.ndarray)):
566+
# For lists, filter the values
567+
new_param_value = [s for s in param_value if s >= min_subsample]
568+
if not new_param_value:
569+
# If all values are filtered out, use the smallest valid value
570+
return [min(p for p in param_value if p > 0) if any(p > 0 for p in param_value) else 1.0]
571+
return new_param_value
572+
return param_value
573+
574+
if isinstance(parameter_space, list):
575+
for params in parameter_space:
576+
if 'subsample' in params:
577+
params['subsample'] = adjust_param(params['subsample'])
578+
elif isinstance(parameter_space, dict) and 'subsample' in parameter_space:
579+
parameter_space['subsample'] = adjust_param(parameter_space['subsample'])
536580

537581

538582
def dummy_auc() -> float:

0 commit comments

Comments
 (0)