Skip to content

Commit 549b381

Browse files
fix broken broken tests for Mooncake.
1 parent bd4fba9 commit 549b381

File tree

3 files changed

+32
-27
lines changed

3 files changed

+32
-27
lines changed

lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,23 @@ function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob,
1919
end
2020

2121
function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem,
22-
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
23-
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
24-
kwargs...)
25-
NonlinearSolveBase._solve_adjoint(
22+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
23+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
24+
kwargs...)
25+
primal, inner_thunking_pb = NonlinearSolveBase._solve_adjoint(
2626
prob, sensealg, u0, p,
2727
originator, args...;
2828
kwargs...)
29+
30+
# when using mooncake ∂sol would be a NamedTuple Tangent with cotangents of all the solution struct's fields.
31+
# However the pullback for this rule - "steadystatebackpass" as defined in SciMLSensitivity/src/concrete_solve.jl/
32+
# handles AD only when ∂sol is a ChainRulesCore.AbstractThunk object or a sol.u vector and similar data structures (not namedtuples).
33+
# When using Mooncake, we pass in sol.u to inner_thunking_pb directly as this is the only field relevant to the solution's cotangent (given solve_up, AbstractNonlinearProblem setting).
34+
35+
function solve_up_adjoint(∂sol)
36+
return inner_thunking_pb(∂sol isa Tangent{Any,<:NamedTuple} ? ∂sol.u : ∂sol)
37+
end
38+
return primal, solve_up_adjoint
2939
end
3040

3141
end

lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,25 @@ module NonlinearSolveBaseMooncakeExt
22

33
using NonlinearSolveBase, Mooncake
44
using SciMLBase: SciMLBase
5-
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
6-
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
7-
NoPullback
5+
using Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
6+
@from_chainrules, @zero_adjoint, @mooncake_overlay, MinimalCtx,
7+
NoPullback
88

9-
@from_rrule(MinimalCtx,
10-
Tuple{
11-
typeof(NonlinearSolveBase.solve_up),
12-
SciMLBase.AbstractNonlinearProblem,
13-
Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm},
14-
Any,
15-
Any,
16-
Any
17-
},
18-
true,)
9+
@from_chainrules MinimalCtx Tuple{typeof(NonlinearSolveBase.solve_up),
10+
SciMLBase.AbstractNonlinearProblem,
11+
Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm},
12+
Any,
13+
Any,
14+
Any
15+
} true
1916

2017
# Dispatch for auto-alg
21-
@from_rrule(MinimalCtx,
22-
Tuple{
23-
typeof(NonlinearSolveBase.solve_up),
24-
SciMLBase.AbstractNonlinearProblem,
25-
Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm},
26-
Any,
27-
Any
28-
},
29-
true,)
18+
@from_chainrules MinimalCtx Tuple{
19+
typeof(NonlinearSolveBase.solve_up),
20+
SciMLBase.AbstractNonlinearProblem,
21+
Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm},
22+
Any,
23+
Any
24+
} true
3025

3126
end

test/adjoint_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@
2323

2424
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
2525
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
26-
@test_broken ∂p_forwarddiff ∂p_mooncake
26+
@test ∂p_forwarddiff ∂p_mooncake
2727
end

0 commit comments

Comments
 (0)