Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ The package also provides `mkl_fft._numpy_fft` and `mkl_fft._scipy_fft` interfac

To build ``mkl_fft`` from sources on Linux:
- install a recent version of MKL, if necessary;
- execute ``source /path_to_oneapi/mkl/latest/env/vars.sh`` ;
- execute ``source /path_to_oneapi/mkl/latest/env/vars.sh``;
- execute ``python -m pip install .``
40 changes: 33 additions & 7 deletions mkl_fft/_numpy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@
]

import re
import warnings

import numpy as np
from numpy import array, asanyarray, conjugate, prod, sqrt, take

from . import _float_utils
Expand Down Expand Up @@ -688,22 +690,46 @@ def ihfft(a, n=None, axis=-1, norm=None):
return output


def _cook_nd_args(a, s=None, axes=None, invreal=0):
# copied from: https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py
def _cook_nd_args(a, s=None, axes=None, invreal=False):
if s is None:
shapeless = 1
shapeless = True
if axes is None:
s = list(a.shape)
else:
s = take(a.shape, axes)
else:
shapeless = 0
shapeless = False
s = list(s)
if axes is None:
if not shapeless and np.__version__ >= "2.0":
msg = (
"`axes` should not be `None` if `s` is not `None` "
"(Deprecated in NumPy 2.0). In a future version of NumPy, "
"this will raise an error and `s[i]` will correspond to "
"the size along the transformed axis specified by "
"`axes[i]`. To retain current behaviour, pass a sequence "
"[0, ..., k-1] to `axes` for an array of dimension k."
)
warnings.warn(msg, DeprecationWarning, stacklevel=3)
axes = list(range(-len(s), 0))
if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")
if invreal and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
if None in s and np.__version__ >= "2.0":
msg = (
"Passing an array containing `None` values to `s` is "
"deprecated in NumPy 2.0 and will raise an error in "
"a future version of NumPy. To use the default behaviour "
"of the corresponding 1-D transform, pass the value matching "
"the default for its `n` parameter. To use the default "
"behaviour for every axis, the `s` argument can be omitted."
)
warnings.warn(msg, DeprecationWarning, stacklevel=3)
# use the whole input array along axis `i` if `s[i] == -1 or None`
s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)]

return s, axes


Expand Down Expand Up @@ -808,6 +834,7 @@ def fftn(a, s=None, axes=None, norm=None):
"""
_check_norm(norm)
x = _float_utils.__downcast_float128_array(a)
s, axes = _cook_nd_args(x, s, axes)

if norm in (None, "backward"):
fsc = 1.0
Expand Down Expand Up @@ -920,6 +947,7 @@ def ifftn(a, s=None, axes=None, norm=None):
"""
_check_norm(norm)
x = _float_utils.__downcast_float128_array(a)
s, axes = _cook_nd_args(x, s, axes)

if norm in (None, "backward"):
fsc = 1.0
Expand Down Expand Up @@ -1215,16 +1243,15 @@ def rfftn(a, s=None, axes=None, norm=None):
"""
_check_norm(norm)
x = _float_utils.__downcast_float128_array(a)
s, axes = _cook_nd_args(x, s, axes)

if norm in (None, "backward"):
fsc = 1.0
elif norm == "forward":
x = asanyarray(x)
s, axes = _cook_nd_args(x, s, axes)
fsc = frwd_sc_nd(s, x.shape)
else:
x = asanyarray(x)
s, axes = _cook_nd_args(x, s, axes)
fsc = sqrt(frwd_sc_nd(s, x.shape))

return trycall(
Expand Down Expand Up @@ -1369,16 +1396,15 @@ def irfftn(a, s=None, axes=None, norm=None):
"""
_check_norm(norm)
x = _float_utils.__downcast_float128_array(a)
s, axes = _cook_nd_args(x, s, axes, invreal=True)

if norm in (None, "backward"):
fsc = 1.0
elif norm == "forward":
x = asanyarray(x)
s, axes = _cook_nd_args(x, s, axes, invreal=1)
fsc = frwd_sc_nd(s, x.shape)
else:
x = asanyarray(x)
s, axes = _cook_nd_args(x, s, axes, invreal=1)
fsc = sqrt(frwd_sc_nd(s, x.shape))

return trycall(
Expand Down
4 changes: 2 additions & 2 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def _fftnd_impl(x, s=None, axes=None, overwrite_x=False, direction=+1, double fs
_direct_fftnd,
{'overwrite_x': overwrite_x, 'direction': direction, 'fsc': fsc},
res
)
)
else:
sc = <object> fsc
return _iter_fftnd(x, s=s, axes=axes,
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
a = _fix_dimensions(a, tuple(ss), axes)
if len(set(axes)) == len(axes) and len(axes) == a.ndim and len(axes) > 2:
ss, aa = _remove_axis(s, axes, -1)
ind = [slice(None,None,1),] * len(s)
ind = [slice(None, None, 1),] * len(s)
for ii in range(a.shape[la]):
ind[la] = ii
tind = tuple(ind)
Expand Down
1 change: 1 addition & 0 deletions mkl_fft/tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_scipy_rfftn(norm, dtype):
assert np.allclose(x, xx, atol=tol, rtol=tol)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_numpy_rfftn(norm, dtype):
Expand Down
6 changes: 3 additions & 3 deletions mkl_fft/tests/test_pocketfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def test_s_negative_1(self, op):
# should use the whole input array along the first axis
assert op(x, s=(-1, 5), axes=(0, 1)).shape == (10, 5)

@pytest.mark.skip("no warning is raised in mkl_ftt")
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
@pytest.mark.parametrize(
"op", [mkl_fft.fftn, mkl_fft.ifftn, mkl_fft.rfftn, mkl_fft.irfftn]
)
Expand All @@ -519,14 +519,14 @@ def test_s_axes_none(self, op):
with pytest.warns(match="`axes` should not be `None` if `s`"):
op(x, s=(-1, 5))

@pytest.mark.skip("no warning is raised in mkl_ftt")
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
@pytest.mark.parametrize("op", [mkl_fft.fft2, mkl_fft.ifft2])
def test_s_axes_none_2D(self, op):
x = np.arange(100).reshape(10, 10)
with pytest.warns(match="`axes` should not be `None` if `s`"):
op(x, s=(-1, 5), axes=None)

@pytest.mark.skip("no warning is raised in mkl_ftt")
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
@pytest.mark.parametrize(
"op",
[
Expand Down
Loading