Skip to content

Commit 7652e2d

Browse files
Implement time-domain convolution and use it for integers (#545)
1 parent 1dae6a3 commit 7652e2d

File tree

6 files changed

+231
-97
lines changed

6 files changed

+231
-97
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1313
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515

16+
[weakdeps]
17+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
18+
19+
[extensions]
20+
OffsetArraysExt = "OffsetArrays"
21+
1622
[compat]
1723
Bessels = "0.2"
1824
DelimitedFiles = "1.6"

docs/src/convolutions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
```@docs
44
conv
5+
conv!
56
deconv
67
xcorr
78
```

ext/OffsetArraysExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module OffsetArraysExt
2+
import DSP
3+
import OffsetArrays
4+
5+
DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true
6+
7+
end

src/DSP.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using FFTW
44
using LinearAlgebra: mul!, rmul!
55
using IterTools: subsets
66

7-
export conv, deconv, filt, filt!, xcorr
7+
export conv, conv!, deconv, filt, filt!, xcorr
88

99
# This function has methods added in `periodograms` but is not exported,
1010
# so we define it here so one can do `DSP.allocate_output` instead of

src/dspbase.jl

Lines changed: 138 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -488,16 +488,16 @@ end
488488
# Assumes u is larger than, or the same size as, v
489489
# nfft should be greater than or equal to 2*sv-1
490490
function unsafe_conv_kern_os!(out,
491+
output_indices,
491492
u::AbstractArray{<:Any, N},
492493
v,
493-
su,
494-
sv,
495-
sout,
496494
nffts) where N
495+
sout = size(out)
496+
su = size(u)
497+
sv = size(v)
497498
u_start = first.(axes(u))
498-
out_axes = axes(out)
499-
out_start = first.(out_axes)
500-
out_stop = last.(out_axes)
499+
out_start = Tuple(first(output_indices))
500+
out_stop = Tuple(last(output_indices))
501501
ideal_save_blocksize = nffts .- sv .+ 1
502502
# Number of samples that are "missing" if the output is smaller than the
503503
# valid portion of the convolution
@@ -507,7 +507,7 @@ function unsafe_conv_kern_os!(out,
507507
nblocks = cld.(sout, save_blocksize)
508508

509509
# Pre-allocation
510-
tdbuff, fdbuff, p, ip = os_prepare_conv(u, nffts)
510+
tdbuff, fdbuff, p, ip = os_prepare_conv(out, nffts)
511511
tdbuff_axes = axes(tdbuff)
512512

513513
# Transform the smaller filter
@@ -608,129 +608,178 @@ function unsafe_conv_kern_os!(out,
608608
out
609609
end
610610

611-
function _conv_kern_fft!(out,
612-
u::AbstractArray{T, N},
613-
v::AbstractArray{T, N},
614-
su,
615-
sv,
616-
outsize,
617-
nffts) where {T<:Real, N}
618-
padded = _zeropad(u, nffts)
611+
function _conv_kern_fft!(out::AbstractArray{T, N},
612+
output_indices,
613+
u::AbstractArray{<:Real, N},
614+
v::AbstractArray{<:Real, N}) where {T<:Real, N}
615+
outsize = size(output_indices)
616+
nffts = nextfastfft(outsize)
617+
padded = _zeropad!(similar(u, T, nffts), u)
619618
p = plan_rfft(padded)
620619
uf = p * padded
621620
_zeropad!(padded, v)
622621
vf = p * padded
623622
uf .*= vf
624623
raw_out = irfft(uf, nffts[1])
625624
copyto!(out,
626-
CartesianIndices(out),
625+
output_indices,
627626
raw_out,
628627
CartesianIndices(UnitRange.(1, outsize)))
629628
end
630-
function _conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
631-
upad = _zeropad(u, nffts)
632-
vpad = _zeropad(v, nffts)
629+
function _conv_kern_fft!(out::AbstractArray{T}, output_indices, u, v) where {T}
630+
outsize = size(output_indices)
631+
nffts = nextfastfft(outsize)
632+
upad = _zeropad!(similar(u, T, nffts), u)
633+
vpad = _zeropad!(similar(v, T, nffts), v)
633634
p! = plan_fft!(upad)
634635
ip! = inv(p!)
635636
p! * upad # Operates in place on upad
636637
p! * vpad
637638
upad .*= vpad
638639
ip! * upad
639640
copyto!(out,
640-
CartesianIndices(out),
641+
output_indices,
641642
upad,
642643
CartesianIndices(UnitRange.(1, outsize)))
643644
end
644645

645-
# v should be smaller than u for good performance
646-
function _conv_fft!(out, u, v, su, sv, outsize)
647-
os_nffts = map(optimalfftfiltlength, sv, su)
648-
if any(os_nffts .< outsize)
649-
unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts)
646+
function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
647+
index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices)
648+
checkbounds(out, output_indices)
649+
fill!(out, zero(eltype(out)))
650+
if size(u, 1) size(v, 1) # choose more efficient iteration order
651+
for m in CartesianIndices(u), n in CartesianIndices(v)
652+
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
653+
end
650654
else
651-
nffts = nextfastfft(outsize)
652-
_conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
655+
for n in CartesianIndices(v), m in CartesianIndices(u)
656+
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
657+
end
653658
end
659+
return out
654660
end
655661

662+
# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
663+
conv_with_offset(::Base.OneTo) = false
664+
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))
656665

