diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl index f153ee77de..9cf17d203d 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -103,4 +103,6 @@ function ChainRulesCore.rrule( newbuf, pullback end +ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol) + end diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 0e72b2b7b7..adaf6117c6 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -124,3 +124,13 @@ fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals) nsol = solve(nprob, NewtonRaphson()) @test nsol[1] ≈ 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2 end + +@testset "`sys.var` is non-differentiable" begin + @variables x(t) + @mtkbuild sys = ODESystem(D(x) ~ x, t) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) + + grad = Zygote.gradient(prob) do prob + prob[sys.x] + end +end