From ddf0d12c06e9afc7777d83777f6537bdb214757c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 22:08:02 +0530 Subject: [PATCH] feat: mark `getproperty(::AbstractSystem, ::Symbol)` as non-differentiable --- ext/MTKChainRulesCoreExt.jl | 2 ++ test/extensions/ad.jl | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl index f153ee77de..9cf17d203d 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -103,4 +103,6 @@ function ChainRulesCore.rrule( newbuf, pullback end +ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol) + end diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 0e72b2b7b7..adaf6117c6 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -124,3 +124,13 @@ fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals) nsol = solve(nprob, NewtonRaphson()) @test nsol[1] ≈ 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2 end + +@testset "`sys.var` is non-differentiable" begin + @variables x(t) + @mtkbuild sys = ODESystem(D(x) ~ x, t) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) + + grad = Zygote.gradient(prob) do prob + prob[sys.x] + end +end