657-
# For arrays with weird offsets
658-
function _conv_similar(u, outsize, axesu, axesv)
659-
out_offsets = first.(axesu) .+ first.(axesv)
660-
out_axes = UnitRange.(out_offsets, out_offsets .+ outsize .- 1)
661-
similar(u, out_axes)
662-
end
663-
function _conv_similar(
664-
u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}, ::NTuple{<:Any, Base.OneTo{Int}}
665-
)
666-
similar(u, outsize)
667-
end
668-
_conv_similar(u, v, outsize) = _conv_similar(u, outsize, axes(u), axes(v))
669-
670-
# Does convolution, will not switch argument order
671-
function _conv!(out, u, v, su, sv, outsize)
672-
# TODO: Add spatial / time domain algorithm
673-
_conv_fft!(out, u, v, su, sv, outsize)
674-
end
675-
676-
# Does convolution, will not switch argument order
677-
function _conv(u, v, su, sv)
678-
outsize = su .+ sv .- 1
679-
out = _conv_similar(u, v, outsize)
680-
_conv!(out, u, v, su, sv, outsize)
681-
end
682-
683-
# We use this type definition for clarity
684-
const RealOrComplexFloat = Union{AbstractFloat, Complex{T} where T<:AbstractFloat}
666+
const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}
685667

