Skip to content

Commit 2abce81

Browse files
MAINT: Set SCIPY_ARRAY_API=1 on tests (#2806)
* add necessary env variable for array api on scipy * try another way * remove from places where it has no effect * more fixes
1 parent 268eb80 commit 2abce81

File tree

6 files changed

+9
-35
lines changed

6 files changed

+9
-35
lines changed

.ci/scripts/run_sklearn_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import os
18+
19+
os.environ["SCIPY_ARRAY_API"] = "1"
1720
from sklearnex import patch_sklearn
1821

1922
patch_sklearn()
2023

2124
import argparse
22-
import os
2325
import sys
2426

2527
import pytest
@@ -43,8 +45,6 @@
4345
if os.environ["SELECTED_TESTS"] == "all":
4446
os.environ["SELECTED_TESTS"] = ""
4547

46-
os.environ["SCIPY_ARRAY_API"] = "1"
47-
4848
pytest_args = (
4949
f"--rootdir={sklearn_file_dir} "
5050
f'{os.environ["DESELECTED_TESTS"]} {os.environ["SELECTED_TESTS"]}'.split(" ")

conda-recipe/run_test.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if "%PYTHON%"=="python" (
2828
set NO_DIST=1
2929
)
3030

31-
31+
set SCIPY_ARRAY_API=1
3232

3333
%PYTHON% -c "from sklearnex import patch_sklearn; patch_sklearn()" || set exitcode=1
3434

conda-recipe/run_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ if [ -z "${PYTHON}" ]; then
3636
export PYTHON=python
3737
fi
3838

39+
export SCIPY_ARRAY_API=1
40+
3941
# Note: execute with argument --json-report in order to produce
4042
# a JSON report under folder '.pytest_reports'. Other arguments
4143
# will also be forwarded to pytest.

daal4py/sklearn/linear_model/tests/test_linear.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,14 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
18-
from os import environ
19-
20-
from daal4py.sklearn._utils import sklearn_check_version
21-
22-
# sklearn requires manual enabling of Scipy array API support
23-
# if `array-api-compat` package is present in environment
24-
# TODO: create generic approach to handle this for all tests
25-
if sklearn_check_version("1.6"):
26-
environ["SCIPY_ARRAY_API"] = "1"
27-
28-
2917
import numpy as np
3018
import pytest
3119
from sklearn.datasets import make_regression
3220
from sklearn.linear_model import LinearRegression
3321
from sklearn.utils._testing import assert_array_almost_equal
3422

23+
from daal4py.sklearn._utils import sklearn_check_version
24+
3525

3626
def make_dataset(n_samples, n_features, kind=np.array, random_state=0, types=None):
3727
try:

onedal/svm/tests/test_svc.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from os import environ
18-
19-
# sklearn requires manual enabling of Scipy array API support
20-
# if `array-api-compat` package is present in environment
21-
# TODO: create generic approach to handle this for all tests
22-
environ["SCIPY_ARRAY_API"] = "1"
23-
24-
2517
import numpy as np
2618
import pytest
2719
import sklearn.utils.estimator_checks

sklearnex/covariance/tests/test_incremental_covariance.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,6 @@
1515
# ===============================================================================
1616

1717
from contextlib import nullcontext
18-
from os import environ
19-
20-
from daal4py.sklearn._utils import sklearn_check_version
21-
22-
# sklearn requires manual enabling of Scipy array API support
23-
# if `array-api-compat` package is present in environment
24-
# TODO: create generic approach to handle this for all tests
25-
if sklearn_check_version("1.6"):
26-
environ["SCIPY_ARRAY_API"] = "1"
27-
2818

2919
import numpy as np
3020
import pytest
@@ -38,7 +28,7 @@
3828
from sklearn.datasets import load_diabetes
3929
from sklearn.decomposition import PCA
4030

41-
from daal4py.sklearn._utils import daal_check_version
31+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
4232
from onedal.tests.utils._dataframes_support import (
4333
_as_numpy,
4434
_convert_to_dataframe,

0 commit comments

Comments
 (0)