Skip to content

Commit f1e7f52

Browse files
dkarraschstevengj
andauthored
Add rdiv! for Bidiagonal + small improvements (#43779)
Co-authored-by: Steven G. Johnson <stevenj@alum.mit.edu>
1 parent 3f0ae6e commit f1e7f52

File tree

5 files changed

+326
-125
lines changed

5 files changed

+326
-125
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 193 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ end
253253

254254
adjoint(B::Bidiagonal) = Adjoint(B)
255255
transpose(B::Bidiagonal) = Transpose(B)
256-
adjoint(B::Bidiagonal{<:Real}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
256+
adjoint(B::Bidiagonal{<:Number}) = Bidiagonal(conj(B.dv), conj(B.ev), B.uplo == 'U' ? :L : :U)
257257
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
258258
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
259259
function permutedims(B::Bidiagonal, perm)
@@ -640,51 +640,39 @@ end
640640

641641
function *(A::AbstractTriangular, B::Union{SymTridiagonal, Tridiagonal})
642642
TS = promote_op(matprod, eltype(A), eltype(B))
643-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
643+
A_mul_B_td!(zeros(TS, size(A)), A, B)
644644
end
645645

646-
const UpperOrUnitUpperTriangular = Union{UpperTriangular, UnitUpperTriangular}
647-
const LowerOrUnitLowerTriangular = Union{LowerTriangular, UnitLowerTriangular}
646+
const UpperOrUnitUpperTriangular{T} = Union{UpperTriangular{T}, UnitUpperTriangular{T}}
647+
const LowerOrUnitLowerTriangular{T} = Union{LowerTriangular{T}, UnitLowerTriangular{T}}
648648

649649
function *(A::UpperOrUnitUpperTriangular, B::Bidiagonal)
650650
TS = promote_op(matprod, eltype(A), eltype(B))
651-
if B.uplo == 'U'
652-
A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B)
653-
else
654-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
655-
end
651+
C = A_mul_B_td!(zeros(TS, size(A)), A, B)
652+
return B.uplo == 'U' ? UpperTriangular(C) : C
656653
end
657654

658655
function *(A::LowerOrUnitLowerTriangular, B::Bidiagonal)
659656
TS = promote_op(matprod, eltype(A), eltype(B))
660-
if B.uplo == 'L'
661-
A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B)
662-
else
663-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
664-
end
657+
C = A_mul_B_td!(zeros(TS, size(A)), A, B)
658+
return B.uplo == 'L' ? LowerTriangular(C) : C
665659
end
666660

667661
function *(A::Union{SymTridiagonal, Tridiagonal}, B::AbstractTriangular)
668662
TS = promote_op(matprod, eltype(A), eltype(B))
669-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
663+
A_mul_B_td!(zeros(TS, size(A)), A, B)
670664
end
671665

672666
function *(A::Bidiagonal, B::UpperOrUnitUpperTriangular)
673667
TS = promote_op(matprod, eltype(A), eltype(B))
674-
if A.uplo == 'U'
675-
A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B)
676-
else
677-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
678-
end
668+
C = A_mul_B_td!(zeros(TS, size(A)), A, B)
669+
return A.uplo == 'U' ? UpperTriangular(C) : C
679670
end
680671

681672
function *(A::Bidiagonal, B::LowerOrUnitLowerTriangular)
682673
TS = promote_op(matprod, eltype(A), eltype(B))
683-
if A.uplo == 'L'
684-
A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B)
685-
else
686-
A_mul_B_td!(zeros(TS, size(A)...), A, B)
687-
end
674+
C = A_mul_B_td!(zeros(TS, size(A)), A, B)
675+
return A.uplo == 'L' ? LowerTriangular(C) : C
688676
end
689677

690678
function *(A::BiTri, B::Diagonal)
@@ -709,7 +697,7 @@ end
709697

710698
function *(A::BiTriSym, B::BiTriSym)
711699
TS = promote_op(matprod, eltype(A), eltype(B))
712-
mul!(similar(A, TS, size(A)...), A, B)
700+
mul!(similar(A, TS, size(A)), A, B)
713701
end
714702

715703
function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
@@ -744,85 +732,206 @@ end
744732

