Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ as𝕀
```@docs
UnitVector
CorrCholeskyFactor
PosDefCholeskyFactor
```

# Defining custom transformations
Expand Down
62 changes: 61 additions & 1 deletion src/special_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export UnitVector, CorrCholeskyFactor
export UnitVector, CorrCholeskyFactor, PosDefCholeskyFactor

"""
(y, r, ℓ) = $SIGNATURES
Expand Down Expand Up @@ -129,3 +129,63 @@ function inverse!(x::RealVector, t::CorrCholeskyFactor, U::UpperTriangular)
end
x
end



"""
PosDefCholeskyFactor(n)

Cholesky factor of a symmetric positive-definite matrix of size `n`.

Transforms ``n×(n+1)/2`` real numbers to an ``n×n`` upper-triangular matrix `Ω`, such that
`Ω'*Ω` is a positive definite matrix.
"""
@calltrans struct PosDefCholeskyFactor <: VectorTransform
n::Int
function PosDefCholeskyFactor(n)
@argcheck n ≥ 1 "Dimension should be positive."
new(n)
end
end

dimension(t::PosDefCholeskyFactor) = (t.n*(t.n+1)) ÷ 2

function transform_with(flag::LogJacFlag, t::PosDefCholeskyFactor, x::RealVector)
@unpack n = t
T = extended_eltype(x)
ℓ = logjac_zero(flag, T)
U = Matrix{T}(undef, n, n)
index = firstindex(x)
@inbounds for col in 1:n
for row in 1:(col-1)
U[row, col] = x[index]
index += 1
end
if flag isa NoLogJac
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could avoid code duplication here by using transform_with(flag, ...). When flag::NoLogJac, it is "summed" automatically as a no-op.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transform_with doesn't seem to work for scalar transforms with scalar args:

julia> TransformVariables.transform_with(TransformVariables.NoLogJac(), asℝ₊, 1.0)
ERROR: MethodError: no method matching transform_with(::TransformVariables.NoLogJac, ::TransformVariables.ShiftedExp{true,Float64}, ::Float64)
Closest candidates are:
  transform_with(::TransformVariables.NoLogJac, ::TransformVariables.ScalarTransform, ::AbstractArray{T,1} where T<:Real) at /Users/simon/.julia/dev/TransformVariables/src/scalar.jl:15
  transform_with(::TransformVariables.LogJac, ::TransformVariables.ScalarTransform, ::AbstractArray{T,1} where T<:Real) at /Users/simon/.julia/dev/TransformVariables/src/scalar.jl:18
  transform_with(::TransformVariables.LogJacFlag, ::UnitVector, ::AbstractArray{T,1} where T<:Real) at /Users/simon/.julia/dev/TransformVariables/src/special_arrays.jl:45
  ...
Stacktrace:
 [1] top-level scope at none:0

U[col, col] = transform(asℝ₊, x[index])
else
U[col, col], ℓi = transform_and_logjac(asℝ₊, x[index])
ℓ += ℓi
end
index += 1
end
UpperTriangular(U), ℓ
end

inverse_eltype(t::PosDefCholeskyFactor, U::UpperTriangular) = extended_eltype(U)

function inverse!(x::RealVector, t::PosDefCholeskyFactor, U::UpperTriangular)
@unpack n = t
@argcheck size(U, 1) == n
@argcheck length(x) == dimension(t)
index = firstindex(x)
@inbounds for col in 1:n
for row in 1:(col-1)
x[index] = U[row,col]
index += 1
end
x[index] = inverse(asℝ₊, U[col, col])
index += 1
end
x
end
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,14 @@ end
@test_nowarn @inferred transform(t, x)
@test_nowarn @inferred transform_and_logjac(t, x)
end

@testset "positive definite cholesky factor" begin
t = PosDefCholeskyFactor(4)
d = dimension(t)

v = randn(d)
U = transform(t, v)
@test U <: UpperTriangular
@test size(U) = (4,4)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want a == here.

@test inverse(t,U) ≈ v
end
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please

  • also test for U being positive (semi) definite (eg just check the diagonal),
  • check the log jacobian using ForwardDiff (you find examples of this in the tests),
  • test that the results are @inferred

Happy to help if necessary.