|
| 1 | +struct EmbeddedMap{T, As <: LinearMap, Rs <: AbstractVector{Int}, |
| 2 | + Cs <: AbstractVector{Int}} <: LinearMap{T} |
| 3 | + lmap::As |
| 4 | + dims::Dims{2} |
| 5 | + rows::Rs # typically i1:i2 with 1 <= i1 <= i2 <= size(map,1) |
| 6 | + cols::Cs # typically j1:j2 with 1 <= j1 <= j2 <= size(map,2) |
| 7 | + |
| 8 | + function EmbeddedMap{T}(map::As, dims::Dims{2}, rows::Rs, cols::Cs) where {T, |
| 9 | + As <: LinearMap, Rs <: AbstractVector{Int}, Cs <: AbstractVector{Int}} |
| 10 | + check_index(rows, size(map, 1), dims[1]) |
| 11 | + check_index(cols, size(map, 2), dims[2]) |
| 12 | + return new{T,As,Rs,Cs}(map, dims, rows, cols) |
| 13 | + end |
| 14 | +end |
| 15 | + |
| 16 | +EmbeddedMap(map::LinearMap{T}, dims::Dims{2}; offset::Dims{2}) where {T} = |
| 17 | + EmbeddedMap{T}(map, dims, offset[1] .+ (1:size(map, 1)), offset[2] .+ (1:size(map, 2))) |
| 18 | +EmbeddedMap(map::LinearMap, dims::Dims{2}, rows::AbstractVector{Int}, cols::AbstractVector{Int}) = |
| 19 | + EmbeddedMap{eltype(map)}(map, dims, rows, cols) |
| 20 | + |
| 21 | +@static if VERSION >= v"1.8-" |
| 22 | + Base.reverse(A::LinearMap; dims=:) = _reverse(A, dims) |
| 23 | + function _reverse(A, dims::Integer) |
| 24 | + if dims == 1 |
| 25 | + return EmbeddedMap(A, size(A), reverse(axes(A, 1)), axes(A, 2)) |
| 26 | + elseif dims == 2 |
| 27 | + return EmbeddedMap(A, size(A), axes(A, 1), reverse(axes(A, 2))) |
| 28 | + else |
| 29 | + throw(ArgumentError("invalid dims argument to reverse, should be 1 or 2, got $dims")) |
| 30 | + end |
| 31 | + end |
| 32 | + _reverse(A, ::Colon) = EmbeddedMap(A, size(A), map(reverse, axes(A))...) |
| 33 | + _reverse(A, dims::NTuple{1,Integer}) = _reverse(A, first(dims)) |
| 34 | + function _reverse(A, dims::NTuple{M,Integer}) where {M} |
| 35 | + dimrev = ntuple(k -> k in dims, 2) |
| 36 | + if 2 < M || M != sum(dimrev) |
| 37 | + throw(ArgumentError("invalid dimensions $dims in reverse!")) |
| 38 | + end |
| 39 | + ax = ntuple(k -> dimrev[k] ? reverse(axes(A, k)) : axes(A, k), 2) |
| 40 | + return EmbeddedMap(A, size(A), ax...) |
| 41 | + end |
| 42 | +end |
| 43 | + |
| 44 | +function check_index(index::AbstractVector{Int}, dimA::Int, dimB::Int) |
| 45 | + length(index) != dimA && throw(ArgumentError("invalid length of index vector")) |
| 46 | + minimum(index) <= 0 && throw(ArgumentError("minimal index is below 1")) |
| 47 | + maximum(index) > dimB && throw(ArgumentError( |
| 48 | + "maximal index $(maximum(index)) exceeds dimension $dimB" |
| 49 | + )) |
| 50 | + # _isvalidstep(index) || throw(ArgumentError("non-monotone index set")) |
| 51 | + nothing |
| 52 | +end |
| 53 | + |
| 54 | +# _isvalidstep(index::AbstractRange) = step(index) > 0 |
| 55 | +# _isvalidstep(index::AbstractVector) = all(diff(index) .> 0) |
| 56 | + |
| 57 | +Base.size(A::EmbeddedMap) = A.dims |
| 58 | + |
| 59 | +# sufficient but not necessary conditions |
| 60 | +LinearAlgebra.issymmetric(A::EmbeddedMap) = |
| 61 | + issymmetric(A.lmap) && (A.dims[1] == A.dims[2]) && (A.rows == A.cols) |
| 62 | +LinearAlgebra.ishermitian(A::EmbeddedMap) = |
| 63 | + ishermitian(A.lmap) && (A.dims[1] == A.dims[2]) && (A.rows == A.cols) |
| 64 | + |
| 65 | +Base.:(==)(A::EmbeddedMap, B::EmbeddedMap) = |
| 66 | + (eltype(A) == eltype(B)) && (A.lmap == B.lmap) && |
| 67 | + (A.dims == B.dims) && (A.rows == B.rows) && (A.cols == B.cols) |
| 68 | + |
| 69 | +LinearAlgebra.adjoint(A::EmbeddedMap) = EmbeddedMap(adjoint(A.lmap), reverse(A.dims), A.cols, A.rows) |
| 70 | +LinearAlgebra.transpose(A::EmbeddedMap) = EmbeddedMap(transpose(A.lmap), reverse(A.dims), A.cols, A.rows) |
| 71 | + |
| 72 | +for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix)) |
| 73 | + @eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In) |
| 74 | + fill!(y, zero(eltype(y))) |
| 75 | + _unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols)) |
| 76 | + return y |
| 77 | + end |
| 78 | + @eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In, alpha::Number, beta::Number) |
| 79 | + LinearAlgebra._rmul_or_fill!(y, beta) |
| 80 | + _unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols), alpha, !iszero(beta)) |
| 81 | + return y |
| 82 | + end |
| 83 | +end |
0 commit comments