Skip to content
Merged

Bugfix #1069

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion lib/OptimizationBase/ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function OptimizationBase.instantiate_function(
lag_extras = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false))
lag_hess_prototype = zeros(Bool, num_cons, length(x))
lag_hess_prototype = zeros(Bool, length(x), length(x))

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
Expand Down Expand Up @@ -288,6 +288,18 @@ function OptimizationBase.instantiate_function(
f, x, adtype, p, num_cons; kwargs...)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::DifferentiationInterface.SecondOrder{
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote},
num_cons = 0; kwargs...)
x = cache.u0
p = cache.p

return OptimizationBase.instantiate_function(
f, x, adtype, p, num_cons; kwargs...)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x,
adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote,
Expand Down Expand Up @@ -575,4 +587,15 @@ function OptimizationBase.instantiate_function(
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AutoSparse{<:DifferentiationInterface.SecondOrder{
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}},
num_cons = 0; kwargs...)
x = cache.u0
p = cache.p

return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

end
4 changes: 2 additions & 2 deletions lib/OptimizationBase/src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function instantiate_function(
lag_prep = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = zeros(Bool, num_cons, length(x))
lag_hess_prototype = zeros(Bool, length(x), length(x))

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
Expand Down Expand Up @@ -457,7 +457,7 @@ function instantiate_function(
lag_prep = prepare_hessian(
lagrangian, soadtype, x, Constant(one(eltype(x))),
Constant(ones(eltype(x), num_cons)), Constant(p))
lag_hess_prototype = zeros(Bool, num_cons, length(x))
lag_hess_prototype = zeros(Bool, length(x), length(x))

function lag_h!(θ, σ, λ)
if σ == zero(eltype(θ))
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationBase/src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),
hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt),
fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt),
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
cons_vjp = SciMLBase.allowsconsvjp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))

if structural_analysis
obj_res, cons_res = symify_cache(f, prob, num_cons, manifold)
Expand Down
33 changes: 33 additions & 0 deletions lib/OptimizationBase/test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ optprob.cons_h(H3, x0)
optprob.lag_h(H4, x0, σ, μ)
@test H4≈σ * H2 + μ[1] * H3[1] rtol=1e-6

# Test that the AD-generated lag_hess_prototype has correct dimensions
@test !isnothing(optprob.lag_hess_prototype)
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n

# Test that we can actually use it as a buffer
if !isnothing(optprob.lag_hess_prototype)
H_proto = similar(optprob.lag_hess_prototype, Float64)
optprob.lag_h(H_proto, x0, σ, μ)
@test H_proto ≈ σ * H2 + μ[1] * H3[1] rtol=1e-6
end

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

Expand Down Expand Up @@ -257,6 +268,17 @@ optprob.cons_h(H3, x0)
optprob.lag_h(H4, x0, σ, μ)
@test H4≈σ * H2 + μ[1] * H3[1] rtol=1e-6

# Test that the AD-generated lag_hess_prototype has correct dimensions
@test !isnothing(optprob.lag_hess_prototype)
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n

# Test that we can actually use it as a buffer (this would fail with the bug)
if !isnothing(optprob.lag_hess_prototype)
H_proto = similar(optprob.lag_hess_prototype, Float64)
optprob.lag_h(H_proto, x0, σ, μ)
@test H_proto ≈ σ * H2 + μ[1] * H3[1] rtol=1e-6
end

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

Expand Down Expand Up @@ -490,6 +512,17 @@ end
optprob.lag_h(H4, x0, σ, μ)
@test H4≈σ * H1 + sum(μ .* H3) rtol=1e-6

# Test that the AD-generated lag_hess_prototype has correct dimensions
@test !isnothing(optprob.lag_hess_prototype)
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n

# Test that we can actually use it as a buffer (this would fail with the bug)
if !isnothing(optprob.lag_hess_prototype)
H_proto = similar(optprob.lag_hess_prototype, Float64)
optprob.lag_h(H_proto, x0, σ, μ)
@test H_proto ≈ σ * H1 + sum(μ .* H3) rtol=1e-6
end

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

Expand Down
Loading