Skip to content

Commit 0671fe3

Browse files
committed
add Mooncake to adjoint tests
1 parent e8b6aea commit 0671fe3

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ LineSearches = "7.3"
9999
LinearAlgebra = "1.10"
100100
LinearSolve = "2.36.1, 3"
101101
MINPACK = "1.2"
102+
Mooncake = "0.4"
102103
MPI = "0.20.22"
103104
NLSolvers = "0.5"
104105
NLsolve = "4.5"
@@ -146,6 +147,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
146147
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
147148
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
148149
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
150+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
149151
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
150152
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
151153
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
@@ -170,4 +172,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
170172
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
171173

172174
[targets]
173-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme"]
175+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme", "Mooncake"]

lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module NonlinearSolveBaseMooncakeExt
22

33
using NonlinearSolveBase, Mooncake
4-
using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
4+
using SciMLBase: SciMLBase
55
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
66
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
77
NoPullback

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba
795795
kwargs...)
796796
alg = extract_alg(args, kwargs, prob.kwargs)
797797
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
798-
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
798+
_prob = get_concrete_problem(prob, true; u0 = u0,
799799
p = p, kwargs...)
800800
else
801801
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
@@ -817,7 +817,7 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba
817817
kwargs...)
818818
alg = extract_alg(args, kwargs, prob.kwargs)
819819
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
820-
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
820+
_prob = get_concrete_problem(prob, true; u0 = u0,
821821
p = p, kwargs...)
822822
else
823823
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)

test/adjoint_tests.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@testitem "Adjoint Tests" tags = [:adjoint] begin
2-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme
1+
@testitem "Adjoint Tests" tags = [:nopre] begin
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake
33

44
ff(u, p) = u .^ 2 .- p
55

@@ -17,6 +17,11 @@
1717
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
1818
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
1919
∂p_enzyme = Enzyme.gradient(Enzyme.Reverse, solve_nlprob, p)[1]
20-
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
20+
21+
cache = Mooncake.prepare_gradient_cache(solve_nlprob, p)
22+
∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2]
23+
24+
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
2125
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
26+
@test_broken ∂p_forwarddiff ∂p_mooncake
2227
end

0 commit comments

Comments
 (0)