Skip to content

Commit c18a4d2

Browse files
authored
Solve triu (#1920)
1 parent 56ca039 commit c18a4d2

File tree

4 files changed

+185
-16
lines changed

4 files changed

+185
-16
lines changed

src/Matrix-Strassen.jl

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
Provides generic asymptotically fast matrix methods:
3-
- mul and mul! using the Strassen scheme
4-
- _solve_tril!
5-
- lu!
6-
- _solve_triu
3+
- `mul` and `mul!` using the Strassen scheme
4+
- `_solve_tril!`
5+
- `lu!`
6+
- `_solve_triu`
77
88
Just prefix the function by "Strassen." all 4 functions support a keyword
99
argument "cutoff" to indicate when the base case should be used.
@@ -40,6 +40,12 @@ function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
4040
end
4141

4242
#scheduling copied from the nmod_mat_mul in Flint
43+
"""
44+
Fast, recursive, generic matrix multiplication using the Strassen
45+
trick.
46+
47+
`cutoff` indicates when the recursion stops and the base case is called.
48+
"""
4349
function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
4450
sA = size(A)
4551
sB = size(B)
@@ -274,17 +280,82 @@ function lu!(P::Perm{Int}, A; cutoff::Int = 300)
274280
return r1 + r2
275281
end
276282

277-
function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
278-
#b*inv(T), thus solves Tx = b for T upper triangular
283+
function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left)
284+
#inv(T)*b, thus solves Tx = b for T upper triangular
279285
n = ncols(T)
280286
if n <= cutoff
281-
R = AbstractAlgebra._solve_triu(T, b)
287+
R = AbstractAlgebra._solve_triu(T, b; side)
282288
return R
283289
end
290+
if side == :left
291+
return _solve_triu_left(T, b; cutoff)
292+
end
293+
@assert side == :right
294+
@assert n == nrows(T) == nrows(b)
295+
296+
n2 = div(n, 2) + n % 2
297+
m = ncols(b)
298+
m2 = div(m, 2) + m % 2
299+
#=
300+
b = [U X; V Y]
301+
T = [A B; 0 C]
302+
x = [SS RR; S R]
303+
304+
[0 C] [SS; S] = CS = V
305+
[0 C] [RR; R] = CR = Y
306+
307+
[A B] [SS; S] = A SS + B S = U => A SS = U - BS
308+
[A B] [RR; R] = A RR + B R = U => A RR = X - BR
309+
310+
=#
311+
312+
U = view(b, 1:n2, 1:m2)
313+
X = view(b, 1:n2, m2+1:m)
314+
V = view(b, n2+1:n, 1:m2)
315+
Y = view(b, n2+1:n, m2+1:m)
316+
317+
A = view(T, 1:n2, 1:n2)
318+
B = view(T, 1:n2, 1+n2:n)
319+
C = view(T, 1+n2:n, 1+n2:n)
320+
321+
S = _solve_triu(C, V; cutoff, side)
322+
R = _solve_triu(C, Y; cutoff, side)
323+
324+
SS = mul(B, S; cutoff)
325+
SS = sub!(SS, U, SS)
326+
SS = _solve_triu(A, SS; cutoff, side)
327+
328+
RR = mul(B, R; cutoff)
329+
RR = sub!(RR, X, RR)
330+
RR = _solve_triu(A, RR; cutoff, side)
331+
332+
return [SS RR; S R]
333+
end
334+
335+
function _solve_triu_left(T::MatElem, b::MatElem; cutoff::Int = cutoff)
336+
#b*inv(T), thus solves xT = b for T upper triangular
337+
n = ncols(T)
338+
if n <= cutoff
339+
R = AbstractAlgebra._solve_triu_left(T, b)
340+
return R
341+
end
342+
343+
@assert ncols(b) == nrows(T) == n
284344

