|
8 | 8 | from joblib import cpu_count |
9 | 9 | from scipy.sparse import csr_matrix |
10 | 10 | from sklearn.base import BaseEstimator, TransformerMixin |
| 11 | +from sklearn.utils import Tags, TargetTags, TransformerTags |
| 12 | +from sklearn.utils.validation import validate_data |
11 | 13 |
|
12 | 14 | from ..utils import TransformerChecksMixin, postprocess_knn_csr |
13 | 15 |
|
@@ -85,7 +87,7 @@ def _metric_info(self): |
85 | 87 |
|
86 | 88 | def fit(self, X, y=None): |
87 | 89 | normalize = self._metric_info.get("normalize", False) |
88 | | - X = self._validate_data(X, dtype=np.float32, copy=normalize) |
| 90 | + X = validate_data(self, X, dtype=np.float32, copy=normalize) |
89 | 91 | self.n_samples_fit_ = X.shape[0] |
90 | 92 | if self.n_jobs == -1: |
91 | 93 | n_jobs = cpu_count() |
@@ -157,14 +159,11 @@ def _transform(self, X): |
157 | 159 | def fit_transform(self, X, y=None): |
158 | 160 | return self.fit(X, y=y)._transform(X=None) |
159 | 161 |
|
160 | | - def _more_tags(self): |
161 | | - return { |
162 | | - "_xfail_checks": { |
163 | | - "check_estimators_pickle": "Cannot pickle FAISS index", |
164 | | - "check_methods_subset_invariance": "Unable to reset FAISS internal RNG", |
165 | | - }, |
166 | | - "requires_y": False, |
167 | | - "preserves_dtype": [np.float32], |
| 162 | + def __sklearn_tags__(self) -> Tags: |
| 163 | + return Tags( |
| 164 | + estimator_type="transformer", |
| 165 | + target_tags=TargetTags(required=False), |
| 166 | + transformer_tags=TransformerTags(preserves_dtype=[np.float32]), |
168 | 167 | # Could be made deterministic *if* we could reset FAISS's internal RNG |
169 | | - "non_deterministic": True, |
170 | | - } |
| 168 | + non_deterministic=True, |
| 169 | + ) |
0 commit comments