11from typing import Any , Dict , List , Optional
22import logging
3+ import torch
34
45from ml_grid .model_classes .adaboost_classifier_class import adaboost_class
56from ml_grid .model_classes .catboost_classifier_class import CatBoost_class
@@ -51,6 +52,9 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
5152 List[Any]: A list of instantiated model class objects.
5253 """
5354 logger = logging .getLogger ('ml_grid' )
55+
56+ # Check for GPU availability once
57+ gpu_available = torch .cuda .is_available ()
5458 # Get the parameter space size, defaulting to 'small' if not provided.
5559 # This prevents errors when the key is missing from the configuration.
5660 parameter_space_size = ml_grid_object .local_param_dict .get ("param_space_size" )
@@ -76,8 +80,8 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
7680 "GaussianNB_class" : True ,
7781 "LightGBMClassifierWrapper" : True ,
7882 "adaboost_class" : True ,
79- "kerasClassifier_class" : True ,
80- "knn__gpu_wrapper_class" : True ,
83+ "kerasClassifier_class" : gpu_available ,
84+ "knn__gpu_wrapper_class" : gpu_available ,
8185 "NeuralNetworkClassifier_class" : False ,
8286 "TabTransformer_class" : False ,
8387 "h2o_classifier_class" : False ,
@@ -87,6 +91,12 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
8791
8892 for class_name , include in model_class_dict .items ():
8993 if include :
94+ # Proactively skip GPU-specific models if no GPU is available
95+ if "_gpu_" in class_name .lower () and not gpu_available :
96+ logger .warning (
97+ f"Skipping '{ class_name } ' because it requires a GPU, but no CUDA-enabled GPU is available."
98+ )
99+ continue
90100 # Try the exact name first, then try with '_class' appended for convenience
91101 try :
92102 model_class = eval (class_name )
0 commit comments