Skip to content

Commit 821241f

Browse files
authored
feat: lower SVD to enzymexla ops (#1889)
* refactor: rename to batched to avoid ambiguity * feat: svd/svdvals lowering * test: svdvals + batching fix * chore: run fmt * feat: more coverage * test: use updated commit * chore: bump reactant_jll version
1 parent 6b07526 commit 821241f

File tree

8 files changed

+355
-53
lines changed

8 files changed

+355
-53
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.263"
108+
Reactant_jll = "0.0.264"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

src/Overlay.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ for (jlop, rop, default_pivot) in (
269269
end
270270
end
271271

272+
for (jlop, rop) in ((:svd, :overloaded_svd),)
273+
@eval begin
274+
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
275+
x::AbstractArray; kwargs...
276+
)
277+
if use_overlayed_version(x)
278+
return TracedLinearAlgebra.$(rop)(
279+
factorization_copy(LinearAlgebra.$(jlop), x); kwargs...
280+
)
281+
else
282+
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...)
283+
end
284+
end
285+
end
286+
end
287+
272288
@reactant_overlay @noinline function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray)
273289
if use_overlayed_version(x) || use_overlayed_version(y)
274290
return TracedLinearAlgebra.overloaded_dot(x, y)

src/stdlibs/LinearAlgebra.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ..MLIR: MLIR
44
using ..Reactant: Reactant, Ops
55
using ..Reactant:
66
TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector
7-
using ..Reactant: call_with_reactant
7+
using ..Reactant: call_with_reactant, unwrapped_eltype, promote_to
88
using ReactantCore: ReactantCore, materialize_traced_array, @trace
99
using Reactant_jll: Reactant_jll
1010

@@ -15,8 +15,9 @@ using LinearAlgebra: LinearAlgebra, BLAS
1515
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
1616
using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
1717
using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
18-
using LinearAlgebra:
19-
diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul!
18+
using LinearAlgebra: I, diag, diagm, ldiv!, det, logabsdet, istriu, istril, triu!, tril!
19+
using LinearAlgebra: inv!, rmul!, normalize
20+
using LinearAlgebra: svd, lu
2021
using Libdl: Libdl
2122
using GPUArraysCore: @allowscalar
2223

src/stdlibs/factorization/Cholesky.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <:
2-
GeneralizedFactorization{T}
1+
struct BatchedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <:
2+
BatchedFactorization{T}
33
factors::S
44
uplo::Char
55
info::I
66
end
77

8-
function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I}
8+
function BatchedCholesky(factors::S, uplo::Char, info::I) where {S,I}
99
@assert ndims(info) == ndims(factors) - 2
10-
return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info)
10+
return BatchedCholesky{eltype(factors),S,I}(factors, uplo, info)
1111
end
1212

13-
Base.size(c::GeneralizedCholesky) = size(c.factors)
14-
Base.ndims(c::GeneralizedCholesky) = ndims(c.factors)
13+
Base.size(c::BatchedCholesky) = size(c.factors)
14+
Base.size(c::BatchedCholesky, i::Integer) = size(c.factors, i)
15+
Base.ndims(c::BatchedCholesky) = ndims(c.factors)
1516

1617
function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false)
1718
return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check)
@@ -41,26 +42,26 @@ function overloaded_cholesky(
4142
info = TracedRNumber{Bool}((), info.mlir_data)
4243
end
4344

44-
return GeneralizedCholesky(factors, 'U', info)
45+
return BatchedCholesky(factors, 'U', info)
4546
end
4647

4748
function LinearAlgebra.ldiv!(
48-
F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M}
49+
F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M}
4950
) where {T,N,M}
5051
@assert N == M + 1
5152
ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...))
5253
return B
5354
end
5455

5556
function LinearAlgebra.ldiv!(
56-
F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2}
57+
F::BatchedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2}
5758
) where {T}
5859
B .= _cholesky_solve_core(F.factors, B, F.uplo)
5960
return B
6061
end
6162

6263
function LinearAlgebra.ldiv!(
63-
F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N}
64+
F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N}
6465
) where {T,N}
6566
batch_shape = size(F.factors)[3:end]
6667
@assert batch_shape == size(B)[3:end]
Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
1-
# Supports batched factorization
2-
abstract type GeneralizedFactorization{T} <: Factorization{T} end
1+
abstract type BatchedFactorization{T} <: Factorization{T} end
32

4-
function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization)
3+
function LinearAlgebra.TransposeFactorization(f::BatchedFactorization)
54
return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f)
65
end
76

8-
function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization)
7+
function LinearAlgebra.AdjointFactorization(f::BatchedFactorization)
98
return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f)
109
end
1110

12-
const GeneralizedTransposeFactorization{T} =
13-
LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T}
14-
const GeneralizedAdjointFactorization{T} =
15-
LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T}
11+
const BatchedTransposeFactorization{T} =
12+
LinearAlgebra.TransposeFactorization{T,<:BatchedFactorization{T}} where {T}
13+
const BatchedAdjointFactorization{T} =
14+
LinearAlgebra.AdjointFactorization{T,<:BatchedFactorization{T}} where {T}
1615

1716
include("Cholesky.jl")
1817
include("LU.jl")
18+
include("SVD.jl")
1919

2020
# Overload \ to support batched factorization
21-
for FT in (
22-
:GeneralizedFactorization,
23-
:GeneralizedTransposeFactorization,
24-
:GeneralizedAdjointFactorization,
25-
)
21+
for FT in
22+
(:BatchedFactorization, :BatchedTransposeFactorization, :BatchedAdjointFactorization)
2623
for aType in (:AbstractVecOrMat, :AbstractArray)
2724
@eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B)
2825
end
@@ -32,18 +29,36 @@ for FT in (
3229
) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B)
3330
end
3431

