Skip to content

Commit db21cf8

Browse files
fix: fix jacobian and taylor implementations
1 parent 2026034 commit db21cf8

File tree

2 files changed

+138
-102
lines changed

2 files changed

+138
-102
lines changed

lib/NonlinearSolveHomotopyContinuation/src/interface_types.jl

Lines changed: 106 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,109 @@ struct Inplace <: HomotopySystemVariant end
44
struct OutOfPlace <: HomotopySystemVariant end
55
struct Scalar <: HomotopySystemVariant end
66

7+
@concrete struct ComplexJacobianWrapper{variant <: HomotopySystemVariant}
8+
f
9+
end
10+
11+
function (cjw::ComplexJacobianWrapper{Inplace})(u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
12+
x = reinterpret(Complex{T}, x)
13+
u = reinterpret(Complex{T}, u)
14+
cjw.f(u, x, p)
15+
u = parent(u)
16+
return u
17+
end
18+
19+
function (cjw::ComplexJacobianWrapper{OutOfPlace})(u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
20+
x = reinterpret(Complex{T}, x)
21+
u_tmp = cjw.f(x, p)
22+
u_tmp = reinterpret(T, u_tmp)
23+
copyto!(u, u_tmp)
24+
return u
25+
end
26+
27+
function (cjw::ComplexJacobianWrapper{Scalar})(u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
28+
x = reinterpret(Complex{T}, x)
29+
u_tmp = cjw.f(x[1], p)
30+
u[1] = real(u_tmp)
31+
u[2] = imag(u_tmp)
32+
return u
33+
end
34+
735
@concrete struct HomotopySystemWrapper{variant <: HomotopySystemVariant} <: HC.AbstractSystem
8-
prob
36+
f
37+
jac
38+
p
939
autodiff
1040
prep
1141
vars
1242
taylorvars
1343
jacobian_buffers
1444
end
1545

16-
Base.size(sys::HomotopySystemWrapper) = (length(sys.prob.u0), length(sys.prob.u0))
46+
Base.size(sys::HomotopySystemWrapper) = (length(sys.vars), length(sys.vars))
1747
HC.ModelKit.variables(sys::HomotopySystemWrapper) = sys.vars
1848

1949
function HC.ModelKit.evaluate!(u, sys::HomotopySystemWrapper{Inplace}, x, p = nothing)
20-
sys.prob.f.f(u, x, parameter_values(sys.prob))
50+
sys.f(u, x, sys.p)
2151
return u
2252
end
2353

2454
function HC.ModelKit.evaluate!(u, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing)
25-
values = sys.prob.f.f(x, parameter_values(sys.prob))
55+
values = sys.f(x, sys.p)
2656
copyto!(u, values)
2757
return u
2858
end
2959

3060
function HC.ModelKit.evaluate!(u, sys::HomotopySystemWrapper{Scalar}, x, p = nothing)
31-
u[1] = sys.prob.f.f(only(x), parameter_values(sys.prob))
61+
u[1] = sys.f(x[1], sys.p)
3262
return u
3363
end
3464

3565
function HC.ModelKit.evaluate_and_jacobian!(u, U, sys::HomotopySystemWrapper{Inplace}, x, p = nothing)
36-
f = sys.prob.f
37-
p = parameter_values(sys.prob)
38-
if SciMLBase.has_jac(f)
39-
f.f(u, x, p)
40-
f.jac(U, x, p)
41-
return
42-
end
43-
44-
x_tmp, u_tmp, U_tmp = sys.jacobian_buffers
45-
copyto!(x_tmp, x)
46-
DI.value_and_jacobian!(f.f, u_tmp, U_tmp, sys.prep, sys.autodiff, x_tmp, DI.Constant(p))
47-
copyto!(u, u_tmp)
48-
copyto!(U, U_tmp)
66+
p = sys.p
67+
sys.f(u, x, p)
68+
sys.jac(U, x, p)
4969
return u, U
5070
end
5171

5272
function HC.ModelKit.evaluate_and_jacobian!(u, U, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing)
53-
f = sys.prob.f
54-
p = parameter_values(sys.prob)
55-
if SciMLBase.has_jac(f)
56-
u_tmp = f.f(x, p)
57-
copyto!(u, u_tmp)
58-
j_tmp = f.jac(U, x, p)
59-
copyto!(U, j_tmp)
60-
return
61-
end
62-
x_tmp, U_tmp = sys.jacobian_buffers
63-
copyto!(x_tmp, x)
64-
u_tmp, _ = DI.value_and_jacobian!(f.f, U_tmp, sys.prep, sys.autodiff, x_tmp, DI.Constant(p))
73+
p = sys.p
74+
u_tmp = sys.f(x, p)
6575
copyto!(u, u_tmp)
66-
copyto!(U, U_tmp)
76+
j_tmp = sys.jac(x, p)
77+
copyto!(U, j_tmp)
6778
return u, U
6879
end
6980

7081
function HC.ModelKit.evaluate_and_jacobian!(u, U, sys::HomotopySystemWrapper{Scalar}, x, p = nothing)
71-
f = sys.prob.f
72-
p = parameter_values(sys.prob)
73-
if SciMLBase.has_jac(f)
74-
HC.ModelKit.evaluate!(u, sys, x, p)
75-
U[1] = f.jac(only(x), p)
76-
else
77-
x = real(first(x))
78-
u[1], U[1] = DI.value_and_derivative(f.f, sys.prep, sys.autodiff, x, DI.Constant(p))
79-
end
82+
p = sys.p
83+
u[1] = sys.f(x[1], p)
84+
U[1] = sys.jac(x[1], p)
8085
return u, U
8186
end
8287

83-
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{Inplace}, x::HC.ModelKit.TaylorVector{M}, p = nothing) where {N, M}
84-
f = sys.prob.f
85-
p = parameter_values(sys.prob)
86-
buffer, vars = sys.taylorvars
88+
for V in (Inplace, OutOfPlace, Scalar)
89+
@eval function HC.ModelKit.evaluate_and_jacobian!(u, U, sys::HomotopySystemWrapper{$V, F, J}, x, p = nothing) where {F, J <: ComplexJacobianWrapper}
90+
p = sys.p
91+
U_tmp = sys.jacobian_buffers
92+
x = reinterpret(Float64, x)
93+
u = reinterpret(Float64, u)
94+
DI.value_and_jacobian!(sys.jac, u, U_tmp, sys.prep, sys.autodiff, x, DI.Constant(p))
95+
U = reinterpret(Float64, U)
96+
@inbounds for j in axes(U, 2)
97+
jj = 2j - 1
98+
for i in axes(U, 1)
99+
U[i, j] = U_tmp[i, jj]
100+
end
101+
end
102+
u = parent(u)
103+
U = parent(U)
104+
105+
return u, U
106+
end
107+
end
108+
109+
function update_taylorvars_from_taylorvector!(vars, x::HC.ModelKit.TaylorVector{M}) where {M}
87110
for i in eachindex(vars)
88111
for j in 0:M-1
89112
vars[i][j] = x[i, j + 1]
@@ -92,44 +115,57 @@ function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWra
92115
vars[i][j] = zero(vars[i][j])
93116
end
94117
end
95-
f.f(buffer, vars, p)
96-
if u isa Vector
97-
for i in eachindex(vars)
98-
u[i] = buffer[i][N]
99-
end
100-
else
101-
for i in eachindex(vars)
102-
u[i] = ntuple(j -> buffer[i][j - 1], Val(N + 1))
118+
end
119+
120+
function update_taylorvars_from_taylorvector!(vars, x::AbstractVector)
121+
for i in eachindex(vars)
122+
vars[i][0] = x[i]
123+
for j in 1:4
124+
vars[i][j] = zero(vars[i][j])
103125
end
104126
end
105-
return u
106127
end
107128

108-
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{OutOfPlace}, x::HC.ModelKit.TaylorVector{M}, p = nothing) where {N, M}
109-
f = sys.prob.f
110-
p = parameter_values(sys.prob)
111-
vars = sys.taylorvars
129+
function update_maybe_taylorvector_from_taylorvars!(u::Vector, vars, buffer, ::Val{N}) where {N}
112130
for i in eachindex(vars)
113-
for j in 0:M
114-
vars[i][j] = x[i, j + 1]
115-
end
131+
u[i] = buffer[i][N]
116132
end
117-
buffer = f.f(vars, p)
133+
end
134+
135+
function update_maybe_taylorvector_from_taylorvars!(u::HC.ModelKit.TaylorVector, vars, buffer, ::Val{N}) where {N}
118136
for i in eachindex(vars)
119137
u[i] = ntuple(j -> buffer[i][j - 1], Val(N + 1))
120138
end
121139
end
122140

123-
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{Scalar}, x::HC.ModelKit.TaylorVector{M}, p = nothing) where {N, M}
124-
f = sys.prob.f
125-
p = parameter_values(sys.prob)
141+
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{Inplace}, x, p = nothing) where {N}
142+
f = sys.f
143+
p = sys.p
144+
buffer, vars = sys.taylorvars
145+
update_taylorvars_from_taylorvector!(vars, x)
146+
f(buffer, vars, p)
147+
update_maybe_taylorvector_from_taylorvars!(u, vars, buffer, Val(N))
148+
return u
149+
end
150+
151+
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing) where {N}
152+
f = sys.f
153+
p = sys.p
154+
vars = sys.taylorvars
155+
update_taylorvars_from_taylorvector!(vars, x)
156+
buffer = f(vars, p)
157+
update_maybe_taylorvector_from_taylorvars!(u, vars, buffer, Val(N))
158+
return u
159+
end
160+
161+
function HC.ModelKit.taylor!(u::AbstractVector, ::Val{N}, sys::HomotopySystemWrapper{Scalar}, x, p = nothing) where {N}
162+
f = sys.f
163+
p = sys.p
126164
var = sys.taylorvars
127-
for i in 0:M
128-
var[i] = x[1, i + 1]
129-
end
130-
taylor = f.f(var, p)
131-
val = ntuple(i -> taylor[i - 1], Val(N + 1))
132-
u[1] = val
165+
update_taylorvars_from_taylorvector!((var,), x)
166+
buffer = f(var, p)
167+
update_maybe_taylorvector_from_taylorvars!(u, (var,), (buffer,), Val(N))
168+
return u
133169
end
134170

