Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __pycache__/

mkl_fft/_pydfti.c
mkl_fft/_pydfti.cpython*.so
mkl_fft/_pydfti.*-win_amd64.pyd
mkl_fft/src/mklfft.c
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ To build `mkl_fft` from sources on Linux with Intel® OneMKL:
- `git clone https://github.com/IntelPython/mkl_fft.git mkl_fft`
- `cd mkl_fft`
- `python -m pip install .`
- `pip install scipy` (optional: for using `mkl_fft.interface.scipy_fft` module)
- `cd ..`
- `python -c "import mkl_fft"`

Expand All @@ -103,5 +104,6 @@ To build `mkl_fft` from sources on Linux with conda follow these steps:
- `git clone https://github.com/IntelPython/mkl_fft.git mkl_fft`
- `cd mkl_fft`
- `python -m pip install .`
- `conda install scipy` (optional: for using `mkl_fft.interface.scipy_fft` module)
- `cd ..`
- `python -c "import mkl_fft"`
2 changes: 1 addition & 1 deletion conda-recipe-cf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ test:
- pytest -v --pyargs mkl_fft
requires:
- pytest
- scipy
- scipy >=1.10
imports:
- mkl_fft
- mkl_fft.interfaces
Expand Down
2 changes: 1 addition & 1 deletion conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ test:
- pytest -v --pyargs mkl_fft
requires:
- pytest
- scipy
- scipy >=1.10
imports:
- mkl_fft
- mkl_fft.interfaces
Expand Down
9 changes: 8 additions & 1 deletion mkl_fft/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,11 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from . import numpy_fft, scipy_fft
from . import numpy_fft

try:
import scipy.fft
except ImportError:
pass
else:
from . import scipy_fft
31 changes: 23 additions & 8 deletions mkl_fft/tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,26 @@

import mkl_fft.interfaces as mfi

try:
scipy_fft = mfi.scipy_fft
except AttributeError:
scipy_fft = None

interfaces = []
ids = []
if scipy_fft is not None:
interfaces.append(scipy_fft)
ids.append("scipy")
interfaces.append(mfi.numpy_fft)
ids.append("numpy")


@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
@pytest.mark.parametrize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
def test_scipy_fft(norm, dtype):
pytest.importorskip("scipy", reason="requires scipy")
x = np.ones(511, dtype=dtype)
w = mfi.scipy_fft.fft(x, norm=norm, workers=None, plan=None)
xx = mfi.scipy_fft.ifft(w, norm=norm, workers=None, plan=None)
Expand All @@ -57,6 +71,7 @@ def test_numpy_fft(norm, dtype):
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_scipy_rfft(norm, dtype):
pytest.importorskip("scipy", reason="requires scipy")
x = np.ones(511, dtype=dtype)
w = mfi.scipy_fft.rfft(x, norm=norm, workers=None, plan=None)
xx = mfi.scipy_fft.irfft(
Expand Down Expand Up @@ -87,6 +102,7 @@ def test_numpy_rfft(norm, dtype):
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
def test_scipy_fftn(norm, dtype):
pytest.importorskip("scipy", reason="requires scipy")
x = np.ones((37, 83), dtype=dtype)
w = mfi.scipy_fft.fftn(x, norm=norm, workers=None, plan=None)
xx = mfi.scipy_fft.ifftn(w, norm=norm, workers=None, plan=None)
Expand All @@ -109,6 +125,7 @@ def test_numpy_fftn(norm, dtype):
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_scipy_rfftn(norm, dtype):
pytest.importorskip("scipy", reason="requires scipy")
x = np.ones((37, 83), dtype=dtype)
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
xx = mfi.scipy_fft.irfftn(w, s=x.shape, norm=norm, workers=None, plan=None)
Expand Down Expand Up @@ -143,32 +160,30 @@ def _get_blacklisted_dtypes():

@pytest.mark.parametrize("dtype", _get_blacklisted_dtypes())
def test_scipy_no_support_for(dtype):
pytest.importorskip("scipy", reason="requires scipy")
x = np.ones(16, dtype=dtype)
assert_raises(NotImplementedError, mfi.scipy_fft.ifft, x)


def test_scipy_fft_arg_validate():
pytest.importorskip("scipy", reason="requires scipy")
with pytest.raises(ValueError):
mfi.scipy_fft.fft([1, 2, 3, 4], norm=b"invalid")

with pytest.raises(NotImplementedError):
mfi.scipy_fft.fft([1, 2, 3, 4], plan="magic")


@pytest.mark.parametrize(
"func", [mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2], ids=["scipy", "numpy"]
)
def test_axes(func):
@pytest.mark.parametrize("interface", interfaces, ids=ids)
def test_axes(interface):
x = np.arange(24.0).reshape(2, 3, 4)
res = func(x, axes=(1, 2))
res = interface.rfft2(x, axes=(1, 2))
exp = np.fft.rfft2(x, axes=(1, 2))
tol = 64 * np.finfo(np.float64).eps
assert np.allclose(res, exp, atol=tol, rtol=tol)


@pytest.mark.parametrize(
"interface", [mfi.scipy_fft, mfi.numpy_fft], ids=["scipy", "numpy"]
)
@pytest.mark.parametrize("interface", interfaces, ids=ids)
@pytest.mark.parametrize(
"func", ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
)
Expand Down
24 changes: 14 additions & 10 deletions mkl_fft/tests/third_party/scipy/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@

import numpy as np
import pytest
import scipy
from numpy.random import random
from numpy.testing import assert_allclose, assert_array_almost_equal
from pytest import raises as assert_raises

# pylint: disable=possibly-used-before-assignment
if scipy.__version__ < "1.12":
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
elif scipy.__version__ < "1.14":
# For python-3.11 and 3.12, scipy<1.14 is installed from Intel channel
# For python<=3.9, scipy<1.14 is installed from conda channel
# pylint: disable=no-name-in-module
from scipy._lib._array_api import size as xp_size
try:
import scipy
except ImportError:
pytest.skip("This test file needs scipy", allow_module_level=True)
else:
from scipy._lib._array_api import xp_size
if np.lib.NumpyVersion(scipy.__version__) < "1.12.0":
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
elif np.lib.NumpyVersion(scipy.__version__) < "1.14.0":
# For python-3.11 and 3.12, scipy<1.14 is installed from Intel channel
# For python<=3.9, scipy<1.14 is installed from conda channel
# pylint: disable=no-name-in-module
from scipy._lib._array_api import size as xp_size
else:
from scipy._lib._array_api import xp_size

from scipy._lib._array_api import is_numpy, xp_assert_close, xp_assert_equal

Expand Down
5 changes: 4 additions & 1 deletion mkl_fft/tests/third_party/scipy/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import pytest
from numpy.testing import assert_allclose

import mkl_fft.interfaces.scipy_fft as fft
try:
import mkl_fft.interfaces.scipy_fft as fft
except ImportError:
pytest.skip("This test file needs scipy", allow_module_level=True)


@pytest.fixture(scope="module")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.9,<3.13"

[project.optional-dependencies]
test = ["pytest", "scipy"]
scipy_interface = ["scipy>=1.10"]
test = ["pytest", "scipy>=1.10"]

[project.urls]
Download = "http://github.com/IntelPython/mkl_fft"
Expand Down
Loading