285345
n2 = div(n, 2) + n % 2
286346
m = nrows(b)
287347
m2 = div(m, 2) + m % 2
348+
#=
349+
b = [U X; V Y]
350+
T = [A B; 0 C]
351+
x = [S SS; R RR]
352+
353+
[S SS] [A; 0] = SA = U
354+
[R RR] [A; 0] = RA = V
355+
[S SS] [B; C] = SB + SS C = X => SS C = Y - SB
356+
[R RR] [B; C] = RB + RR C = Y => RR C = Y - RB
357+
358+
=#
288359

289360
U = view(b, 1:m2, 1:n2)
290361
V = view(b, 1:m2, n2+1:n)
@@ -295,18 +366,21 @@ function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
295366
B = view(T, 1:n2, 1+n2:n)
296367
C = view(T, 1+n2:n, 1+n2:n)
297368

298-
S = _solve_triu(A, U; cutoff)
299-
R = _solve_triu(A, X; cutoff)
369+
S = _solve_triu_left(A, U; cutoff)
370+
R = _solve_triu_left(A, X; cutoff)
300371

301372
SS = mul(S, B; cutoff)
302373
SS = sub!(SS, V, SS)
303-
SS = _solve_triu(C, SS; cutoff)
374+
SS = _solve_triu_left(C, SS; cutoff)
304375

305376
RR = mul(R, B; cutoff)
306377
RR = sub!(RR, Y, RR)
307-
RR = _solve_triu(C, RR; cutoff)
378+
RR = _solve_triu_left(C, RR; cutoff)
379+
#THINK: both pairs of solving could be combined:
380+
# solve [U; X], A to get S and R...
308381

309382
return [S SS; R RR]
310383
end
311384

385+
312386
end # module

