From d8aa5d6aff884219843cfa1eb6b955379bdfc3d7 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 22 Oct 2025 19:33:41 +0200 Subject: [PATCH] updated validate_data to versions > 1.6 --- .gitignore | 8 +++++++- skltemplate/__init__.py | 3 ++- skltemplate/_template.py | 19 +++++++++++-------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 69e8d8a..486bd44 100644 --- a/.gitignore +++ b/.gitignore @@ -75,4 +75,10 @@ target/ .LSOverride # auto-generated files -skltemplate/_version.py \ No newline at end of file +skltemplate/_version.py + +# linters and formatters +.vscode/ +.idea/ +.mypy_cache/ +.ruff_cache/ diff --git a/skltemplate/__init__.py b/skltemplate/__init__.py index 66806e5..31ec217 100644 --- a/skltemplate/__init__.py +++ b/skltemplate/__init__.py @@ -1,8 +1,9 @@ # Authors: scikit-learn-contrib developers # License: BSD 3 clause +from sklearn import __version__ + from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer -from ._version import __version__ __all__ = [ "TemplateEstimator", diff --git a/skltemplate/_template.py b/skltemplate/_template.py index 62a25f2..9077713 100644 --- a/skltemplate/_template.py +++ b/skltemplate/_template.py @@ -4,12 +4,13 @@ # Authors: scikit-learn-contrib developers # License: BSD 3 clause +# mypy: ignore-errors import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context from sklearn.metrics import euclidean_distances from sklearn.utils.multiclass import check_classification_targets -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data class TemplateEstimator(BaseEstimator): @@ -73,12 +74,14 @@ def fit(self, X, y): self : object Returns self. """ - # `_validate_data` is defined in the `BaseEstimator` class. + # `_validate_data` is defined in the sklearn.utils.validation module. # It allows to: # - run different checks on the input data; # - define some attributes associated to the input data: `n_features_in_` and # `feature_names_in_`. - X, y = self._validate_data(X, y, accept_sparse=True) + + X, y = validate_data(self, X, y, accept_sparse=True) + self.is_fitted_ = True # `fit` should always return `self` return self @@ -100,7 +103,7 @@ def predict(self, X): check_is_fitted(self) # We need to set reset=False because we don't want to overwrite `n_features_in_` # `feature_names_in_` but only check that the shape is consistent. - X = self._validate_data(X, accept_sparse=True, reset=False) + X = validate_data(self, X, accept_sparse=True, reset=False) return np.ones(X.shape[0], dtype=np.int64) @@ -182,7 +185,7 @@ def fit(self, X, y): # - run different checks on the input data; # - define some attributes associated to the input data: `n_features_in_` and # `feature_names_in_`. - X, y = self._validate_data(X, y) + X, y = validate_data(self, X, y) # We need to make sure that we have a classification task check_classification_targets(y) @@ -216,7 +219,7 @@ def predict(self, X): # Input validation # We need to set reset=False because we don't want to overwrite `n_features_in_` # `feature_names_in_` but only check that the shape is consistent. - X = self._validate_data(X, reset=False) + X = validate_data(self, X, reset=False) closest = np.argmin(euclidean_distances(X, self.X_), axis=1) return self.y_[closest] @@ -272,7 +275,7 @@ def fit(self, X, y=None): self : object Returns self. """ - X = self._validate_data(X, accept_sparse=True) + X = validate_data(self, X, accept_sparse=True) # Return the transformer return self @@ -297,7 +300,7 @@ def transform(self, X): # Input validation # We need to set reset=False because we don't want to overwrite `n_features_in_` # `feature_names_in_` but only check that the shape is consistent. - X = self._validate_data(X, accept_sparse=True, reset=False) + X = validate_data(self, X, accept_sparse=True, reset=False) return np.sqrt(X) def _more_tags(self):