Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .ci/scripts/run_sklearn_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# limitations under the License.
# ===============================================================================

import os

os.environ["SCIPY_ARRAY_API"] = "1"
from sklearnex import patch_sklearn

patch_sklearn()

import argparse
import os
import sys

import pytest
Expand All @@ -43,8 +45,6 @@
if os.environ["SELECTED_TESTS"] == "all":
os.environ["SELECTED_TESTS"] = ""

os.environ["SCIPY_ARRAY_API"] = "1"

pytest_args = (
f"--rootdir={sklearn_file_dir} "
f'{os.environ["DESELECTED_TESTS"]} {os.environ["SELECTED_TESTS"]}'.split(" ")
Expand Down
2 changes: 1 addition & 1 deletion conda-recipe/run_test.bat
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if "%PYTHON%"=="python" (
set NO_DIST=1
)


set SCIPY_ARRAY_API=1

%PYTHON% -c "from sklearnex import patch_sklearn; patch_sklearn()" || set exitcode=1

Expand Down
2 changes: 2 additions & 0 deletions conda-recipe/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ if [ -z "${PYTHON}" ]; then
export PYTHON=python
fi

export SCIPY_ARRAY_API=1

# Note: execute with argument --json-report in order to produce
# a JSON report under folder '.pytest_reports'. Other arguments
# will also be forwarded to pytest.
Expand Down
14 changes: 2 additions & 12 deletions daal4py/sklearn/linear_model/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,14 @@
# limitations under the License.
# ==============================================================================


from os import environ

from daal4py.sklearn._utils import sklearn_check_version

# sklearn requires manual enabling of Scipy array API support
# if `array-api-compat` package is present in environment
# TODO: create generic approach to handle this for all tests
if sklearn_check_version("1.6"):
environ["SCIPY_ARRAY_API"] = "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it has no effect when done after importing scipy. It is now set from the test runner script instead.



import numpy as np
import pytest
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.utils._testing import assert_array_almost_equal

from daal4py.sklearn._utils import sklearn_check_version


def make_dataset(n_samples, n_features, kind=np.array, random_state=0, types=None):
try:
Expand Down
8 changes: 0 additions & 8 deletions onedal/svm/tests/test_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
# limitations under the License.
# ==============================================================================

from os import environ

# sklearn requires manual enabling of Scipy array API support
# if `array-api-compat` package is present in environment
# TODO: create generic approach to handle this for all tests
environ["SCIPY_ARRAY_API"] = "1"


import numpy as np
import pytest
import sklearn.utils.estimator_checks
Expand Down
12 changes: 1 addition & 11 deletions sklearnex/covariance/tests/test_incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
# ===============================================================================

from contextlib import nullcontext
from os import environ

from daal4py.sklearn._utils import sklearn_check_version

# sklearn requires manual enabling of Scipy array API support
# if `array-api-compat` package is present in environment
# TODO: create generic approach to handle this for all tests
if sklearn_check_version("1.6"):
environ["SCIPY_ARRAY_API"] = "1"


import numpy as np
import pytest
Expand All @@ -38,7 +28,7 @@
from sklearn.datasets import load_diabetes
from sklearn.decomposition import PCA

from daal4py.sklearn._utils import daal_check_version
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
from onedal.tests.utils._dataframes_support import (
_as_numpy,
_convert_to_dataframe,
Expand Down
Loading