4040
4141# # Linear algebra ##
4242
43- LinearAlgebra. UpperTriangular (A:: TrackedMatrix ) = track (UpperTriangular, A)
44- @grad function LinearAlgebra. UpperTriangular (A:: AbstractMatrix )
45- return UpperTriangular (data (A)), Δ-> (UpperTriangular (Δ),)
43+ # Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767
44+
45+ upper (A:: AbstractMatrix ) = UpperTriangular (A)
46+ lower (A:: AbstractMatrix ) = LowerTriangular (A)
47+ function upper (C:: Cholesky )
48+ if C. uplo == ' U'
49+ return upper (C. factors)
50+ else
51+ return copy (lower (C. factors)' )
52+ end
53+ end
54+ function lower (C:: Cholesky )
55+ if C. uplo == ' U'
56+ return copy (upper (C. factors)' )
57+ else
58+ return lower (C. factors)
59+ end
60+ end
61+
62+ LinearAlgebra. LowerTriangular (A:: TrackedMatrix ) = lower (A)
63+ lower (A:: TrackedMatrix ) = track (lower, A)
64+ @grad lower (A) = lower (Tracker. data (A)), ∇ -> (lower (∇),)
65+
66+ LinearAlgebra. UpperTriangular (A:: TrackedMatrix ) = upper (A)
67+ upper (A:: TrackedMatrix ) = track (upper, A)
68+ @grad upper (A) = upper (Tracker. data (A)), ∇ -> (upper (∇),)
69+
70+ function Base. copy (
71+ A:: TrackedArray {T, 2 , <: Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}} },
72+ ) where {T <: Real }
73+ return track (copy, A)
74+ end
75+ @grad function Base. copy (
76+ A:: TrackedArray {T, 2 , <: Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}} },
77+ ) where {T <: Real }
78+ return copy (data (A)), ∇ -> (copy (∇),)
4679end
4780
4881function LinearAlgebra. cholesky (A:: TrackedMatrix ; check= true )
@@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check)
5790end
5891turing_chol (A:: TrackedMatrix , check) = track (turing_chol, A, check)
5992@grad function turing_chol (A:: AbstractMatrix , check)
60- C, back = pullback (unsafe_cholesky , data (A), data (check))
93+ C, back = pullback (_turing_chol , data (A), data (check))
6194 return (C. factors, C. info), Δ-> back ((factors= data (Δ[1 ]),))
6295end
63-
64- unsafe_cholesky (x, check) = cholesky (x, check= check)
65- @adjoint function unsafe_cholesky (Σ:: Real , check)
66- C = cholesky (Σ; check= check)
67- return C, function (Δ:: NamedTuple )
68- issuccess (C) || return (zero (Σ), nothing )
69- (Δ. factors[1 , 1 ] / (2 * C. U[1 , 1 ]), nothing )
70- end
71- end
72- @adjoint function unsafe_cholesky (Σ:: Diagonal , check)
73- C = cholesky (Σ; check= check)
74- return C, function (Δ:: NamedTuple )
75- issuccess (C) || (Diagonal (zero (diag (Δ. factors))), nothing )
76- (Diagonal (diag (Δ. factors) .* inv .(2 .* C. factors. diag)), nothing )
77- end
78- end
79- @adjoint function unsafe_cholesky (Σ:: Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}} , check)
80- C = cholesky (Σ; check= check)
81- return C, function (Δ:: NamedTuple )
82- issuccess (C) || return (zero (Δ. factors), nothing )
83- U, Ū = C. U, Δ. factors
84- Σ̄ = Ū * U'
85- Σ̄ = copytri! (Σ̄, ' U' )
86- Σ̄ = ldiv! (U, Σ̄)
87- BLAS. trsm! (' R' , ' U' , ' T' , ' N' , one (eltype (Σ)), U. data, Σ̄)
88- @inbounds for n in diagind (Σ̄)
89- Σ̄[n] /= 2
90- end
91- return (UpperTriangular (Σ̄), nothing )
92- end
93- end
96+ _turing_chol (x, check) = cholesky (x, check= check)
9497
9598# Specialised logdet for cholesky to target the triangle directly.
9699logdet_chol_tri (U:: AbstractMatrix ) = 2 * sum (log, U[diagind (U)])
0 commit comments