Skip to content
13 changes: 13 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
end
return y, logdet_pullback
end

#####
##### Tridiagonal
#####

function rrule(::Type{Tridiagonal}, dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(ȳ)
∂y = unthunk(ȳ)
return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
end
return y, Tridiagonal_pullback
end
4 changes: 4 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,8 @@
end
end
end

@testset "Tridiagonal" begin
test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
end
end