Skip to content

Commit a5b4d3f

Browse files
fix: call u0_constructor on resid_prototype
1 parent f138b5b commit a5b4d3f

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,6 @@ function hessian_sparsity(sys::NonlinearSystem)
345345
unknowns(sys)) for eq in equations(sys)]
346346
end
347347

348-
function calculate_resid_prototype(N, u0, p)
349-
u0ElType = u0 === nothing ? Float64 : eltype(u0)
350-
if SciMLStructures.isscimlstructure(p)
351-
u0ElType = promote_type(
352-
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
353-
u0ElType)
354-
end
355-
return zeros(u0ElType, N)
356-
end
357-
358348
"""
359349
```julia
360350
SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
@@ -381,6 +371,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
381371
eval_module = @__MODULE__,
382372
sparse = false, simplify = false,
383373
initialization_data = nothing, cse = true,
374+
resid_prototype = nothing,
384375
kwargs...) where {iip}
385376
if !iscomplete(sys)
386377
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`")
@@ -402,12 +393,6 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
402393
observedfun = ObservedFunctionCache(
403394
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
404395

405-
if length(dvs) == length(equations(sys))
406-
resid_prototype = nothing
407-
else
408-
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
409-
end
410-
411396
NonlinearFunction{iip}(f;
412397
sys = sys,
413398
jac = _jac === nothing ? nothing : _jac,

src/systems/problem_utils.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,22 @@ function float_type_from_varmap(varmap, floatT = Bool)
10811081
return float(floatT)
10821082
end
10831083

1084+
"""
1085+
$(TYPEDSIGNATURES)
1086+
1087+
Calculate the `resid_prototype` for a `NonlinearFunction` with `N` equations and the
1088+
provided `u0` and `p`.
1089+
"""
1090+
function calculate_resid_prototype(N::Int, u0, p)
1091+
u0ElType = u0 === nothing ? Float64 : eltype(u0)
1092+
if SciMLStructures.isscimlstructure(p)
1093+
u0ElType = promote_type(
1094+
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
1095+
u0ElType)
1096+
end
1097+
return zeros(u0ElType, N)
1098+
end
1099+
10841100
"""
10851101
$(TYPEDSIGNATURES)
10861102
@@ -1292,7 +1308,14 @@ function process_SciMLProblem(
12921308
end
12931309
initialization_data = SciMLBase.remake_initialization_data(
12941310
kwargs.initialization_data, kwargs, u0, t0, p, u0, p)
1295-
kwargs = merge(kwargs,)
1311+
kwargs = merge(kwargs, (; initialization_data))
1312+
end
1313+
1314+
if constructor <: NonlinearFunction && length(dvs) != length(eqs)
1315+
kwargs = merge(kwargs,
1316+
(;
1317+
resid_prototype = u0_constructor(calculate_resid_prototype(
1318+
length(eqs), u0, p))))
12961319
end
12971320

12981321
f = constructor(sys, dvs, ps, u0; p = p,

test/nonlinearsystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,19 @@ end
442442
@test !in(D(y), vs)
443443
end
444444
end
445+
446+
@testset "oop `NonlinearLeastSquaresProblem` with `u0 === nothing`" begin
447+
@variables x y
448+
@named sys = NonlinearSystem([0 ~ x - y], [], []; observed = [x ~ 1.0, y ~ 1.0])
449+
prob = NonlinearLeastSquaresProblem{false}(complete(sys), nothing)
450+
sol = solve(prob)
451+
resid = sol.resid
452+
@test resid == [0.0]
453+
@test resid isa Vector
454+
prob = NonlinearLeastSquaresProblem{false}(
455+
complete(sys), nothing; u0_constructor = splat(SVector))
456+
sol = solve(prob)
457+
resid = sol.resid
458+
@test resid == [0.0]
459+
@test resid isa SVector
460+
end

0 commit comments

Comments
 (0)