Skip to content

Commit 208f7d8

Browse files
Compute labels for fit (#5)
* Remove useless imports * Remove the randomness class * Forgot to use seed * Fix it such that the return labels actually refer to the original data and not to the coreset * Fix mypy errors
1 parent 6f225dc commit 208f7d8

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

bico/core.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,14 @@ def __init__(
6464

6565
@property
6666
def labels_(self) -> np.ndarray:
67-
if not hasattr(self, "_labels"):
67+
if not hasattr(self, "_cluster_centers"):
6868
raise NotFittedError(self._CORESET_ESTIMATOR_ERROR)
69+
elif not hasattr(self, "_labels"):
70+
raise ValueError(
71+
"The labels have not been computed because the coreset "
72+
"was fit using partial_fit. "
73+
"Please call predict on your data to obtain the labels."
74+
)
6975
return self._labels
7076

7177
@property
@@ -114,6 +120,7 @@ def partial_fit(
114120

115121
def _fit_coreset(
116122
self,
123+
X: Optional[np.ndarray] = None,
117124
) -> None:
118125
if self.coreset_estimator is None:
119126
from sklearn.cluster import KMeans
@@ -127,10 +134,13 @@ def _fit_coreset(
127134
self._coreset_points, sample_weight=self._coreset_weights
128135
)
129136
self._cluster_centers: np.ndarray = self.coreset_estimator.cluster_centers_
130-
self._labels: np.ndarray = self.coreset_estimator.labels_
137+
if X is not None:
138+
self._labels: np.ndarray = self.coreset_estimator.predict(X)
131139
self._inertia: float = self.coreset_estimator.inertia_
132140

133-
def _compute_coreset(self, fit_coreset: bool = False) -> "BICO":
141+
def _compute_coreset(
142+
self, X: Optional[np.ndarray] = None, fit_coreset: bool = False
143+
) -> "BICO":
134144
if not hasattr(self, "bico_obj_"):
135145
raise NotFittedError(
136146
"This BICO instance is not fitted yet. " "Call `fit` or `partial_fit`."
@@ -152,7 +162,7 @@ def _compute_coreset(self, fit_coreset: bool = False) -> "BICO":
152162
self._n_features_out = n_found_points
153163

154164
if self.fit_coreset or fit_coreset:
155-
self._fit_coreset()
165+
self._fit_coreset(X)
156166

157167
return self
158168

@@ -188,7 +198,9 @@ def _fit(
188198
_DLL.addData(self.bico_obj_, c_array, c_n)
189199

190200
if not partial or fit_coreset:
191-
self._compute_coreset(fit_coreset)
201+
self._compute_coreset(
202+
X=_X if not partial else None, fit_coreset=fit_coreset
203+
)
192204

193205
return self
194206

@@ -204,9 +216,9 @@ def fit_predict(
204216
return self.labels_
205217

206218
def predict(self, X: Sequence[Sequence[float]]) -> Any:
207-
self._fit_coreset()
208-
209219
if self.coreset_estimator is None:
210220
raise NotFittedError(self._CORESET_ESTIMATOR_ERROR)
211221

222+
self._fit_coreset()
223+
212224
return self.coreset_estimator.predict(X)

0 commit comments

Comments
 (0)