35-
function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray)
36-
return ldiv!(
37-
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
38-
)
32+
function __get_B(F::Factorization, B::AbstractArray)
33+
m, n = size(F, 1), size(F, 2)
34+
if m != size(B, 1)
35+
throw(DimensionMismatch("arguments must have the same number of rows"))
36+
end
37+
38+
TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))
39+
40+
BB = similar(B, TFB, max(size(B, 1), n), size(B)[2:end]...)
41+
if n > size(B, 1)
42+
BB[1:m, ntuple(Returns(Colon()), ndims(B) - 1)...] = B
43+
else
44+
copyto!(BB, B)
45+
end
46+
47+
return BB
48+
end
49+
50+
function _overloaded_backslash(F::BatchedFactorization, B::AbstractArray)
51+
BB = __get_B(F, B)
52+
ldiv!(F, BB)
53+
return BB[1:size(F, 2), ntuple(Returns(Colon()), ndims(B) - 1)...]
3954
end
4055

41-
function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray)
56+
function _overloaded_backslash(F::BatchedTransposeFactorization, B::AbstractArray)
4257
return conj!(adjoint(F.parent) \ conj.(B))
4358
end
4459

45-
function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray)
46-
return ldiv!(
47-
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
48-
)
60+
function _overloaded_backslash(F::BatchedAdjointFactorization, B::AbstractArray)
61+
BB = __get_B(F, B)
62+
ldiv!(F, BB)
63+
return BB[1:size(F)[2], ntuple(Returns(Colon()), ndims(B) - 1)...]
4964
end

src/stdlibs/factorization/LU.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
2-
GeneralizedFactorization{T}
1+
struct BatchedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
2+
BatchedFactorization{T}
33
factors::S
44
ipiv::P
55
perm::P
66
info::I
77
end
88

9-
Base.size(lu::GeneralizedLU) = size(lu.factors)
10-
Base.size(lu::GeneralizedLU, i) = size(lu.factors, i)
11-
Base.ndims(lu::GeneralizedLU) = ndims(lu.factors)
12-
function Base.copy(lu::GeneralizedLU)
13-
return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info))
9+
Base.size(lu::BatchedLU) = size(lu.factors)
10+
Base.size(lu::BatchedLU, i::Integer) = size(lu.factors, i)
11+
Base.ndims(lu::BatchedLU) = ndims(lu.factors)
12+
function Base.copy(lu::BatchedLU)
13+
return BatchedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info))
1414
end
1515

16-
function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
16+
function BatchedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
1717
@assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1
1818
@assert ndims(info) == ndims(factors) - 2
19-
return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
19+
return BatchedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
2020
end
2121

2222
function overloaded_lu(x::AbstractArray, args...; kwargs...)
@@ -37,26 +37,26 @@ function overloaded_lu(
3737
factors = @opcall transpose(factors, invperm(permdims))
3838
ipiv = @opcall transpose(ipiv, perm_perm)
3939
perm = @opcall transpose(perm, perm_perm)
40-
return GeneralizedLU(factors, ipiv, perm, info)
40+
return BatchedLU(factors, ipiv, perm, info)
4141
end
4242

4343
function LinearAlgebra.ldiv!(
44-
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
44+
lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
4545
) where {T,P,I,N,M}
4646
@assert N == M + 1
4747
ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...))
4848
return B
4949
end
5050

5151
function LinearAlgebra.ldiv!(
52-
lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
52+
lu::BatchedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
5353
) where {T,P,I}
5454
B .= _lu_solve_core(lu.factors, B, lu.perm)
5555
return B
5656
end
5757

5858
function LinearAlgebra.ldiv!(
59-
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
59+
lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
6060
) where {T,P,I,N}
6161
batch_shape = size(lu.factors)[3:end]
6262
@assert batch_shape == size(B)[3:end]
@@ -83,15 +83,15 @@ function LinearAlgebra.ldiv!(
8383
return B
8484
end
8585

86-
function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
86+
function LinearAlgebra.det(lu::BatchedLU{T,<:AbstractMatrix}) where {T}
8787
n = LinearAlgebra.checksquare(lu)
8888
# TODO: check for non-singular matrices
8989

9090
P = prod(LinearAlgebra.diag(lu.factors))
9191
return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P
9292
end
9393

94-
function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
94+
function LinearAlgebra.logabsdet(lu::BatchedLU{T,<:AbstractMatrix}) where {T}
9595
n = LinearAlgebra.checksquare(lu)
9696
Treal = real(T)
9797
# TODO: check for non-singular matrices
@@ -106,7 +106,7 @@ end
106106
for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization),
107107
aType in (:AbstractVecOrMat, :AbstractArray)
108108

109-
@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType)
109+
@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:BatchedLU}, B::$aType)
110110
# TODO: implement this
111111
error("`$(f_wrapper)` is not supported yet for LU.")
112112
return nothing
@@ -116,7 +116,7 @@ end
116116
# currently we lower inverse to lu decomposition + triangular solve. we should
117117
# instead emit getri and lower that to a fallback if the backend doesn't support
118118
# it.
119-
function LinearAlgebra.inv!(lu::GeneralizedLU)
119+
function LinearAlgebra.inv!(lu::BatchedLU)
120120
@assert ndims(lu) == 2 "Only implemented for 2D tensors"
121121
rhs = Reactant.promote_to(
122122
TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1))

0 commit comments

Comments
 (0)