From b5e2ce26e80081e723480d802d41163b76c4de95 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 16:51:47 -0400 Subject: [PATCH 01/25] make solve algorithms use __solve --- lib/BracketingNonlinearSolve/src/alefeld.jl | 2 +- lib/BracketingNonlinearSolve/src/bisection.jl | 2 +- lib/BracketingNonlinearSolve/src/brent.jl | 2 +- lib/BracketingNonlinearSolve/src/falsi.jl | 2 +- lib/BracketingNonlinearSolve/src/itp.jl | 2 +- lib/BracketingNonlinearSolve/src/muller.jl | 2 +- lib/BracketingNonlinearSolve/src/ridder.jl | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/BracketingNonlinearSolve/src/alefeld.jl b/lib/BracketingNonlinearSolve/src/alefeld.jl index 6880f8c95..86807986a 100644 --- a/lib/BracketingNonlinearSolve/src/alefeld.jl +++ b/lib/BracketingNonlinearSolve/src/alefeld.jl @@ -8,7 +8,7 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal """ struct Alefeld <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Alefeld, args...; maxiters = 1000, abstol = nothing, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/bisection.jl b/lib/BracketingNonlinearSolve/src/bisection.jl index 91c17a775..5f056abbc 100644 --- a/lib/BracketingNonlinearSolve/src/bisection.jl +++ b/lib/BracketingNonlinearSolve/src/bisection.jl @@ -19,7 +19,7 @@ A common bisection method. exact_right::Bool = false end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Bisection, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/brent.jl b/lib/BracketingNonlinearSolve/src/brent.jl index 7baebc90c..6199bf29a 100644 --- a/lib/BracketingNonlinearSolve/src/brent.jl +++ b/lib/BracketingNonlinearSolve/src/brent.jl @@ -5,7 +5,7 @@ Left non-allocating Brent method. """ struct Brent <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Brent, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/falsi.jl b/lib/BracketingNonlinearSolve/src/falsi.jl index 3074a5eb4..a2bdbde1f 100644 --- a/lib/BracketingNonlinearSolve/src/falsi.jl +++ b/lib/BracketingNonlinearSolve/src/falsi.jl @@ -5,7 +5,7 @@ A non-allocating regula falsi method. """ struct Falsi <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Falsi, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/itp.jl b/lib/BracketingNonlinearSolve/src/itp.jl index cbf5818bf..c733dc25f 100644 --- a/lib/BracketingNonlinearSolve/src/itp.jl +++ b/lib/BracketingNonlinearSolve/src/itp.jl @@ -56,7 +56,7 @@ function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10) return ITP(scaled_k1, k2, n0) end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::ITP, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) diff --git a/lib/BracketingNonlinearSolve/src/muller.jl b/lib/BracketingNonlinearSolve/src/muller.jl index 7b89236a0..1e321969b 100644 --- a/lib/BracketingNonlinearSolve/src/muller.jl +++ b/lib/BracketingNonlinearSolve/src/muller.jl @@ -27,7 +27,7 @@ end Muller() = Muller(nothing) -function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; +function SciMLBase.__solve(prob::IntervalNonlinearProblem, alg::Muller, args...; abstol = nothing, maxiters = 1000, kwargs...) @assert !SciMLBase.isinplace(prob) "`Muller` only supports out-of-place problems." xᵢ₋₂, xᵢ = prob.tspan diff --git a/lib/BracketingNonlinearSolve/src/ridder.jl b/lib/BracketingNonlinearSolve/src/ridder.jl index 9192897c5..9e38d25b6 100644 --- a/lib/BracketingNonlinearSolve/src/ridder.jl +++ b/lib/BracketingNonlinearSolve/src/ridder.jl @@ -5,7 +5,7 @@ A non-allocating ridder method. """ struct Ridder <: AbstractBracketingAlgorithm end -function CommonSolve.solve( +function SciMLBase.__solve( prob::IntervalNonlinearProblem, alg::Ridder, args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs... ) From 5205bdf91cfdb478313c44401b4aebe698b3f392 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 16:52:17 -0400 Subject: [PATCH 02/25] add bracketingnonlinear_solve_up --- .../src/BracketingNonlinearSolve.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl index 64db84621..ce99cc982 100644 --- a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl +++ b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl @@ -24,10 +24,21 @@ include("ridder.jl") function CommonSolve.solve(prob::IntervalNonlinearProblem; kwargs...) return CommonSolve.solve(prob, ITP(); kwargs...) end + function CommonSolve.solve(prob::IntervalNonlinearProblem, nothing, args...; kwargs...) return CommonSolve.solve(prob, ITP(), args...; kwargs...) end +function CommonSolve.solve(prob::IntervalNonlinearProblem, + alg::AbstractBracketingAlgorithm, args...; sensealg = nothing, kwargs...) + return bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, prob.p, alg, args...; kwargs...) +end + + +function bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, p, alg, args...; kwargs...) + return SciMLBase.__solve(prob, alg, args...; kwargs...) +end + @setup_workload begin for T in (Float32, Float64) prob_brack = IntervalNonlinearProblem{false}( From ec41473236eea1ae54f1ccc5968e60a12644ebbf Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:19:42 -0400 Subject: [PATCH 03/25] add extensions --- lib/BracketingNonlinearSolve/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 2c1526b07..db5284791 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -19,6 +19,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" +BracketingNonlinearSolveChainRulesCoreExt = "ChainRulesCore" +BracketingNonlinearSolveDiffEqBaseExt = "DiffEqBase" [compat] Aqua = "0.8.9" From ef321121604ce1d3cbe001bbdac246a8b46a9d5e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:20:29 -0400 Subject: [PATCH 04/25] add Bracketing ChainRulesCoreExt --- ...acketingNonlinearSolveChainRulesCoreExt.jl | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl new file mode 100644 index 000000000..cda2cece0 --- /dev/null +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -0,0 +1,31 @@ +module BracketingNonlinearSolveChainRulesCoreExt + +using CommonSolve: CommonSolve +using ForwardDiff: ForwardDiff +using DiffEqBase + +using BracketingNonlinearSolve: bracketingnonlinear_solve_up, is_extension_loaded + +function ChainRulesCore.rrule( + ::typeof(bracketingnonlinear_solve_up), + prob::IntervalNonlinearProblem, + sensealg, p, alg, args...; kwargs... +) + # DiffEqBase is needed for problem/function constructor adjoint + !is_extension_loaded(Val(:DiffEqBase)) && + error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") + out = solve(prob) + u = out.u + f = SciMLBase.unwrapped_f(prob.f) + function ∇bracketingnonlinear_solve_up(Δ) + # Δ = dg/du + λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ Δ.u) + dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), only(p)) + return (NoTangent(), NoTangent(), NoTangent(), + dgdp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + return out, ∇bracketingnonlinear_solve_up +end + +end \ No newline at end of file From 1610a516e2c5aceb18a06e06ab1014af39b3bae9 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:20:57 -0400 Subject: [PATCH 05/25] better error message to make sure problem constructor adjoints exist --- .../ext/BracketingNonlinearSolveDiffEqBaseExt.jl | 5 +++++ lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl | 2 ++ 2 files changed, 7 insertions(+) create mode 100644 lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl new file mode 100644 index 000000000..da5616e75 --- /dev/null +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl @@ -0,0 +1,5 @@ +module BracketingNonlinearSolveDiffEqBaseExt + +BracketingNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true + +end diff --git a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl index ce99cc982..f291368c1 100644 --- a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl +++ b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl @@ -39,6 +39,8 @@ function bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, return SciMLBase.__solve(prob, alg, args...; kwargs...) end +is_extension_loaded(::Val) = false + @setup_workload begin for T in (Float32, Float64) prob_brack = IntervalNonlinearProblem{false}( From 845ec7fc409971e74f8c88be012772e459c225a0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:26:46 -0400 Subject: [PATCH 06/25] add weakdeps --- lib/BracketingNonlinearSolve/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index db5284791..387743275 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -16,6 +16,8 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"} [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" From b8eb03f766cb725007b68b8a5656211f6cab03fb Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:42:23 -0400 Subject: [PATCH 07/25] add test --- .../test/adjoint_tests.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 lib/BracketingNonlinearSolve/test/adjoint_tests.jl diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl new file mode 100644 index 000000000..dc05776c6 --- /dev/null +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -0,0 +1,18 @@ +@testitem "Simple Adjoint Test" tags=[:adjoint] begin + using ForwardDiff, Zygote, DiffEqBase + + ff(u, p) = u^2 .- p[1] + + function solve_nlprob(p) + prob = IntervalNonlinearProblem{false}(ff, [1.0, 2.0], p) + sol = solve(prob, Broyden()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [3.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + @test ∂p_zygote ≈ ∂p_forwarddiff +end From a878e8e0fea5943fca5e1403fb4a211b415865df Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 17:42:46 -0400 Subject: [PATCH 08/25] fix test --- lib/BracketingNonlinearSolve/test/adjoint_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl index dc05776c6..b7a751e23 100644 --- a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -4,7 +4,7 @@ ff(u, p) = u^2 .- p[1] function solve_nlprob(p) - prob = IntervalNonlinearProblem{false}(ff, [1.0, 2.0], p) + prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) sol = solve(prob, Broyden()) res = sol isa AbstractArray ? sol : sol.u return sum(abs2, res) From db469c130c44174f9472fd08f86de00e238c47de Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 21:49:14 -0400 Subject: [PATCH 09/25] use SciMLBase instead --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index cda2cece0..afb9d4888 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -1,8 +1,8 @@ module BracketingNonlinearSolveChainRulesCoreExt using CommonSolve: CommonSolve -using ForwardDiff: ForwardDiff -using DiffEqBase +using ForwardDiff +using SciMLBase using BracketingNonlinearSolve: bracketingnonlinear_solve_up, is_extension_loaded @@ -12,8 +12,6 @@ function ChainRulesCore.rrule( sensealg, p, alg, args...; kwargs... ) # DiffEqBase is needed for problem/function constructor adjoint - !is_extension_loaded(Val(:DiffEqBase)) && - error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") out = solve(prob) u = out.u f = SciMLBase.unwrapped_f(prob.f) From d52842f73202d24aa9771e731fab88db4440c0cc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:06:19 -0400 Subject: [PATCH 10/25] use gradient, p might not be scalar --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index afb9d4888..171ae3f78 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -18,7 +18,7 @@ function ChainRulesCore.rrule( function ∇bracketingnonlinear_solve_up(Δ) # Δ = dg/du λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ Δ.u) - dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), only(p)) + dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) return (NoTangent(), NoTangent(), NoTangent(), dgdp, NoTangent(), ntuple(_ -> NoTangent(), length(args))...) From ff43257c5baea4544393e42fef971b34dc90b3e0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:06:42 -0400 Subject: [PATCH 11/25] add zygote as trigger for chainrulescore extension --- lib/BracketingNonlinearSolve/Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 387743275..371f3c496 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -18,10 +18,11 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"} ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" -BracketingNonlinearSolveChainRulesCoreExt = "ChainRulesCore" +BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "Zygote"] BracketingNonlinearSolveDiffEqBaseExt = "DiffEqBase" [compat] From 634b1c3fc13b2cc606b653dd5a7695a463807c27 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:12:16 -0400 Subject: [PATCH 12/25] account for both derivative and gradient --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index 171ae3f78..df2534e53 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -18,7 +18,11 @@ function ChainRulesCore.rrule( function ∇bracketingnonlinear_solve_up(Δ) # Δ = dg/du λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ Δ.u) - dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) + if p isa Number + dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p) + else + dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) + end return (NoTangent(), NoTangent(), NoTangent(), dgdp, NoTangent(), ntuple(_ -> NoTangent(), length(args))...) From f593cc4420d7cff04a9eda1c0586177e3eae4b4b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:13:20 -0400 Subject: [PATCH 13/25] old docstring --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index df2534e53..f22bce621 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -11,7 +11,6 @@ function ChainRulesCore.rrule( prob::IntervalNonlinearProblem, sensealg, p, alg, args...; kwargs... ) - # DiffEqBase is needed for problem/function constructor adjoint out = solve(prob) u = out.u f = SciMLBase.unwrapped_f(prob.f) From 53cd09a1f46be629b6160283124b3c281593329c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:41:40 -0400 Subject: [PATCH 14/25] add ForwardDiff trigger, more using --- lib/BracketingNonlinearSolve/Project.toml | 2 +- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 371f3c496..019fec869 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" -BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "Zygote"] +BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "Zygote", "ForwardDiff"] BracketingNonlinearSolveDiffEqBaseExt = "DiffEqBase" [compat] diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index f22bce621..acc765e41 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -1,8 +1,9 @@ module BracketingNonlinearSolveChainRulesCoreExt using CommonSolve: CommonSolve -using ForwardDiff +using ForwardDiff: ForwardDiff using SciMLBase +using ChainRulesCore using BracketingNonlinearSolve: bracketingnonlinear_solve_up, is_extension_loaded From bac45ad89567b607c58646ae7b6b77fc4eda6c6d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 21 May 2025 22:44:56 -0400 Subject: [PATCH 15/25] get rid of unnecessary Zygote --- lib/BracketingNonlinearSolve/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 019fec869..09fe636e4 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" -BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "Zygote", "ForwardDiff"] +BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "ForwardDiff"] BracketingNonlinearSolveDiffEqBaseExt = "DiffEqBase" [compat] From 77b9e5ae6389a79ffe2b7b271790258771dc724d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 09:33:29 -0400 Subject: [PATCH 16/25] fix adjoint test --- lib/BracketingNonlinearSolve/test/adjoint_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl index b7a751e23..3b29de330 100644 --- a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -10,9 +10,9 @@ return sum(abs2, res) end - p = [3.0, 2.0] + p = [2.0, 2.0] - ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_zygote = Zygote.gradient(solve_nlprob, p) ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) @test ∂p_zygote ≈ ∂p_forwarddiff end From 0ed1191266016810fddaa9f26238150f1fc65a6f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 10:53:00 -0400 Subject: [PATCH 17/25] don't need diffeqbase ext stuff --- lib/BracketingNonlinearSolve/Project.toml | 1 - .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 2 +- .../ext/BracketingNonlinearSolveDiffEqBaseExt.jl | 5 ----- lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl | 1 - 4 files changed, 1 insertion(+), 8 deletions(-) delete mode 100644 lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 09fe636e4..db493dce2 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -17,7 +17,6 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"} [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index acc765e41..a336384a5 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -5,7 +5,7 @@ using ForwardDiff: ForwardDiff using SciMLBase using ChainRulesCore -using BracketingNonlinearSolve: bracketingnonlinear_solve_up, is_extension_loaded +using BracketingNonlinearSolve: bracketingnonlinear_solve_up function ChainRulesCore.rrule( ::typeof(bracketingnonlinear_solve_up), diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl deleted file mode 100644 index da5616e75..000000000 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveDiffEqBaseExt.jl +++ /dev/null @@ -1,5 +0,0 @@ -module BracketingNonlinearSolveDiffEqBaseExt - -BracketingNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true - -end diff --git a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl index f291368c1..9337ac6fe 100644 --- a/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl +++ b/lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl @@ -39,7 +39,6 @@ function bracketingnonlinear_solve_up(prob::IntervalNonlinearProblem, sensealg, return SciMLBase.__solve(prob, alg, args...; kwargs...) end -is_extension_loaded(::Val) = false @setup_workload begin for T in (Float32, Float64) From 7d48d7d35c5cc37a28b0c5621dc4f43fa1b1e844 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 10:53:15 -0400 Subject: [PATCH 18/25] load bracketing nonlinear solve in test --- lib/BracketingNonlinearSolve/test/adjoint_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl index 3b29de330..14ac46845 100644 --- a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -1,5 +1,5 @@ @testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, Zygote, DiffEqBase + using ForwardDiff, Zygote, DiffEqBase, BracketingNonlinearSolve ff(u, p) = u^2 .- p[1] From f458c9648c2d6afaf330ef1514e65e89f593b516 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 10:54:30 -0400 Subject: [PATCH 19/25] fix project.toml --- lib/BracketingNonlinearSolve/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index db493dce2..6b73bff37 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -22,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "ForwardDiff"] -BracketingNonlinearSolveDiffEqBaseExt = "DiffEqBase" [compat] Aqua = "0.8.9" From 50ce8605d1e1cca1ec98d37b35c950f394c76f61 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 10:57:28 -0400 Subject: [PATCH 20/25] add Zygote to test deps --- lib/BracketingNonlinearSolve/Project.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 6b73bff37..4f89781f4 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -17,7 +17,7 @@ NonlinearSolveBase = {path = "../NonlinearSolveBase"} [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [extensions] BracketingNonlinearSolveForwardDiffExt = "ForwardDiff" @@ -45,6 +45,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"] +test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner", "Zygote"] From d70f9b3894dd067483e41f6f601d1ca3c24eba73 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 11:24:58 -0400 Subject: [PATCH 21/25] test should use Bisection --- lib/BracketingNonlinearSolve/test/adjoint_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl index 14ac46845..fb2e67794 100644 --- a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -1,11 +1,11 @@ @testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, Zygote, DiffEqBase, BracketingNonlinearSolve + using ForwardDiff, Zygote, BracketingNonlinearSolve ff(u, p) = u^2 .- p[1] function solve_nlprob(p) prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) - sol = solve(prob, Broyden()) + sol = solve(prob, Bisection()) res = sol isa AbstractArray ? sol : sol.u return sum(abs2, res) end From 7f3db450bc6324cfe462f77544a5e70d8fd7daa7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 11:25:28 -0400 Subject: [PATCH 22/25] account for Thunks, non tangent types --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index a336384a5..5f9b2e7dc 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -12,12 +12,14 @@ function ChainRulesCore.rrule( prob::IntervalNonlinearProblem, sensealg, p, alg, args...; kwargs... ) - out = solve(prob) + out = solve(prob, alg) u = out.u f = SciMLBase.unwrapped_f(prob.f) function ∇bracketingnonlinear_solve_up(Δ) + Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ # Δ = dg/du - λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ Δ.u) + Δ isa Tangent ? delu = Δ.u : delu = Δ + λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ delu) if p isa Number dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p) else From 671d23acc7c2af0c166665749f71c836b5875c5a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 11:44:51 -0400 Subject: [PATCH 23/25] fix test --- lib/BracketingNonlinearSolve/test/adjoint_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl index fb2e67794..50c2dce2a 100644 --- a/lib/BracketingNonlinearSolve/test/adjoint_tests.jl +++ b/lib/BracketingNonlinearSolve/test/adjoint_tests.jl @@ -12,7 +12,7 @@ p = [2.0, 2.0] - ∂p_zygote = Zygote.gradient(solve_nlprob, p) + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) @test ∂p_zygote ≈ ∂p_forwarddiff end From 789e04b910c48f3ae15561ba01a6da79576459c9 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 22 May 2025 11:45:13 -0400 Subject: [PATCH 24/25] make imports explicit, add ompat bounds --- lib/BracketingNonlinearSolve/Project.toml | 2 ++ .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index 4f89781f4..176c78452 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -25,6 +25,7 @@ BracketingNonlinearSolveChainRulesCoreExt = ["ChainRulesCore", "ForwardDiff"] [compat] Aqua = "0.8.9" +ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" ExplicitImports = "1.10.1" @@ -37,6 +38,7 @@ SciMLBase = "2.69" Test = "1.10" TestItemRunner = "1" julia = "1.10" +Zygote = "0.6.69, 0.7" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index 5f9b2e7dc..6bf5c1ab1 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -1,9 +1,9 @@ module BracketingNonlinearSolveChainRulesCoreExt -using CommonSolve: CommonSolve +using CommonSolve: CommonSolve, solve using ForwardDiff: ForwardDiff -using SciMLBase -using ChainRulesCore +using SciMLBase: SciMLBase, IntervalNonlinearProblem +using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk using BracketingNonlinearSolve: bracketingnonlinear_solve_up From d8b82aff6a2515bcbb45ac1fe3416a13ff5c41d5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 22 May 2025 20:18:01 +0000 Subject: [PATCH 25/25] Update lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl --- .../ext/BracketingNonlinearSolveChainRulesCoreExt.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl index 6bf5c1ab1..388517367 100644 --- a/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl +++ b/lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl @@ -1,12 +1,11 @@ module BracketingNonlinearSolveChainRulesCoreExt -using CommonSolve: CommonSolve, solve -using ForwardDiff: ForwardDiff -using SciMLBase: SciMLBase, IntervalNonlinearProblem +using BracketingNonlinearSolve: bracketingnonlinear_solve_up, CommonSolve, SciMLBase +using CommonSolve: solve +using SciMLBase: IntervalNonlinearProblem +using ForwardDiff using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk -using BracketingNonlinearSolve: bracketingnonlinear_solve_up - function ChainRulesCore.rrule( ::typeof(bracketingnonlinear_solve_up), prob::IntervalNonlinearProblem,