55
66import numpy as np
77from numpy import (atleast_1d , poly , polyval , roots , real , asarray ,
8- resize , pi , absolute , sqrt , tan , log10 ,
8+ pi , absolute , sqrt , tan , log10 ,
99 arcsinh , sin , exp , cosh , arccosh , ceil , conjugate ,
1010 zeros , sinh , append , concatenate , prod , ones , full , array ,
1111 mintypecode )
1717from scipy ._lib ._util import float_factorial
1818from scipy .signal ._arraytools import _validate_fs
1919
20+ import scipy ._lib .array_api_extra as xpx
21+ from scipy ._lib ._array_api import array_namespace , xp_promote , xp_size
22+
2023
2124__all__ = ['findfreqs' , 'freqs' , 'freqz' , 'tf2zpk' , 'zpk2tf' , 'normalize' ,
2225 'lp2lp' , 'lp2hp' , 'lp2bp' , 'lp2bs' , 'bilinear' , 'iirdesign' ,
@@ -1676,7 +1679,7 @@ def idx_worst(p):
16761679 return sos
16771680
16781681
1679- def _align_nums (nums ):
1682+ def _align_nums (nums , xp ):
16801683 """Aligns the shapes of multiple numerators.
16811684
16821685 Given an array of numerator coefficient arrays [[a_1, a_2,...,
@@ -1701,19 +1704,19 @@ def _align_nums(nums):
17011704 # The statement can throw a ValueError if one
17021705 # of the numerators is a single digit and another
17031706 # is array-like e.g. if nums = [5, [1, 2, 3]]
1704- nums = asarray (nums )
1707+ nums = xp . asarray (nums )
17051708
1706- if not np . issubdtype (nums .dtype , np . number ):
1709+ if not xp . isdtype (nums .dtype , "numeric" ):
17071710 raise ValueError ("dtype of numerator is non-numeric" )
17081711
17091712 return nums
17101713
17111714 except ValueError :
1712- nums = [np . atleast_1d ( num ) for num in nums ]
1713- max_width = max (num . size for num in nums )
1715+ nums = [xpx . atleast_nd ( xp . asarray ( num ), ndim = 1 ) for num in nums ]
1716+ max_width = max (xp_size ( num ) for num in nums )
17141717
17151718 # pre-allocate
1716- aligned_nums = np .zeros ((len ( nums ) , max_width ))
1719+ aligned_nums = xp .zeros ((nums . shape [ 0 ] , max_width ))
17171720
17181721 # Create numerators with padded zeros
17191722 for index , num in enumerate (nums ):
@@ -1722,6 +1725,26 @@ def _align_nums(nums):
17221725 return aligned_nums
17231726
17241727
1728+ def _trim_zeros (filt , trim = 'fb' ):
1729+ # https://github.com/numpy/numpy/blob/v2.1.0/numpy/lib/_function_base_impl.py#L1874-L1925
1730+ first = 0
1731+ trim = trim .upper ()
1732+ if 'F' in trim :
1733+ for i in filt :
1734+ if i != 0. :
1735+ break
1736+ else :
1737+ first = first + 1
1738+ last = filt .shape [0 ]
1739+ if 'B' in trim :
1740+ for i in filt [::- 1 ]:
1741+ if i != 0. :
1742+ break
1743+ else :
1744+ last = last - 1
1745+ return filt [first :last ]
1746+
1747+
17251748def normalize (b , a ):
17261749 """Normalize numerator/denominator of a continuous-time transfer function.
17271750
@@ -1778,30 +1801,33 @@ def normalize(b, a):
17781801 Badly conditioned filter coefficients (numerator): the results may be meaningless
17791802
17801803 """
1781- num , den = b , a
1804+ xp = array_namespace ( b , a )
17821805
1783- den = np .asarray (den )
1784- den = np .atleast_1d (den )
1785- num = np .atleast_2d (_align_nums (num ))
1806+ den = xp .asarray (a )
1807+ den = xpx .atleast_nd (den , ndim = 1 , xp = xp )
1808+
1809+ num = xp .asarray (b )
1810+ num = xpx .atleast_nd (_align_nums (num , xp ), ndim = 2 , xp = xp )
17861811
17871812 if den .ndim != 1 :
17881813 raise ValueError ("Denominator polynomial must be rank-1 array." )
17891814 if num .ndim > 2 :
17901815 raise ValueError ("Numerator polynomial must be rank-1 or"
17911816 " rank-2 array." )
1792- if np .all (den == 0 ):
1817+ if xp .all (den == 0 ):
17931818 raise ValueError ("Denominator must have at least on nonzero element." )
17941819
17951820 # Trim leading zeros in denominator, leave at least one.
1796- den = np . trim_zeros (den , 'f' )
1821+ den = _trim_zeros (den , 'f' )
17971822
17981823 # Normalize transfer function
17991824 num , den = num / den [0 ], den / den [0 ]
18001825
18011826 # Count numerator columns that are all zero
18021827 leading_zeros = 0
1803- for col in num .T :
1804- if np .allclose (col , 0 , atol = 1e-14 ):
1828+ for j in range (num .shape [- 1 ]):
1829+ col = num [:, j ]
1830+ if xp .all (xp .abs (col ) <= 1e-14 ):
18051831 leading_zeros += 1
18061832 else :
18071833 break
@@ -1879,22 +1905,49 @@ def lp2lp(b, a, wo=1.0):
18791905 >>> plt.legend()
18801906
18811907 """
1882- a , b = map (atleast_1d , (a , b ))
1908+ xp = array_namespace (a , b )
1909+ a , b = map (xp .asarray , (a , b ))
1910+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
1911+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
1912+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
1913+
18831914 try :
18841915 wo = float (wo )
18851916 except TypeError :
18861917 wo = float (wo [0 ])
1887- d = len ( a )
1888- n = len ( b )
1918+ d = a . shape [ 0 ]
1919+ n = b . shape [ 0 ]
18891920 M = max ((d , n ))
1890- pwo = pow ( wo , np .arange (M - 1 , - 1 , - 1 ) )
1921+ pwo = wo ** xp .arange (M - 1 , - 1 , - 1 , dtype = xp . float64 )
18911922 start1 = max ((n - d , 0 ))
18921923 start2 = max ((d - n , 0 ))
18931924 b = b * pwo [start1 ] / pwo [start2 :]
18941925 a = a * pwo [start1 ] / pwo [start1 :]
18951926 return normalize (b , a )
18961927
18971928
1929+ def _resize (a , new_shape , xp ):
1930+ # https://github.com/numpy/numpy/blob/v2.2.4/numpy/_core/fromnumeric.py#L1535
1931+ a = xp .reshape (a , (- 1 ,))
1932+
1933+ new_size = 1
1934+ for dim_length in new_shape :
1935+ new_size *= dim_length
1936+ if dim_length < 0 :
1937+ raise ValueError (
1938+ 'all elements of `new_shape` must be non-negative'
1939+ )
1940+
1941+ if xp_size (a ) == 0 or new_size == 0 :
1942+ # First case must zero fill. The second would have repeats == 0.
1943+ return xp .zeros_like (a , shape = new_shape )
1944+
1945+ repeats = - (- new_size // xp_size (a )) # ceil division
1946+ a = xp .concat ((a ,) * repeats )[:new_size ]
1947+
1948+ return xp .reshape (a , new_shape )
1949+
1950+
18981951def lp2hp (b , a , wo = 1.0 ):
18991952 r"""
19001953 Transform a lowpass filter prototype to a highpass filter.
@@ -1953,27 +2006,34 @@ def lp2hp(b, a, wo=1.0):
19532006 >>> plt.legend()
19542007
19552008 """
1956- a , b = map (atleast_1d , (a , b ))
2009+ xp = array_namespace (a , b )
2010+
2011+ a , b = map (xp .asarray , (a , b ))
2012+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2013+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2014+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
2015+
19572016 try :
19582017 wo = float (wo )
19592018 except TypeError :
19602019 wo = float (wo [0 ])
1961- d = len ( a )
1962- n = len ( b )
2020+ d = a . shape [ 0 ]
2021+ n = b . shape [ 0 ]
19632022 if wo != 1 :
1964- pwo = pow ( wo , np .arange (max ((d , n ))) )
2023+ pwo = wo ** xp .arange (max ((d , n )), dtype = xp . float64 )
19652024 else :
1966- pwo = np .ones (max ((d , n )), b .dtype . char )
2025+ pwo = xp .ones (max ((d , n )), dtype = b .dtype )
19672026 if d >= n :
1968- outa = a [::- 1 ] * pwo
1969- outb = resize (b , (d ,))
2027+ outa = xp .flip (a ) * pwo
2028+ outb = xp .concat ((xp .zeros (n , dtype = b .dtype ), ))
2029+ outb = _resize (b , (d ,), xp = xp )
19702030 outb [n :] = 0.0
1971- outb [:n ] = b [:: - 1 ] * pwo [:n ]
2031+ outb [:n ] = xp . flip ( b ) * pwo [:n ]
19722032 else :
1973- outb = b [:: - 1 ] * pwo
1974- outa = resize (a , (n ,))
2033+ outb = xp . flip ( b ) * pwo
2034+ outa = _resize (a , (n ,), xp = xp )
19752035 outa [d :] = 0.0
1976- outa [:d ] = a [:: - 1 ] * pwo [:d ]
2036+ outa [:d ] = xp . flip ( a ) * pwo [:d ]
19772037
19782038 return normalize (outb , outa )
19792039
@@ -2038,16 +2098,20 @@ def lp2bp(b, a, wo=1.0, bw=1.0):
20382098 >>> plt.ylabel('Amplitude [dB]')
20392099 >>> plt.legend()
20402100 """
2101+ xp = array_namespace (a , b )
2102+
2103+ a , b = map (xp .asarray , (a , b ))
2104+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2105+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2106+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
20412107
2042- a , b = map (atleast_1d , (a , b ))
2043- D = len (a ) - 1
2044- N = len (b ) - 1
2045- artype = mintypecode ((a , b ))
2108+ D = a .shape [0 ] - 1
2109+ N = b .shape [0 ] - 1
20462110 ma = max ([N , D ])
20472111 Np = N + ma
20482112 Dp = D + ma
2049- bprime = np .empty (Np + 1 , artype )
2050- aprime = np .empty (Dp + 1 , artype )
2113+ bprime = xp .empty (Np + 1 , dtype = b . dtype )
2114+ aprime = xp .empty (Dp + 1 , dtype = a . dtype )
20512115 wosq = wo * wo
20522116 for j in range (Np + 1 ):
20532117 val = 0.0
@@ -2126,15 +2190,20 @@ def lp2bs(b, a, wo=1.0, bw=1.0):
21262190 >>> plt.ylabel('Amplitude [dB]')
21272191 >>> plt.legend()
21282192 """
2129- a , b = map (atleast_1d , (a , b ))
2130- D = len (a ) - 1
2131- N = len (b ) - 1
2132- artype = mintypecode ((a , b ))
2193+ xp = array_namespace (a , b )
2194+
2195+ a , b = map (xp .asarray , (a , b ))
2196+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2197+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2198+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
2199+
2200+ D = a .shape [0 ] - 1
2201+ N = b .shape [0 ] - 1
21332202 M = max ([N , D ])
21342203 Np = M + M
21352204 Dp = M + M
2136- bprime = np .empty (Np + 1 , artype )
2137- aprime = np .empty (Dp + 1 , artype )
2205+ bprime = xp .empty (Np + 1 , dtype = b . dtype )
2206+ aprime = xp .empty (Dp + 1 , dtype = a . dtype )
21382207 wosq = wo * wo
21392208 for j in range (Np + 1 ):
21402209 val = 0.0
0 commit comments