From 44090d587ad5acdb143554c7254ad63c083033cf Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Thu, 31 Jan 2019 17:31:25 -0800 Subject: [PATCH 1/3] add PosDefCholeskyFactor Fixes #6. --- docs/src/index.md | 1 + src/special_arrays.jl | 62 ++++++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 11 ++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 69f76d8..61c3d2c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -98,6 +98,7 @@ as𝕀 ```@docs UnitVector CorrCholeskyFactor +PosDefCholeskyFactor ``` # Defining custom transformations diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 653d2a0..8bf9b72 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -1,4 +1,4 @@ -export UnitVector, CorrCholeskyFactor +export UnitVector, CorrCholeskyFactor, PosDefCholeskyFactor """ (y, r, ℓ) = $SIGNATURES @@ -129,3 +129,63 @@ function inverse!(x::RealVector, t::CorrCholeskyFactor, U::UpperTriangular) end x end + + + +""" + PosDefCholeskyFactor(n) + +Cholesky factor of a symmetric positive-definite matrix of size `n`. + +Transforms ``n×(n+1)/2`` real numbers to an ``n×n`` upper-triangular matrix `Ω`, such that +`Ω'*Ω` is a positive definite matrix. +""" +@calltrans struct PosDefCholeskyFactor <: VectorTransform + n::Int + function PosDefCholeskyFactor(n) + @argcheck n ≥ 1 "Dimension should be positive." + new(n) + end +end + +dimension(t::PosDefCholeskyFactor) = (t.n*(t.n+1)) ÷ 2 + +function transform_with(flag::LogJacFlag, t::PosDefCholeskyFactor, x::RealVector) + @unpack n = t + T = extended_eltype(x) + ℓ = logjac_zero(flag, T) + U = Matrix{T}(undef, n, n) + index = firstindex(x) + @inbounds for col in 1:n + for row in 1:(col-1) + U[row, col] = x[index] + index += 1 + end + if flag isa NoLogJac + U[col, col] = transform(asℝ₊, x[index]) + else + U[col, col], ℓi = transform_and_logjac(asℝ₊, x[index]) + ℓ += ℓi + end + index += 1 + end + UpperTriangular(U), ℓ +end + +inverse_eltype(t::PosDefCholeskyFactor, U::UpperTriangular) = extended_eltype(U) + +function inverse!(x::RealVector, t::PosDefCholeskyFactor, U::UpperTriangular) + @unpack n = t + @argcheck size(U, 1) == n + @argcheck length(x) == dimension(t) + index = firstindex(x) + @inbounds for col in 1:n + for row in 1:(col-1) + x[index] = U[row,col] + index += 1 + end + x[index] = inverse(asℝ₊, U[col, col]) + index += 1 + end + x +end diff --git a/test/runtests.jl b/test/runtests.jl index ddac45b..e30e988 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -339,3 +339,14 @@ end @test_nowarn @inferred transform(t, x) @test_nowarn @inferred transform_and_logjac(t, x) end + +@testset "positive definite cholesky factor" begin + t = PosDefCholeskyFactor(4) + d = dimension(t) + + v = randn(d) + U = transform(t, v) + @test U <: UpperTriangular + @test size(U) = (4,4) + @test inverse(t,U) ≈ v +end From ace4fc6279294d3c8893b7db5fc96305dc4708fe Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Fri, 1 Feb 2019 09:36:54 -0800 Subject: [PATCH 2/3] isa instead of <: --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e30e988..4297eaf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ import Flux, ForwardDiff, ReverseDiff using LogDensityProblems: Value, ValueGradient using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransform, - unit_triangular_dimension, logistic, logistic_logjac, logit + unit_triangular_dimension, logistic, logistic_logjac, l<:ogit include("test_utilities.jl") @@ -346,7 +346,7 @@ end v = randn(d) U = transform(t, v) - @test U <: UpperTriangular + @test U isa UpperTriangular @test size(U) = (4,4) @test inverse(t,U) ≈ v end From fa52a5ce1f6e59ad922a423a05d6583c4a58f16e Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Fri, 1 Feb 2019 09:37:29 -0800 Subject: [PATCH 3/3] typo --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4297eaf..fa0c6ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ import Flux, ForwardDiff, ReverseDiff using LogDensityProblems: Value, ValueGradient using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransform, - unit_triangular_dimension, logistic, logistic_logjac, l<:ogit + unit_triangular_dimension, logistic, logistic_logjac, logit include("test_utilities.jl")