@@ -165,7 +165,18 @@ Base.:(==)(A::CompositeMap, B::CompositeMap) =
165165 (eltype (A) == eltype (B) && all (A. maps .== B. maps))
166166
167167# multiplication with vectors/matrices
168- _unsafe_mul! (y, A:: CompositeMap , x:: AbstractVector ) = _compositemul! (y, A, x)
168+ function Base.:(* )(A:: CompositeMap , x:: AbstractVector )
169+ MulStyle (A) === TwoArg () ?
170+ foldr (* , reverse (A. maps), init= x) :
171+ invoke (* , Tuple{LinearMap, AbstractVector}, A, x)
172+ end
173+
174+ function _unsafe_mul! (y, A:: CompositeMap , x:: AbstractVector )
175+ MulStyle (A) === TwoArg () ?
176+ copyto! (y, foldr (* , reverse (A. maps), init= x)) :
177+ _compositemul! (y, A, x)
178+ return y
179+ end
169180_unsafe_mul! (y, A:: CompositeMap , x:: AbstractMatrix ) = _compositemul! (y, A, x)
170181
171182function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap}} , x,
@@ -174,10 +185,50 @@ function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
174185 return _unsafe_mul! (y, A. maps[1 ], x)
175186end
176187function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}} , x,
177- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
188+ source = nothing ,
189+ dest = nothing )
190+ if isnothing (source)
191+ z = convert (AbstractArray, A. maps[1 ] * x)
192+ _unsafe_mul! (y, A. maps[2 ], z)
193+ return y
194+ else
195+ _unsafe_mul! (source, A. maps[1 ], x)
196+ _unsafe_mul! (y, A. maps[2 ], source)
197+ return y
198+ end
199+ end
200+ _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapTuple} , x, s = nothing , d = nothing ) =
201+ _compositemulN! (y, A, x, s, d)
202+ function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
203+ source = nothing ,
178204 dest = nothing )
179- _unsafe_mul! (source, A. maps[1 ], x)
180- _unsafe_mul! (y, A. maps[2 ], source)
205+ N = length (A. maps)
206+ if N == 1
207+ return _unsafe_mul! (y, A. maps[1 ], x)
208+ elseif N == 2
209+ return _unsafe_mul! (y, A. maps[2 ] * A. maps[1 ], x)
210+ else
211+ return _compositemulN! (y, A, x, source, dest)
212+ end
213+ end
214+
215+ function _compositemulN! (y, A:: CompositeMap , x,
216+ src = nothing ,
217+ dst = nothing )
218+ N = length (A. maps) # ≥ 3
219+ source = isnothing (src) ?
220+ convert (AbstractArray, A. maps[1 ] * x) :
221+ _unsafe_mul! (src, A. maps[1 ], x)
222+ dest = isnothing (dst) ?
223+ convert (AbstractArray, A. maps[2 ] * source) :
224+ _unsafe_mul! (dst, A. maps[2 ], source)
225+ dest, source = source, dest # alternate dest and source
226+ for n in 3 : N- 1
227+ dest = _resize (dest, (size (A. maps[n], 1 ), size (x)[2 : end ]. .. ))
228+ _unsafe_mul! (dest, A. maps[n], source)
229+ dest, source = source, dest # alternate dest and source
230+ end
231+ _unsafe_mul! (y, A. maps[N], source)
181232 return y
182233end
183234
@@ -197,48 +248,3 @@ function _resize(dest::AbstractMatrix, sz::Tuple{<:Integer,<:Integer})
197248 size (dest) == sz && return dest
198249 similar (dest, sz)
199250end
200-
201- function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapTuple} , x,
202- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
203- dest = similar (y, (size (A. maps[2 ],1 ), size (x)[2 : end ]. .. )))
204- N = length (A. maps)
205- _unsafe_mul! (source, A. maps[1 ], x)
206- for n in 2 : N- 1
207- dest = _resize (dest, (size (A. maps[n],1 ), size (x)[2 : end ]. .. ))
208- _unsafe_mul! (dest, A. maps[n], source)
209- dest, source = source, dest # alternate dest and source
210- end
211- _unsafe_mul! (y, A. maps[N], source)
212- return y
213- end
214-
215- function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x)
216- N = length (A. maps)
217- if N == 1
218- return _unsafe_mul! (y, A. maps[1 ], x)
219- elseif N == 2
220- return _compositemul2! (y, A, x)
221- else
222- return _compositemulN! (y, A, x)
223- end
224- end
225-
226- function _compositemul2! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
227- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )))
228- _unsafe_mul! (source, A. maps[1 ], x)
229- _unsafe_mul! (y, A. maps[2 ], source)
230- return y
231- end
232- function _compositemulN! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
233- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
234- dest = similar (y, (size (A. maps[2 ],1 ), size (x)[2 : end ]. .. )))
235- N = length (A. maps)
236- _unsafe_mul! (source, A. maps[1 ], x)
237- for n in 2 : N- 1
238- dest = _resize (dest, (size (A. maps[n],1 ), size (x)[2 : end ]. .. ))
239- _unsafe_mul! (dest, A. maps[n], source)
240- dest, source = source, dest # alternate dest and source
241- end
242- _unsafe_mul! (y, A. maps[N], source)
243- return y
244- end
0 commit comments