Skip to content

Commit b4e693d

Browse files
committed
prevent test/non gpu env calling gpu dependent models in testing.
1 parent a3e3b64 commit b4e693d

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

ml_grid/pipeline/model_class_list.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional
22
import logging
3+
import torch
34

45
from ml_grid.model_classes.adaboost_classifier_class import adaboost_class
56
from ml_grid.model_classes.catboost_classifier_class import CatBoost_class
@@ -51,6 +52,9 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
5152
List[Any]: A list of instantiated model class objects.
5253
"""
5354
logger = logging.getLogger('ml_grid')
55+
56+
# Check for GPU availability once
57+
gpu_available = torch.cuda.is_available()
5458
# Get the parameter space size, defaulting to 'small' if not provided.
5559
# This prevents errors when the key is missing from the configuration.
5660
parameter_space_size = ml_grid_object.local_param_dict.get("param_space_size")
@@ -76,8 +80,8 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
7680
"GaussianNB_class": True,
7781
"LightGBMClassifierWrapper": True,
7882
"adaboost_class": True,
79-
"kerasClassifier_class": True,
80-
"knn__gpu_wrapper_class": True,
83+
"kerasClassifier_class": gpu_available,
84+
"knn__gpu_wrapper_class": gpu_available,
8185
"NeuralNetworkClassifier_class": False,
8286
"TabTransformer_class": False,
8387
"h2o_classifier_class": False,
@@ -87,6 +91,12 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
8791

8892
for class_name, include in model_class_dict.items():
8993
if include:
94+
# Proactively skip GPU-specific models if no GPU is available
95+
if "_gpu_" in class_name.lower() and not gpu_available:
96+
logger.warning(
97+
f"Skipping '{class_name}' because it requires a GPU, but no CUDA-enabled GPU is available."
98+
)
99+
continue
90100
# Try the exact name first, then try with '_class' appended for convenience
91101
try:
92102
model_class = eval(class_name)

0 commit comments

Comments
 (0)