@@ -488,16 +488,16 @@ end
488488# Assumes u is larger than, or the same size as, v
489489# nfft should be greater than or equal to 2*sv-1
490490function unsafe_conv_kern_os! (out,
491+ output_indices,
491492 u:: AbstractArray{<:Any, N} ,
492493 v,
493- su,
494- sv,
495- sout,
496494 nffts) where N
495+ sout = size (out)
496+ su = size (u)
497+ sv = size (v)
497498 u_start = first .(axes (u))
498- out_axes = axes (out)
499- out_start = first .(out_axes)
500- out_stop = last .(out_axes)
499+ out_start = Tuple (first (output_indices))
500+ out_stop = Tuple (last (output_indices))
501501 ideal_save_blocksize = nffts .- sv .+ 1
502502 # Number of samples that are "missing" if the output is smaller than the
503503 # valid portion of the convolution
@@ -507,7 +507,7 @@ function unsafe_conv_kern_os!(out,
507507 nblocks = cld .(sout, save_blocksize)
508508
509509 # Pre-allocation
510- tdbuff, fdbuff, p, ip = os_prepare_conv (u , nffts)
510+ tdbuff, fdbuff, p, ip = os_prepare_conv (out , nffts)
511511 tdbuff_axes = axes (tdbuff)
512512
513513 # Transform the smaller filter
@@ -608,129 +608,178 @@ function unsafe_conv_kern_os!(out,
608608 out
609609end
610610
611- function _conv_kern_fft! (out,
612- u:: AbstractArray{T, N} ,
613- v:: AbstractArray{T, N} ,
614- su,
615- sv,
616- outsize,
617- nffts) where {T<: Real , N}
618- padded = _zeropad (u, nffts)
611+ function _conv_kern_fft! (out:: AbstractArray{T, N} ,
612+ output_indices,
613+ u:: AbstractArray{<:Real, N} ,
614+ v:: AbstractArray{<:Real, N} ) where {T<: Real , N}
615+ outsize = size (output_indices)
616+ nffts = nextfastfft (outsize)
617+ padded = _zeropad! (similar (u, T, nffts), u)
619618 p = plan_rfft (padded)
620619 uf = p * padded
621620 _zeropad! (padded, v)
622621 vf = p * padded
623622 uf .*= vf
624623 raw_out = irfft (uf, nffts[1 ])
625624 copyto! (out,
626- CartesianIndices (out) ,
625+ output_indices ,
627626 raw_out,
628627 CartesianIndices (UnitRange .(1 , outsize)))
629628end
630- function _conv_kern_fft! (out, u, v, su, sv, outsize, nffts)
631- upad = _zeropad (u, nffts)
632- vpad = _zeropad (v, nffts)
629+ function _conv_kern_fft! (out:: AbstractArray{T} , output_indices, u, v) where {T}
630+ outsize = size (output_indices)
631+ nffts = nextfastfft (outsize)
632+ upad = _zeropad! (similar (u, T, nffts), u)
633+ vpad = _zeropad! (similar (v, T, nffts), v)
633634 p! = plan_fft! (upad)
634635 ip! = inv (p!)
635636 p! * upad # Operates in place on upad
636637 p! * vpad
637638 upad .*= vpad
638639 ip! * upad
639640 copyto! (out,
640- CartesianIndices (out) ,
641+ output_indices ,
641642 upad,
642643 CartesianIndices (UnitRange .(1 , outsize)))
643644end
644645
645- # v should be smaller than u for good performance
646- function _conv_fft! (out, u, v, su, sv, outsize)
647- os_nffts = map (optimalfftfiltlength, sv, su)
648- if any (os_nffts .< outsize)
649- unsafe_conv_kern_os! (out, u, v, su, sv, outsize, os_nffts)
646+ function _conv_td! (out, output_indices, u:: AbstractArray{<:Number, N} , v:: AbstractArray{<:Number, N} ) where {N}
647+ index_offset = first (CartesianIndices (u)) + first (CartesianIndices (v)) - first (output_indices)
648+ checkbounds (out, output_indices)
649+ fill! (out, zero (eltype (out)))
650+ if size (u, 1 ) ≤ size (v, 1 ) # choose more efficient iteration order
651+ for m in CartesianIndices (u), n in CartesianIndices (v)
652+ @inbounds out[n+ m - index_offset] = muladd (u[m], v[n], out[n+ m - index_offset])
653+ end
650654 else
651- nffts = nextfastfft (outsize)
652- _conv_kern_fft! (out, u, v, su, sv, outsize, nffts)
655+ for n in CartesianIndices (v), m in CartesianIndices (u)
656+ @inbounds out[n+ m - index_offset] = muladd (u[m], v[n], out[n+ m - index_offset])
657+ end
653658 end
659+ return out
654660end
655661
662+ # whether the given axis are to be considered to carry an offset for `conv!` and `conv`
663+ conv_with_offset (:: Base.OneTo ) = false
664+ conv_with_offset (a:: Any ) = throw (ArgumentError (" unsupported axis type $(typeof (a)) " ))
656665
657- # For arrays with weird offsets
658- function _conv_similar (u, outsize, axesu, axesv)
659- out_offsets = first .(axesu) .+ first .(axesv)
660- out_axes = UnitRange .(out_offsets, out_offsets .+ outsize .- 1 )
661- similar (u, out_axes)
662- end
663- function _conv_similar (
664- u, outsize, :: NTuple{<:Any, Base.OneTo{Int}} , :: NTuple{<:Any, Base.OneTo{Int}}
665- )
666- similar (u, outsize)
667- end
668- _conv_similar (u, v, outsize) = _conv_similar (u, outsize, axes (u), axes (v))
669-
670- # Does convolution, will not switch argument order
671- function _conv! (out, u, v, su, sv, outsize)
672- # TODO : Add spatial / time domain algorithm
673- _conv_fft! (out, u, v, su, sv, outsize)
674- end
675-
676- # Does convolution, will not switch argument order
677- function _conv (u, v, su, sv)
678- outsize = su .+ sv .- 1
679- out = _conv_similar (u, v, outsize)
680- _conv! (out, u, v, su, sv, outsize)
681- end
682-
683- # We use this type definition for clarity
684- const RealOrComplexFloat = Union{AbstractFloat, Complex{T} where T<: AbstractFloat }
666+ const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}
685667
686- # May switch argument order
687668"""
688- conv(u,v)
689-
690- Convolution of two arrays. Uses either FFT convolution or overlap-save,
691- depending on the size of the input. `u` and `v` can be N-dimensional arrays,
692- with arbitrary indexing offsets, but their axes must be a `UnitRange`.
669+ conv!(out, u, v; algorithm=:auto)
670+
671+ Convolution of two arrays `u` and `v` with the result stored in `out`. `out`
672+ must be large enough to store the entire result; if it is even larger, the
673+ excess entries will be zeroed.
674+
675+ `out`, `u`, and `v` can be N-dimensional arrays, with arbitrary indexing
676+ offsets. If none of them has offset axes,
677+ `size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
678+ have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
679+ `lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
680+ A mix of offset and non-offset axes between input and output is not permitted.
681+
682+ The `algorithm` keyword allows choosing the algorithm to use:
683+ * `:direct`: Evaluates the convolution sum in time domain.
684+ * `:fft_simple`: Evaluates the convolution as a product in the frequency domain.
685+ * `:fft_overlapsave`: Evaluates the convolution block-wise as a product in the
686+ frequency domain, overlapping the resulting blocks.
687+ * `:fft`: Selects the faster of `:fft_simple` and `:fft_overlapsave` (as
688+ estimated from the input size).
689+ * `:fast`: Selects the fastest of `:direct`, `:fft_simple` and
690+ `:fft_overlapsave` (as estimated from the input size).
691+ * `:auto` (default): Equivalent to `:fast` if the data type is known to be
692+ suitable for FFT-based computation, equivalent to `:direct` otherwise.
693+
694+ !!! warning
695+ The choices made by `:fft`, `:fast`, and `:auto` are based on performance
696+ heuristics which may not result in the fastest algorithm in all cases. If
697+ best performance for a certain size/type combination is required, it is
698+ advised to do individual benchmarking and explicitly specify the desired
699+ algorithm.
693700"""
694- function conv (u:: AbstractArray{T, N} ,
695- v:: AbstractArray{T, N} ) where {T<: RealOrComplexFloat , N}
696- su = size (u)
697- sv = size (v)
698- if length (u) >= length (v)
699- _conv (u, v, su, sv)
701+ function conv! (
702+ out:: AbstractArray{T, N} ,
703+ u:: AbstractArray{<:Number, N} ,
704+ v:: AbstractArray{<:Number, N} ;
705+ algorithm= :auto
706+ ) where {T<: Number , N}
707+ output_indices = CartesianIndices (map (axes (out), axes (u), axes (v)) do ao, au, av
708+ input_has_offset = conv_with_offset (au) || conv_with_offset (av)
709+ if input_has_offset != = conv_with_offset (ao)
710+ throw (ArgumentError (" output must have offset axes if and only if the input has" ))
711+ end
712+ offset = input_has_offset ? 0 : 1
713+ return (first (au)+ first (av) : last (au)+ last (av)) .- offset
714+ end )
715+
716+ if algorithm=== :auto
717+ algorithm = T <: FFTTypes ? :fast : :direct
718+ end
719+ if algorithm=== :fast
720+ if length (u) * length (v) < 2 ^ 16 # TODO : better heuristic
721+ algorithm = :direct
722+ else
723+ algorithm = :fft
724+ end
725+ end
726+ if algorithm=== :direct
727+ return _conv_td! (out, output_indices, u, v)
700728 else
701- _conv (v, u, sv, su)
729+ if output_indices != CartesianIndices (out)
730+ fill! (out, zero (eltype (out)))
731+ end
732+ os_nffts = length (u) >= length (v) ? map (optimalfftfiltlength, size (v), size (u)) : map (optimalfftfiltlength, size (u), size (v))
733+ if algorithm=== :fft
734+ if any (os_nffts .< size (output_indices))
735+ algorithm = :fft_overlapsave
736+ else
737+ algorithm = :fft_simple
738+ end
739+ end
740+ if algorithm === :fft_overlapsave
741+ # v should be smaller than u for good performance
742+ if length (u) >= length (v)
743+ return unsafe_conv_kern_os! (out, output_indices, u, v, os_nffts)
744+ else
745+ return unsafe_conv_kern_os! (out, output_indices, v, u, os_nffts)
746+ end
747+ elseif algorithm === :fft_simple
748+ return _conv_kern_fft! (out, output_indices, u, v)
749+ else
750+ throw (ArgumentError (" algorithm must be :auto, :fast, :direct, :fft, :fft_simple, or :fft_overlapsave" ))
751+ end
702752 end
703753end
704754
705- function conv (u:: AbstractArray{<:RealOrComplexFloat, N} ,
706- v:: AbstractArray{<:RealOrComplexFloat, N} ) where N
707- fu, fv = promote (u, v)
708- conv (fu, fv)
709- end
710-
711- conv (u:: AbstractArray{<:Integer, N} , v:: AbstractArray{<:Integer, N} ) where {N} =
712- round .(Int, conv (float (u), float (v)))
755+ conv_output_axis (au, av) =
756+ conv_with_offset (au) || conv_with_offset (av) ?
757+ (first (au)+ first (av): last (au)+ last (av)) : Base. OneTo (last (au) + last (av) - 1 )
713758
714- conv (u:: AbstractArray{<:Number, N} , v:: AbstractArray{<:Number, N} ) where {N} =
715- conv (float (u), float (v))
716-
717- function conv (u:: AbstractArray{<:Number, N} ,
718- v:: AbstractArray{<:RealOrComplexFloat, N} ) where N
719- conv (float (u), v)
720- end
759+ """
760+ conv(u, v; algorithm)
721761
722- function conv (u:: AbstractArray{<:RealOrComplexFloat, N} ,
723- v:: AbstractArray{<:Number, N} ) where N
724- conv (u, float (v))
762+ Convolution of two arrays. A convolution algorithm is automatically chosen among
763+ direct convolution, FFT, or FFT overlap-save, depending on the size of the
764+ input, unless explicitly specified with the `algorithm` keyword argument; see
765+ [`conv!`](@ref) for details.
766+ """
767+ function conv (
768+ u:: AbstractArray{Tu, N} , v:: AbstractArray{Tv, N} ; kwargs...
769+ ) where {Tu<: Number , Tv<: Number , N}
770+ T = promote_type (Tu, Tv)
771+ out_axes = map (conv_output_axis, axes (u), axes (v))
772+ out = similar (u, T, out_axes)
773+ return conv! (out, u, v; kwargs... )
725774end
726775
727776function conv (A:: AbstractArray{<:Number, M} ,
728- B:: AbstractArray{<:Number, N} ) where {M, N}
777+ B:: AbstractArray{<:Number, N} ; kwargs ... ) where {M, N}
729778 if (M < N)
730- conv (cat (A, dims= N):: AbstractArray{eltype(A), N} , B)
779+ conv (cat (A, dims= N):: AbstractArray{eltype(A), N} , B; kwargs ... )
731780 else
732781 @assert M > N
733- conv (A, cat (B, dims= M):: AbstractArray{eltype(B), M} )
782+ conv (A, cat (B, dims= M):: AbstractArray{eltype(B), M} ; kwargs ... )
734783 end
735784end
736785
0 commit comments