135171
@concrete struct GuessHomotopy <: HC.AbstractHomotopy
@@ -165,14 +201,6 @@ end
165201
HC.ModelKit.taylor!(u, v::Val{N}, H::GuessHomotopy, tx, t, incremental::Bool) where {N} =
166202
HC.ModelKit.taylor!(u, v, H, tx, t)
167203

168-
function HC.ModelKit.taylor!(u, ::Val{1}, h::GuessHomotopy, x::AbstractVector{<:Number}, t, p = nothing)
169-
HC.ModelKit.evaluate!(u, h.sys, x, p)
170-
@inbounds for i in eachindex(u)
171-
u[i] -= h.fu0[i]
172-
end
173-
return u
174-
end
175-
176204
function HC.ModelKit.taylor!(u, ::Val{N}, h::GuessHomotopy, x, t, p = nothing) where {N}
177205
HC.ModelKit.taylor!(h.taylorbuffer, Val(N), h.sys, x, p)
178206
@inbounds for i in eachindex(u)

lib/NonlinearSolveHomotopyContinuation/src/solve.jl

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,22 @@ function homotopy_continuation_preprocessing(prob::NonlinearProblem, alg::Homoto
1111
isscalar = u0 isa Number
1212
iip = SciMLBase.isinplace(prob)
1313

14+
variant = iip ? Inplace : isscalar ? Scalar : OutOfPlace
15+
1416
# jacobian handling
15-
if SciMLBase.has_jac(f)
17+
if SciMLBase.has_jac(f.f)
1618
# use if present
1719
prep = nothing
18-
elseif iip
19-
# prepare a DI jacobian if not
20-
prep = DI.prepare_jacobian(prob.f.f, copy(state_values(prob)), alg.autodiff, u0, DI.Constant(p))
21-
elseif isscalar
22-
prep = DI.prepare_derivative(prob.f.f, alg.autodiff, u0, DI.Constant(p))
20+
jac = f.jac
2321
else
24-
prep = DI.prepare_jacobian(prob.f.f, alg.autodiff, u0, DI.Constant(p))
22+
# prepare a DI jacobian if not
23+
jac = ComplexJacobianWrapper{variant}(f.f.f)
24+
tmp = if isscalar
25+
Vector{Float64}(undef, 2)
26+
else
27+
similar(u0, Float64, 2length(u0))
28+
end
29+
prep = DI.prepare_jacobian(jac, tmp, alg.autodiff, copy(tmp), DI.Constant(p))
2530
end
2631

2732
# variables for HC to use
@@ -41,16 +46,13 @@ function homotopy_continuation_preprocessing(prob::NonlinearProblem, alg::Homoto
4146
end
4247

4348
jacobian_buffers = if isscalar
44-
nothing
45-
elseif iip
46-
(similar(u0), similar(u0), similar(u0, length(u0), length(u0)))
49+
Matrix{Float64}(undef, 2, 2)
4750
else
48-
(similar(u0), similar(u0, length(u0), length(u0)))
51+
similar(u0, Float64, 2length(u0), 2length(u0))
4952
end
5053

5154
# HC-compatible system
52-
variant = iip ? Inplace : isscalar ? Scalar : OutOfPlace
53-
hcsys = HomotopySystemWrapper{variant}(prob, alg.autodiff, prep, vars, taylorvars, jacobian_buffers)
55+
hcsys = HomotopySystemWrapper{variant}(f.f.f, jac, p, alg.autodiff, prep, vars, taylorvars, jacobian_buffers)
5456

5557
return f, hcsys
5658
end
@@ -113,14 +115,14 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f
113115
fu0 = NonlinearSolveBase.Utils.evaluate_f(prob, u0_p)
114116

115117
homotopy = GuessHomotopy(hcsys, fu0)
116-
if u0_p isa Number
117-
u0_p = [u0_p]
118+
orig_sol = HC.solve(homotopy, u0_p isa Number ? [[u0_p]] : [u0_p]; alg.kwargs..., kwargs...)
119+
realsols = map(res -> res.solution, HC.results(orig_sol; only_real = true))
120+
if u0 isa Number
121+
realsols = map(only, realsols)
118122
end
119-
orig_sol = HC.solve(homotopy, [u0_p]; alg.kwargs..., kwargs...)
120-
realsols = HC.results(orig_sol; only_real = true)
121123

122124
# no real solutions or infeasible solution
123-
if isempty(realsols) || any(<=(denominator_abstol), f.denominator(real.(only(realsols).solution), p))
125+
if isempty(realsols) || any(<=(denominator_abstol), map(abs, f.denominator(real.(only(realsols)), p)))
124126
retcode = if isempty(realsols)
125127
SciMLBase.ReturnCode.ConvergenceFailure
126128
else
@@ -129,15 +131,21 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f
129131
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0)
130132
return SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol)
131133
end
132-
realsol = only(realsols)
134+
135+
realsol = real(only(realsols))
133136
T = eltype(u0)
134-
validsols = f.unpolynomialize(realsol.solution, p)
135-
sol, idx = findmin(validsols) do sol
136-
norm(sol - u0)
137+
validsols = f.unpolynomialize(realsol, p)
138+
_, idx = findmin(validsols) do sol
139+
norm(sol - u0_p)
140+
end
141+
u = map(real, validsols[idx])
142+
143+
if u0 isa Number
144+
u = only(u)
137145
end
138-
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0)
146+
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u)
139147

140148
retcode = SciMLBase.ReturnCode.Success
141-
return SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol)
149+
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original = orig_sol)
142150
end
143151

0 commit comments

Comments
 (0)