src/Matrix.jl

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,7 @@ function rref!(A::MatrixElem{T}) where {T <: FieldElement}
21352135
V[j, i] = A[j, pivots[np + i]]
21362136
end
21372137
end
2138-
V = _solve_triu(U, V, false)
2138+
V = _solve_triu_right(U, V; unipotent = false)
21392139
for i = 1:rnk
21402140
for j = 1:i
21412141
A[j, pivots[i]] = i == j ? one(R) : R()
@@ -3411,14 +3411,14 @@ $n\times m$ matrix $x$ such that $Ux = b$. If $U$ is singular an exception
34113411
is raised. If unit is true then $U$ is assumed to have ones on its
34123412
diagonal, and the diagonal will not be read.
34133413
"""
3414-
function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T <: FieldElement}
3414+
function _solve_triu_right(U::MatElem{T}, b::MatElem{T}; unipotent::Bool = false) where {T <: FieldElement}
34153415
n = nrows(U)
34163416
m = ncols(b)
34173417
R = base_ring(U)
34183418
X = zero(b)
34193419
Tinv = Vector{elem_type(R)}(undef, n)
34203420
tmp = Vector{elem_type(R)}(undef, n)
3421-
if unit == false
3421+
if unipotent == false
34223422
for i = 1:n
34233423
Tinv[i] = inv(U[i, i])
34243424
end
@@ -3435,7 +3435,7 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
34353435
end
34363436
s = reduce!(s)
34373437
s = b[j, i] - s
3438-
if unit == false
3438+
if unipotent == false
34393439
s = mul!(s, s, Tinv[j])
34403440
end
34413441
tmp[j] = s
@@ -3447,6 +3447,89 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
34473447
return X
34483448
end
34493449

3450+
@doc raw"""
3451+
_solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
3452+
3453+
Let $U$ be a non-singular $n\times n$ upper triangular matrix $U$ over a field. If
3454+
`side = :right`, let $b$
3455+
be an $n\times m$ matrix $b$ over the same field, return an
3456+
$n\times m$ matrix $x$ such that $Ux = b$. If this is not possible, an error
3457+
will be raised.
3458+
3459+
If `side = :left`, the default, $b$ has to be $m \times n$. In this case
3460+
$xU = b$ is solved - or an error raised.
3461+
3462+
See also [`AbstractAlgebra._solve_triu_left`](@ref) and [`Strassen`](@ref) for
3463+
asymptotically fast versions.
3464+
"""
3465+
function _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
3466+
if side == :left
3467+
return _solve_triu_left(U, b)
3468+
end
3469+
@assert side == :right
3470+
n = nrows(U)
3471+
m = ncols(b)
3472+
R = base_ring(U)
3473+
X = zero(b)
3474+
tmp = Vector{elem_type(R)}(undef, n)
3475+
t = R()
3476+
for i = 1:m
3477+
for j = 1:n
3478+
tmp[j] = X[j, i]
3479+
end
3480+
for j = n:-1:1
3481+
s = R(0)
3482+
for k = j + 1:n
3483+
s = addmul!(s, U[j, k], tmp[k], t)
3484+
# s = s + U[j, k] * tmp[k]
3485+
end
3486+
s = b[j, i] - s
3487+
tmp[j] = divexact(s, U[j,j])
3488+
end
3489+
for j = 1:n
3490+
X[j, i] = tmp[j]
3491+
end
3492+
end
3493+
return X
3494+
end
3495+
3496+
@doc raw"""
3497+
_solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
3498+
3499+
Given a non-singular $n\times n$ matrix $U$ over a field which is upper
3500+
triangular, and an $m\times n$ matrix $b$ over the same ring, return an
3501+
$m\times n$ matrix $x$ such that $xU = b$. If this is not possible, an error
3502+
will be raised.
3503+
3504+
See also [`_solve_triu`](@ref) and [`Strassen`](@ref) for asymptotically fast
3505+
versions.
3506+
"""
3507+
function _solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
3508+
n = ncols(U)
3509+
m = nrows(b)
3510+
R = base_ring(U)
3511+
X = zero(b)
3512+
tmp = Vector{elem_type(R)}(undef, n)
3513+
t = R()
3514+
for i = 1:m
3515+
for j = 1:n
3516+
tmp[j] = X[i, j]
3517+
end
3518+
for j = 1:n
3519+
s = R()
3520+
for k = 1:j-1
3521+
s = addmul!(s, U[k, j], tmp[k], t)
3522+
end
3523+
s = b[i, j] - s
3524+
tmp[j] = divexact(s, U[j,j])
3525+
end
3526+
for j = 1:n
3527+
X[i, j] = tmp[j]
3528+
end
3529+
end
3530+
return X
3531+
end
3532+
34503533
#solves A x = B for A intended to be lower triangular
34513534
#only the lower part is used. if f is true, then the diagonal is assumed to be 1
34523535
#used to use lu!

test/Solve-test.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,15 @@ end
259259
@test ncols(S) == 3
260260
@test base_ring(S) == QQ
261261
end
262+
263+
@testset "solve_triu" begin
264+
A = matrix(ZZ, 10, 10, [i<=j ? i+j-1 : 0 for i=1:10 for j=1:10])
265+
x = matrix(ZZ, rand(-10:10, 10, 10))
266+
@test AbstractAlgebra._solve_triu(A, A*x; side = :right) == x
267+
@test AbstractAlgebra._solve_triu(A, x*A; side = :left) == x
268+
269+
A = matrix(ZZ, 20, 20, [i<=j ? i+j-1 : 0 for i=1:20 for j=1:20])
270+
x = matrix(ZZ, rand(-10:10, 20, 20))
271+
@test AbstractAlgebra.Strassen._solve_triu(A, A*x; cutoff = 10, side = :right) == x
272+
@test AbstractAlgebra.Strassen._solve_triu(A, x*A; cutoff = 10, side = :left) == x
273+
end

test/generic/Matrix-test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2658,7 +2658,7 @@ end
26582658
M = randmat_triu(S, -100:100)
26592659
b = rand(U, -100:100)
26602660

2661-
x = AbstractAlgebra._solve_triu(M, b, false)
2661+
x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false)
26622662

26632663
@test M*x == b
26642664
end

0 commit comments

Comments
 (0)