diff --git a/.ci/scripts/run_sklearn_tests.py b/.ci/scripts/run_sklearn_tests.py index e8a0b07cfd..33373f6c66 100644 --- a/.ci/scripts/run_sklearn_tests.py +++ b/.ci/scripts/run_sklearn_tests.py @@ -14,12 +14,14 @@ # limitations under the License. # =============================================================================== +import os + +os.environ["SCIPY_ARRAY_API"] = "1" from sklearnex import patch_sklearn patch_sklearn() import argparse -import os import sys import pytest @@ -43,8 +45,6 @@ if os.environ["SELECTED_TESTS"] == "all": os.environ["SELECTED_TESTS"] = "" - os.environ["SCIPY_ARRAY_API"] = "1" - pytest_args = ( f"--rootdir={sklearn_file_dir} " f'{os.environ["DESELECTED_TESTS"]} {os.environ["SELECTED_TESTS"]}'.split(" ") diff --git a/conda-recipe/run_test.bat b/conda-recipe/run_test.bat index a70f5f76cf..8e47b761d1 100644 --- a/conda-recipe/run_test.bat +++ b/conda-recipe/run_test.bat @@ -28,7 +28,7 @@ if "%PYTHON%"=="python" ( set NO_DIST=1 ) - +set SCIPY_ARRAY_API=1 %PYTHON% -c "from sklearnex import patch_sklearn; patch_sklearn()" || set exitcode=1 diff --git a/conda-recipe/run_test.sh b/conda-recipe/run_test.sh index 61e96fb018..a292957b35 100755 --- a/conda-recipe/run_test.sh +++ b/conda-recipe/run_test.sh @@ -36,6 +36,8 @@ if [ -z "${PYTHON}" ]; then export PYTHON=python fi +export SCIPY_ARRAY_API=1 + # Note: execute with argument --json-report in order to produce # a JSON report under folder '.pytest_reports'. Other arguments # will also be forwarded to pytest. diff --git a/daal4py/sklearn/linear_model/tests/test_linear.py b/daal4py/sklearn/linear_model/tests/test_linear.py index 29137b475a..8590a357bf 100644 --- a/daal4py/sklearn/linear_model/tests/test_linear.py +++ b/daal4py/sklearn/linear_model/tests/test_linear.py @@ -14,24 +14,14 @@ # limitations under the License. # ============================================================================== - -from os import environ - -from daal4py.sklearn._utils import sklearn_check_version - -# sklearn requires manual enabling of Scipy array API support -# if `array-api-compat` package is present in environment -# TODO: create generic approach to handle this for all tests -if sklearn_check_version("1.6"): - environ["SCIPY_ARRAY_API"] = "1" - - import numpy as np import pytest from sklearn.datasets import make_regression from sklearn.linear_model import LinearRegression from sklearn.utils._testing import assert_array_almost_equal +from daal4py.sklearn._utils import sklearn_check_version + def make_dataset(n_samples, n_features, kind=np.array, random_state=0, types=None): try: diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index f97d01c091..a4fda0875f 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -14,14 +14,6 @@ # limitations under the License. # ============================================================================== -from os import environ - -# sklearn requires manual enabling of Scipy array API support -# if `array-api-compat` package is present in environment -# TODO: create generic approach to handle this for all tests -environ["SCIPY_ARRAY_API"] = "1" - - import numpy as np import pytest import sklearn.utils.estimator_checks diff --git a/sklearnex/covariance/tests/test_incremental_covariance.py b/sklearnex/covariance/tests/test_incremental_covariance.py index 554b1a1734..a89310856b 100644 --- a/sklearnex/covariance/tests/test_incremental_covariance.py +++ b/sklearnex/covariance/tests/test_incremental_covariance.py @@ -15,16 +15,6 @@ # =============================================================================== from contextlib import nullcontext -from os import environ - -from daal4py.sklearn._utils import sklearn_check_version - -# sklearn requires manual enabling of Scipy array API support -# if `array-api-compat` package is present in environment -# TODO: create generic approach to handle this for all tests -if sklearn_check_version("1.6"): - environ["SCIPY_ARRAY_API"] = "1" - import numpy as np import pytest @@ -38,7 +28,7 @@ from sklearn.datasets import load_diabetes from sklearn.decomposition import PCA -from daal4py.sklearn._utils import daal_check_version +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version from onedal.tests.utils._dataframes_support import ( _as_numpy, _convert_to_dataframe,