Skip to content

Commit 45cd3fa

Browse files
authored
Fix MulStyle of CompositeMaps (#200)
1 parent bc62f53 commit 45cd3fa

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/composition.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTuple
1919
Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector{<:LinearMap{T}}) where {T} =
2020
CompositeMap{T}(reverse(maps))
2121

22+
MulStyle(A::CompositeMap) = MulStyle(A.maps...) === TwoArg() ? TwoArg() : ThreeArg()
23+
2224
# basic methods
2325
Base.size(A::CompositeMap) = (size(A.maps[end], 1), size(A.maps[1], 2))
2426
Base.axes(A::CompositeMap) = (axes(A.maps[end])[1], axes(A.maps[1])[2])
@@ -173,7 +175,7 @@ end
173175

174176
function _unsafe_mul!(y, A::CompositeMap, x::AbstractVector)
175177
MulStyle(A) === TwoArg() ?
176-
copyto!(y, foldr(*, reverse(A.maps), init=x)) :
178+
copyto!(y, A*x) :
177179
_compositemul!(y, A, x)
178180
return y
179181
end

test/composition.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,28 @@ using LinearMaps: LinearMapVector, LinearMapTuple
158158
@test P * ones(3) == (LowerTriangular(ones(3,3))^i) * ones(3)
159159
end
160160
end
161+
162+
# test product of 2-arg FunctionMaps
163+
# the following tests don't work when wrapped in a testset
164+
N = 100
165+
function planA()
166+
y = zeros(N) # workspace
167+
A = LinearMap{Float64}(x -> (y .= x .+ 1; y), N)
168+
return A, y
169+
end
170+
function planB()
171+
y = zeros(N) # workspace
172+
A = LinearMap{Float64}(x -> (y .= x ./ 2; y), N)
173+
return A, y
174+
end
175+
A, ya = planA()
176+
B, yb = planB()
177+
x = zeros(N)
178+
C = @inferred A*B; C*x
179+
@test C*x === ya == ones(N)
180+
D = @inferred B*A; D*x
181+
@test D*x === yb == fill(0.5, N)
182+
@test (@allocated C*x) == 0
183+
mul!(deepcopy(ya), C, x)
184+
y = deepcopy(ya)
185+
@test (@allocated mul!(y, C, x)) == 0

0 commit comments

Comments
 (0)