745733
#Linear solvers
746734
#Generic solver using naive substitution
747-
function ldiv!(A::Bidiagonal, b::AbstractVector)
748-
require_one_based_indexing(A, b)
735+
ldiv!(A::Bidiagonal, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
736+
function ldiv!(c::AbstractVecOrMat, A::Bidiagonal, b::AbstractVecOrMat)
737+
require_one_based_indexing(c, A, b)
749738
N = size(A, 2)
750-
mb = length(b)
739+
mb, nb = size(b, 1), size(b, 2)
751740
if N != mb
752-
throw(DimensionMismatch("second dimension of A, $N, does not match the length of b, $mb"))
741+
throw(DimensionMismatch("second dimension of A, $N, does not match first dimension of b, $mb"))
742+
end
743+
mc, nc = size(c, 1), size(c, 2)
744+
if mc != mb || nc != nb
745+
throw(DimensionMismatch("size of result, ($mc, $nc), does not match the size of b, ($mb, $nb)"))
753746
end
754747

755748
if N == 0
756-
return b
749+
return copyto!(c, b)
757750
end
758751

759-
@inbounds begin
760-
if A.uplo == 'L' #do forward substitution
761-
b[1] = bj1 = A.dv[1]\b[1]
762-
for j in 2:N
763-
bj = b[j]
764-
bj -= A.ev[j - 1] * bj1
765-
dvj = A.dv[j]
766-
if iszero(dvj)
767-
throw(SingularException(j))
768-
end
769-
bj = dvj\bj
770-
b[j] = bj1 = bj
752+
zi = findfirst(iszero, A.dv)
753+
isnothing(zi) || throw(SingularException(zi))
754+
755+
@inbounds for j in 1:nb
756+
if A.uplo == 'L' #do colwise forward substitution
757+
c[1,j] = bi1 = A.dv[1] \ b[1,j]
758+
for i in 2:N
759+
c[i,j] = bi1 = A.dv[i] \ (b[i,j] - A.ev[i - 1] * bi1)
771760
end
772-
else #do backward substitution
773-
b[N] = bj1 = A.dv[N]\b[N]
774-
for j = (N - 1):-1:1
775-
bj = b[j]
776-
bj -= A.ev[j] * bj1
777-
dvj = A.dv[j]
778-
if iszero(dvj)
779-
throw(SingularException(j))
780-
end
781-
bj = dvj\bj
782-
b[j] = bj1 = bj
761+
else #do colwise backward substitution
762+
c[N,j] = bi1 = A.dv[N] \ b[N,j]
763+
for i in (N - 1):-1:1
764+
c[i,j] = bi1 = A.dv[i] \ (b[i,j] - A.ev[i] * bi1)
783765
end
784766
end
785767
end
786-
return b
787-
end
788-
function ldiv!(A::Bidiagonal, B::AbstractMatrix)
789-
require_one_based_indexing(A, B)
790-
mA, nA = size(A)
791-
n = size(B, 1)
792-
if mA != n
793-
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of B, $n"))
794-
end
795-
for b in eachcol(B)
796-
ldiv!(A, b)
797-
end
798-
B
768+
return c
799769
end
800-
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = ldiv!(copy(A), b)
801-
ldiv!(A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = ldiv!(copy(A), b)
770+
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
771+
ldiv!(A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
772+
ldiv!(c::AbstractVecOrMat, A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
773+
(_rdiv!(transpose(c), transpose(b), transpose(A)); return c)
774+
ldiv!(c::AbstractVecOrMat, A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
775+
(_rdiv!(adjoint(c), adjoint(b), adjoint(A)); return c)
802776

803777
### Generic promotion methods and fallbacks
804778
function \(A::Bidiagonal{<:Number}, B::AbstractVecOrMat{<:Number})
805779
TA, TB = eltype(A), eltype(B)
806-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
807-
ldiv!(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
780+
TAB = typeof((oneunit(TA))\oneunit(TB))
781+
ldiv!(zeros(TAB, size(B)), A, B)
808782
end
809-
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(A, copy(B))
810-
function \(tA::Transpose{<:Number,<:Bidiagonal{<:Number}}, B::AbstractVecOrMat{<:Number})
811-
A = tA.parent
812-
TA, TB = eltype(A), eltype(B)
813-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
814-
ldiv!(transpose(convert(AbstractArray{TAB}, A)), copy_oftype(B, TAB))
783+
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(copy(B), A, B)
784+
\(tA::Transpose{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(tA) \ B
785+
\(adjA::Adjoint{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(adjA) \ B
786+
787+
### Triangular specializations
788+
function \(B::Bidiagonal{<:Number}, U::UpperOrUnitUpperTriangular{<:Number})
789+
T = typeof((oneunit(eltype(B)))\oneunit(eltype(U)))
790+
A = ldiv!(zeros(T, size(U)), B, U)
791+
return B.uplo == 'U' ? UpperTriangular(A) : A
815792
end
816-
\(tA::Transpose{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = ldiv!(tA, copy(B))
817-
function \(adjA::Adjoint{<:Number,<:Bidiagonal{<:Number}}, B::AbstractVecOrMat{<:Number})
818-
A = adjA.parent
819-
TA, TB = eltype(A), eltype(B)
820-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
821-
ldiv!(adjoint(convert(AbstractArray{TAB}, A)), copy_oftype(B, TAB))
793+
function \(B::Bidiagonal, U::UpperOrUnitUpperTriangular)
794+
A = ldiv!(copy(parent(U)), B, U)
795+
return B.uplo == 'U' ? UpperTriangular(A) : A
796+
end
797+
function \(B::Bidiagonal{<:Number}, L::LowerOrUnitLowerTriangular{<:Number})
798+
T = typeof((oneunit(eltype(B)))\oneunit(eltype(L)))
799+
A = ldiv!(zeros(T, size(L)), B, L)
800+
return B.uplo == 'L' ? LowerTriangular(A) : A
801+
end
802+
function \(B::Bidiagonal, L::LowerOrUnitLowerTriangular)
803+
A = ldiv!(copy(parent(L)), B, L)
804+
return B.uplo == 'L' ? LowerTriangular(A) : A
822805
end
823-
\(adjA::Adjoint{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = ldiv!(adjA, copy(B))
806+
807+
function \(U::UpperOrUnitUpperTriangular{<:Number}, B::Bidiagonal{<:Number})
808+
T = typeof((oneunit(eltype(U)))/oneunit(eltype(B)))
809+
A = ldiv!(U, copy_similar(B, T))
810+
return B.uplo == 'U' ? UpperTriangular(A) : A
811+
end
812+
function \(L::LowerOrUnitLowerTriangular{<:Number}, B::Bidiagonal{<:Number})
813+
T = typeof((oneunit(eltype(L)))/oneunit(eltype(B)))
814+
A = ldiv!(L, copy_similar(B, T))
815+
return B.uplo == 'L' ? LowerTriangular(A) : A
816+
end
817+
### Diagonal specialization
818+
function \(B::Bidiagonal{<:Number}, D::Diagonal{<:Number})
819+
T = typeof((oneunit(eltype(B)))\oneunit(eltype(D)))
820+
A = ldiv!(zeros(T, size(D)), B, D)
821+
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
822+
end
823+
824+
function _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal)
825+
require_one_based_indexing(C, A, B)
826+
m, n = size(A)
827+
if size(B, 1) != n
828+
throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(size(B,1))"))
829+
end
830+
mc, nc = size(C)
831+
if mc != m || nc != n
832+
throw(DimensionMismatch("expect output to have size ($m, $n), but got ($mc, $nc)"))
833+
end
834+
835+
zi = findfirst(iszero, B.dv)
836+
isnothing(zi) || throw(SingularException(zi))
837+
838+
if B.uplo == 'L'
839+
diagB = B.dv[n]
840+
for i in 1:m
841+
C[i,n] = A[i,n] / diagB
842+
end
843+
for j in n-1:-1:1
844+
diagB = B.dv[j]
845+
offdiagB = B.ev[j]
846+
for i in 1:m
847+
C[i,j] = (A[i,j] - C[i,j+1]*offdiagB)/diagB
848+
end
849+
end
850+
else
851+
diagB = B.dv[1]
852+
for i in 1:m
853+
C[i,1] = A[i,1] / diagB
854+
end
855+
for j in 2:n
856+
diagB = B.dv[j]
857+
offdiagB = B.ev[j-1]
858+
for i = 1:m
859+
C[i,j] = (A[i,j] - C[i,j-1]*offdiagB)/diagB
860+
end
861+
end
862+
end
863+
C
864+
end
865+
rdiv!(A::AbstractMatrix, B::Bidiagonal) = @inline _rdiv!(A, A, B)
866+
rdiv!(A::AbstractMatrix, B::Adjoint{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
867+
rdiv!(A::AbstractMatrix, B::Transpose{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
868+
rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::Adjoint{<:Any,<:Bidiagonal}) =
869+
(ldiv!(adjoint(C), adjoint(B), adjoint(A)); return C)
870+
rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::Transpose{<:Any,<:Bidiagonal}) =
871+
(ldiv!(transpose(C), transpose(B), transpose(A)); return C)
872+
873+
function /(A::AbstractMatrix{<:Number}, B::Bidiagonal{<:Number})
874+
TA, TB = eltype(A), eltype(B)
875+
TAB = typeof((oneunit(TA))/oneunit(TB))
876+
_rdiv!(zeros(TAB, size(A)), A, B)
877+
end
878+
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(copy(A), A, B)
879+
880+
### Triangular specializations
881+
function /(U::UpperOrUnitUpperTriangular{<:Number}, B::Bidiagonal{<:Number})
882+
T = typeof((oneunit(eltype(U)))/oneunit(eltype(B)))
883+
A = _rdiv!(zeros(T, size(U)), U, B)
884+
return B.uplo == 'U' ? UpperTriangular(A) : A
885+
end
886+
function /(U::UpperOrUnitUpperTriangular, B::Bidiagonal)
887+
A = _rdiv!(copy(parent(U)), U, B)
888+
return B.uplo == 'U' ? UpperTriangular(A) : A
889+
end
890+
function /(L::LowerOrUnitLowerTriangular{<:Number}, B::Bidiagonal{<:Number})
891+
T = typeof((oneunit(eltype(L)))/oneunit(eltype(B)))
892+
A = _rdiv!(zeros(T, size(L)), L, B)
893+
return B.uplo == 'L' ? LowerTriangular(A) : A
894+
end
895+
function /(L::LowerOrUnitLowerTriangular, B::Bidiagonal)
896+
A = _rdiv!(copy(parent(L)), L, B)
897+
return B.uplo == 'L' ? LowerTriangular(A) : A
898+
end
899+
function /(B::Bidiagonal{<:Number}, U::UpperOrUnitUpperTriangular{<:Number})
900+
T = typeof((oneunit(eltype(B)))/oneunit(eltype(U)))
901+
A = rdiv!(copy_similar(B, T), U)
902+
return B.uplo == 'U' ? UpperTriangular(A) : A
903+
end
904+
function /(B::Bidiagonal{<:Number}, L::LowerOrUnitLowerTriangular{<:Number})
905+
T = typeof((oneunit(eltype(B)))\oneunit(eltype(L)))
906+
A = rdiv!(copy_similar(B, T), L)
907+
return B.uplo == 'L' ? LowerTriangular(A) : A
908+
end
909+
### Diagonal specialization
910+
function /(D::Diagonal{<:Number}, B::Bidiagonal{<:Number})
911+
T = typeof((oneunit(eltype(D)))/oneunit(eltype(B)))
912+
A = _rdiv!(zeros(T, size(D)), D, B)
913+
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
914+
end
915+
916+
/(A::AbstractMatrix, B::Transpose{<:Any,<:Bidiagonal}) = A / copy(B)
917+
/(A::AbstractMatrix, B::Adjoint{<:Any,<:Bidiagonal}) = A / copy(B)
918+
# disambiguation
919+
/(A::AdjointAbsVec{<:Number}, B::Bidiagonal{<:Number}) = adjoint(adjoint(B) \ parent(A))
920+
/(A::TransposeAbsVec{<:Number}, B::Bidiagonal{<:Number}) = transpose(transpose(B) \ parent(A))
921+
/(A::AdjointAbsVec, B::Bidiagonal) = adjoint(adjoint(B) \ parent(A))
922+
/(A::TransposeAbsVec, B::Bidiagonal) = transpose(transpose(B) \ parent(A))
923+
/(A::AdjointAbsVec, B::Transpose{<:Any,<:Bidiagonal}) = adjoint(adjoint(B) \ parent(A))
924+
/(A::TransposeAbsVec, B::Transpose{<:Any,<:Bidiagonal}) = transpose(transpose(B) \ parent(A))
925+
/(A::AdjointAbsVec, B::Adjoint{<:Any,<:Bidiagonal}) = adjoint(adjoint(B) \ parent(A))
926+
/(A::TransposeAbsVec, B::Adjoint{<:Any,<:Bidiagonal}) = transpose(transpose(B) \ parent(A))
824927

825928
factorize(A::Bidiagonal) = A
929+
function inv(B::Bidiagonal{T}) where T
930+
n = size(B, 1)
931+
dest = zeros(typeof(oneunit(T)\one(T)), (n, n))
932+
ldiv!(dest, B, Diagonal{typeof(one(T)\one(T))}(I, n))
933+
return B.uplo == 'U' ? UpperTriangular(dest) : LowerTriangular(dest)
934+
end
826935

827936
# Eigensystems
828937
eigvals(M::Bidiagonal) = M.dv

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
351351
return C
352352
end
353353

354-
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D)), size(A)), A, D)
354+
/(A::AbstractVecOrMat, D::Diagonal) =
355+
_rdiv!((promote_op(/, eltype(A), eltype(D))).(A), A, D)
355356

356357
rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
357358
# avoid copy when possible via internal 3-arg backend
@@ -372,7 +373,8 @@ function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
372373
B
373374
end
374375

375-
\(D::Diagonal, B::AbstractVecOrMat) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B)), size(B)), D, B)
376+
\(D::Diagonal, B::AbstractVecOrMat) =
377+
ldiv!(promote_op(\, eltype(D), eltype(B)).(B), D, B)
376378

377379
ldiv!(D::Diagonal, B::AbstractVecOrMat) = @inline ldiv!(B, D, B)
378380
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)

0 commit comments

Comments
 (0)