686-
# May switch argument order
687668
"""
688-
conv(u,v)
689-
690-
Convolution of two arrays. Uses either FFT convolution or overlap-save,
691-
depending on the size of the input. `u` and `v` can be N-dimensional arrays,
692-
with arbitrary indexing offsets, but their axes must be a `UnitRange`.
669+
conv!(out, u, v; algorithm=:auto)
670+
671+
Convolution of two arrays `u` and `v` with the result stored in `out`. `out`
672+
must be large enough to store the entire result; if it is even larger, the
673+
excess entries will be zeroed.
674+
675+
`out`, `u`, and `v` can be N-dimensional arrays, with arbitrary indexing
676+
offsets. If none of them has offset axes,
677+
`size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
678+
have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
679+
`lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
680+
A mix of offset and non-offset axes between input and output is not permitted.
681+
682+
The `algorithm` keyword allows choosing the algorithm to use:
683+
* `:direct`: Evaluates the convolution sum in time domain.
684+
* `:fft_simple`: Evaluates the convolution as a product in the frequency domain.
685+
* `:fft_overlapsave`: Evaluates the convolution block-wise as a product in the
686+
frequency domain, overlapping the resulting blocks.
687+
* `:fft`: Selects the faster of `:fft_simple` and `:fft_overlapsave` (as
688+
estimated from the input size).
689+
* `:fast`: Selects the fastest of `:direct`, `:fft_simple` and
690+
`:fft_overlapsave` (as estimated from the input size).
691+
* `:auto` (default): Equivalent to `:fast` if the data type is known to be
692+
suitable for FFT-based computation, equivalent to `:direct` otherwise.
693+
694+
!!! warning
695+
The choices made by `:fft`, `:fast`, and `:auto` are based on performance
696+
heuristics which may not result in the fastest algorithm in all cases. If
697+
best performance for a certain size/type combination is required, it is
698+
advised to do individual benchmarking and explicitly specify the desired
699+
algorithm.
693700
"""
694-
function conv(u::AbstractArray{T, N},
695-
v::AbstractArray{T, N}) where {T<:RealOrComplexFloat, N}
696-
su = size(u)
697-
sv = size(v)
698-
if length(u) >= length(v)
699-
_conv(u, v, su, sv)
701+
function conv!(
702+
out::AbstractArray{T, N},
703+
u::AbstractArray{<:Number, N},
704+
v::AbstractArray{<:Number, N};
705+
algorithm=:auto
706+
) where {T<:Number, N}
707+
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
708+
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
709+
if input_has_offset !== conv_with_offset(ao)
710+
throw(ArgumentError("output must have offset axes if and only if the input has"))
711+
end
712+
offset = input_has_offset ? 0 : 1
713+
return (first(au)+first(av) : last(au)+last(av)) .- offset
714+
end)
715+
716+
if algorithm===:auto
717+
algorithm = T <: FFTTypes ? :fast : :direct
718+
end
719+
if algorithm===:fast
720+
if length(u) * length(v) < 2^16 # TODO: better heuristic
721+
algorithm = :direct
722+
else
723+
algorithm = :fft
724+
end
725+
end
726+
if algorithm===:direct
727+
return _conv_td!(out, output_indices, u, v)
700728
else
701-
_conv(v, u, sv, su)
729+
if output_indices != CartesianIndices(out)
730+
fill!(out, zero(eltype(out)))
731+
end
732+
os_nffts = length(u) >= length(v) ? map(optimalfftfiltlength, size(v), size(u)) : map(optimalfftfiltlength, size(u), size(v))
733+
if algorithm===:fft
734+
if any(os_nffts .< size(output_indices))
735+
algorithm = :fft_overlapsave
736+
else
737+
algorithm = :fft_simple
738+
end
739+
end
740+
if algorithm === :fft_overlapsave
741+
# v should be smaller than u for good performance
742+
if length(u) >= length(v)
743+
return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts)
744+
else
745+
return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts)
746+
end
747+
elseif algorithm === :fft_simple
748+
return _conv_kern_fft!(out, output_indices, u, v)
749+
else
750+
throw(ArgumentError("algorithm must be :auto, :fast, :direct, :fft, :fft_simple, or :fft_overlapsave"))
751+
end
702752
end
703753
end
704754

705-
function conv(u::AbstractArray{<:RealOrComplexFloat, N},
706-
v::AbstractArray{<:RealOrComplexFloat, N}) where N
707-
fu, fv = promote(u, v)
708-
conv(fu, fv)
709-
end
710-
711-
conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} =
712-
round.(Int, conv(float(u), float(v)))
755+
conv_output_axis(au, av) =
756+
conv_with_offset(au) || conv_with_offset(av) ?
757+
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)
713758

714-
conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} =
715-
conv(float(u), float(v))
716-
717-
function conv(u::AbstractArray{<:Number, N},
718-
v::AbstractArray{<:RealOrComplexFloat, N}) where N
719-
conv(float(u), v)
720-
end
759+
"""
760+
conv(u, v; algorithm)
721761
722-
function conv(u::AbstractArray{<:RealOrComplexFloat, N},
723-
v::AbstractArray{<:Number, N}) where N
724-
conv(u, float(v))
762+
Convolution of two arrays. A convolution algorithm is automatically chosen among
763+
direct convolution, FFT, or FFT overlap-save, depending on the size of the
764+
input, unless explicitly specified with the `algorithm` keyword argument; see
765+
[`conv!`](@ref) for details.
766+
"""
767+
function conv(
768+
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
769+
) where {Tu<:Number, Tv<:Number, N}
770+
T = promote_type(Tu, Tv)
771+
out_axes = map(conv_output_axis, axes(u), axes(v))
772+
out = similar(u, T, out_axes)
773+
return conv!(out, u, v; kwargs...)
725774
end
726775

727776
function conv(A::AbstractArray{<:Number, M},
728-
B::AbstractArray{<:Number, N}) where {M, N}
777+
B::AbstractArray{<:Number, N}; kwargs...) where {M, N}
729778
if (M < N)
730-
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B)
779+
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B; kwargs...)
731780
else
732781
@assert M > N
733-
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M})
782+
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}; kwargs...)
734783
end
735784
end
736785

0 commit comments

Comments
 (0)