1- struct GeneralizedLU {T,S<: AbstractArray ,P<: AbstractArray ,I<: Union{AbstractArray,Number} } < :
2- GeneralizedFactorization {T}
1+ struct BatchedLU {T,S<: AbstractArray ,P<: AbstractArray ,I<: Union{AbstractArray,Number} } < :
2+ BatchedFactorization {T}
33 factors:: S
44 ipiv:: P
55 perm:: P
66 info:: I
77end
88
9- Base. size (lu:: GeneralizedLU ) = size (lu. factors)
10- Base. size (lu:: GeneralizedLU , i) = size (lu. factors, i)
11- Base. ndims (lu:: GeneralizedLU ) = ndims (lu. factors)
12- function Base. copy (lu:: GeneralizedLU )
13- return GeneralizedLU (copy (lu. factors), copy (lu. ipiv), copy (lu. perm), copy (lu. info))
9+ Base. size (lu:: BatchedLU ) = size (lu. factors)
10+ Base. size (lu:: BatchedLU , i:: Integer ) = size (lu. factors, i)
11+ Base. ndims (lu:: BatchedLU ) = ndims (lu. factors)
12+ function Base. copy (lu:: BatchedLU )
13+ return BatchedLU (copy (lu. factors), copy (lu. ipiv), copy (lu. perm), copy (lu. info))
1414end
1515
16- function GeneralizedLU (factors:: S , ipiv:: P , perm:: P , info:: I ) where {S,P,I}
16+ function BatchedLU (factors:: S , ipiv:: P , perm:: P , info:: I ) where {S,P,I}
1717 @assert ndims (ipiv) == ndims (perm) == ndims (factors) - 1
1818 @assert ndims (info) == ndims (factors) - 2
19- return GeneralizedLU {eltype(factors),S,P,I} (factors, ipiv, perm, info)
19+ return BatchedLU {eltype(factors),S,P,I} (factors, ipiv, perm, info)
2020end
2121
2222function overloaded_lu (x:: AbstractArray , args... ; kwargs... )
@@ -37,26 +37,26 @@ function overloaded_lu(
3737 factors = @opcall transpose (factors, invperm (permdims))
3838 ipiv = @opcall transpose (ipiv, perm_perm)
3939 perm = @opcall transpose (perm, perm_perm)
40- return GeneralizedLU (factors, ipiv, perm, info)
40+ return BatchedLU (factors, ipiv, perm, info)
4141end
4242
4343function LinearAlgebra. ldiv! (
44- lu:: GeneralizedLU {T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,M}
44+ lu:: BatchedLU {T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,M}
4545) where {T,P,I,N,M}
4646 @assert N == M + 1
4747 ldiv! (lu, reshape (B, size (B, 1 ), 1 , size (B)[2 : end ]. .. ))
4848 return B
4949end
5050
5151function LinearAlgebra. ldiv! (
52- lu:: GeneralizedLU {T,<:AbstractArray{T,2},P,I} , B:: AbstractArray{T,2}
52+ lu:: BatchedLU {T,<:AbstractArray{T,2},P,I} , B:: AbstractArray{T,2}
5353) where {T,P,I}
5454 B .= _lu_solve_core (lu. factors, B, lu. perm)
5555 return B
5656end
5757
5858function LinearAlgebra. ldiv! (
59- lu:: GeneralizedLU {T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,N}
59+ lu:: BatchedLU {T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,N}
6060) where {T,P,I,N}
6161 batch_shape = size (lu. factors)[3 : end ]
6262 @assert batch_shape == size (B)[3 : end ]
@@ -83,15 +83,15 @@ function LinearAlgebra.ldiv!(
8383 return B
8484end
8585
86- function LinearAlgebra. det (lu:: GeneralizedLU {T,<:AbstractMatrix} ) where {T}
86+ function LinearAlgebra. det (lu:: BatchedLU {T,<:AbstractMatrix} ) where {T}
8787 n = LinearAlgebra. checksquare (lu)
8888 # TODO : check for non-singular matrices
8989
9090 P = prod (LinearAlgebra. diag (lu. factors))
9191 return ifelse (isodd (sum (lu. ipiv[1 : n] .!= (1 : n))), - one (T), one (T)) * P
9292end
9393
94- function LinearAlgebra. logabsdet (lu:: GeneralizedLU {T,<:AbstractMatrix} ) where {T}
94+ function LinearAlgebra. logabsdet (lu:: BatchedLU {T,<:AbstractMatrix} ) where {T}
9595 n = LinearAlgebra. checksquare (lu)
9696 Treal = real (T)
9797 # TODO : check for non-singular matrices
106106for f_wrapper in (LinearAlgebra. TransposeFactorization, LinearAlgebra. AdjointFactorization),
107107 aType in (:AbstractVecOrMat , :AbstractArray )
108108
109- @eval function LinearAlgebra. ldiv! (lu:: $ (f_wrapper){<: Any ,<: GeneralizedLU }, B:: $aType )
109+ @eval function LinearAlgebra. ldiv! (lu:: $ (f_wrapper){<: Any ,<: BatchedLU }, B:: $aType )
110110 # TODO : implement this
111111 error (" `$(f_wrapper) ` is not supported yet for LU." )
112112 return nothing
116116# currently we lower inverse to lu decomposition + triangular solve. we should
117117# instead emit getri and lower that to a fallback if the backend doesn't support
118118# it.
119- function LinearAlgebra. inv! (lu:: GeneralizedLU )
119+ function LinearAlgebra. inv! (lu:: BatchedLU )
120120 @assert ndims (lu) == 2 " Only implemented for 2D tensors"
121121 rhs = Reactant. promote_to (
122122 TracedRArray{Reactant. unwrapped_eltype (eltype (lu)),2 }, LinearAlgebra. I (size (lu, 1 ))
0 commit comments