@@ -64,8 +64,14 @@ def __init__(
6464
6565 @property
6666 def labels_ (self ) -> np .ndarray :
67- if not hasattr (self , "_labels " ):
67+ if not hasattr (self , "_cluster_centers " ):
6868 raise NotFittedError (self ._CORESET_ESTIMATOR_ERROR )
69+ elif not hasattr (self , "_labels" ):
70+ raise ValueError (
71+ "The labels have not been computed because the coreset "
72+ "was fit using partial_fit. "
73+ "Please call predict on your data to obtain the labels."
74+ )
6975 return self ._labels
7076
7177 @property
@@ -114,6 +120,7 @@ def partial_fit(
114120
115121 def _fit_coreset (
116122 self ,
123+ X : Optional [np .ndarray ] = None ,
117124 ) -> None :
118125 if self .coreset_estimator is None :
119126 from sklearn .cluster import KMeans
@@ -127,10 +134,13 @@ def _fit_coreset(
127134 self ._coreset_points , sample_weight = self ._coreset_weights
128135 )
129136 self ._cluster_centers : np .ndarray = self .coreset_estimator .cluster_centers_
130- self ._labels : np .ndarray = self .coreset_estimator .labels_
137+ if X is not None :
138+ self ._labels : np .ndarray = self .coreset_estimator .predict (X )
131139 self ._inertia : float = self .coreset_estimator .inertia_
132140
133- def _compute_coreset (self , fit_coreset : bool = False ) -> "BICO" :
141+ def _compute_coreset (
142+ self , X : Optional [np .ndarray ] = None , fit_coreset : bool = False
143+ ) -> "BICO" :
134144 if not hasattr (self , "bico_obj_" ):
135145 raise NotFittedError (
136146 "This BICO instance is not fitted yet. " "Call `fit` or `partial_fit`."
@@ -152,7 +162,7 @@ def _compute_coreset(self, fit_coreset: bool = False) -> "BICO":
152162 self ._n_features_out = n_found_points
153163
154164 if self .fit_coreset or fit_coreset :
155- self ._fit_coreset ()
165+ self ._fit_coreset (X )
156166
157167 return self
158168
@@ -188,7 +198,9 @@ def _fit(
188198 _DLL .addData (self .bico_obj_ , c_array , c_n )
189199
190200 if not partial or fit_coreset :
191- self ._compute_coreset (fit_coreset )
201+ self ._compute_coreset (
202+ X = _X if not partial else None , fit_coreset = fit_coreset
203+ )
192204
193205 return self
194206
@@ -204,9 +216,9 @@ def fit_predict(
204216 return self .labels_
205217
206218 def predict (self , X : Sequence [Sequence [float ]]) -> Any :
207- self ._fit_coreset ()
208-
209219 if self .coreset_estimator is None :
210220 raise NotFittedError (self ._CORESET_ESTIMATOR_ERROR )
211221
222+ self ._fit_coreset ()
223+
212224 return self .coreset_estimator .predict (X )
0 commit comments