44import numpy as np
55import pandas as pd
66import torch
7- from simbsig .neighbors import KNeighborsClassifier
7+ from simbsig .neighbors import KNeighborsClassifier as SimbsigKNeighborsClassifier
8+ from sklearn .neighbors import KNeighborsClassifier as SklearnKNeighborsClassifier
89from sklearn import metrics
910import logging
1011
@@ -48,18 +49,22 @@ def __init__(
4849 self .p = p
4950 self .metric = metric
5051 self .metric_params = metric_params
52+ self ._init_device = device # Store the original device parameter
53+ self .device = device
54+
55+ # Auto-detect device if not specified, or validate if specified
56+ self ._set_device (device )
5157
52- # Auto-detect device
58+ self .model : Optional [Union [SimbsigKNeighborsClassifier , SklearnKNeighborsClassifier ]] = None
59+
60+ def _set_device (self , device : Optional [str ]):
61+ """Helper to set the device, falling back to CPU if GPU is not available."""
5362 gpu_available = torch .cuda .is_available ()
5463 if device == "gpu" and not gpu_available :
55- logging .getLogger ('ml_grid' ).warning ("GPU requested for KNNWrapper, but torch.cuda.is_available() is False . Falling back to CPU." )
64+ logging .getLogger ('ml_grid' ).warning ("GPU requested for KNNWrapper, but torch.cuda is not available . Falling back to CPU." )
5665 self .device = "cpu"
57- elif device :
58- self .device = device
5966 else :
60- self .device = "gpu" if gpu_available else "cpu"
61-
62- self .model : Optional [KNeighborsClassifier ] = None
67+ self .device = device if device else ("gpu" if gpu_available else "cpu" )
6368
6469 def fit (
6570 self , X : Union [pd .DataFrame , np .ndarray ], y : Union [pd .Series , np .ndarray ]
@@ -75,17 +80,31 @@ def fit(
7580 Returns:
7681 KNNWrapper: The fitted estimator.
7782 """
78- self .model = KNeighborsClassifier (
79- n_neighbors = self .n_neighbors ,
80- weights = self .weights ,
81- algorithm = self .algorithm ,
82- leaf_size = self .leaf_size ,
83- p = self .p ,
84- metric = self .metric ,
85- metric_params = self .metric_params ,
86- device = self .device ,
87- )
88-
83+ # If the device is CPU, use the standard scikit-learn implementation
84+ # to completely avoid any simbsig/torch/cuda calls.
85+ if self .device == 'cpu' :
86+ logging .getLogger ('ml_grid' ).info ("Using scikit-learn's KNeighborsClassifier for CPU execution." )
87+ self .model = SklearnKNeighborsClassifier (
88+ n_neighbors = self .n_neighbors ,
89+ weights = self .weights ,
90+ algorithm = self .algorithm ,
91+ leaf_size = self .leaf_size ,
92+ p = self .p ,
93+ metric = self .metric ,
94+ metric_params = self .metric_params ,
95+ )
96+ else :
97+ # If GPU is intended and available, use the simbsig implementation.
98+ self .model = SimbsigKNeighborsClassifier (
99+ n_neighbors = self .n_neighbors ,
100+ weights = self .weights ,
101+ algorithm = self .algorithm ,
102+ leaf_size = self .leaf_size ,
103+ p = self .p ,
104+ metric = self .metric ,
105+ metric_params = self .metric_params ,
106+ device = self .device ,
107+ )
89108 self .model .fit (X , y )
90109 return self
91110
@@ -97,18 +116,17 @@ def get_params(self, deep: bool = False) -> Dict[str, Any]:
97116 contained subobjects that are estimators.
98117
99118 Returns:
100- Dict[str, Any]: Parameter names mapped to their values.
119+ Dict[str, Any]: Parameter names mapped to their original values.
101120 """
102121 return {
103- "device" : self .device ,
122+ "device" : self ._init_device ,
104123 "n_neighbors" : self .n_neighbors ,
105124 "weights" : self .weights ,
106125 "algorithm" : self .algorithm ,
107126 "leaf_size" : self .leaf_size ,
108127 "p" : self .p ,
109128 "metric" : self .metric ,
110129 "metric_params" : self .metric_params ,
111- "n_neighbors" : self .n_neighbors ,
112130 }
113131
114132 def predict (self , X : Union [pd .DataFrame , np .ndarray ]) -> np .ndarray :
@@ -158,5 +176,10 @@ def set_params(self, **parameters: Any) -> "KNNWrapper":
158176 KNNWrapper: The instance with updated parameters.
159177 """
160178 for parameter , value in parameters .items ():
179+ # Special handling for device to re-validate availability
180+ if parameter == 'device' :
181+ # Update both the initial and current device setting
182+ self ._init_device = value
183+ self ._set_device (value )
161184 setattr (self , parameter , value )
162185 return self
0 commit comments