From b501fda1b949c28f9baee2a96576e59e48e25ec1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 20 May 2024 21:28:47 +0800 Subject: [PATCH] Support Tridiagonal in to_vec --- Project.toml | 2 +- src/to_vec.jl | 16 ++++++++++++++-- test/to_vec.jl | 2 ++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 4107c5a..d6c26eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.31" +version = "0.12.32" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/to_vec.jl b/src/to_vec.jl index af991f3..2d20313 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -111,14 +111,26 @@ function to_vec(x::T) where {T<:LinearAlgebra.HermOrSym} return x_vec, HermOrSym_from_vec end -function to_vec(X::Diagonal) - x_vec, back = to_vec(Matrix(X)) +function to_vec(x::Diagonal) + x_vec, back = to_vec(Matrix(x)) function Diagonal_from_vec(x_vec) return Diagonal(back(x_vec)) end return x_vec, Diagonal_from_vec end +function to_vec(x::Tridiagonal) + x_vec, back = to_vec((x.dl, x.d, x.du)) + # Other field (du2) of a Tridiagonal is not part of its value and is really a kind of cache + function Tridiagonal_from_vec(x_vec) + return Tridiagonal(back(x_vec)...) + end + return x_vec, Tridiagonal_from_vec +end + + + + function to_vec(X::Transpose) x_vec, back = to_vec(Matrix(X)) function Transpose_from_vec(x_vec) diff --git a/test/to_vec.jl b/test/to_vec.jl index 5e5ab5f..f18e756 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -88,6 +88,8 @@ end test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false) test_to_vec(UpperTriangular(randn(T, 13, 13))) test_to_vec(Diagonal(randn(T, 7))) + test_to_vec(Tridiagonal(randn(T, 3), randn(T, 4), randn(T, 3))) + test_to_vec(DummyType(randn(T, 2, 9))) test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false) test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)