diff --git a/lib/OptimizationBase/ext/OptimizationZygoteExt.jl b/lib/OptimizationBase/ext/OptimizationZygoteExt.jl index 461efe699..1e26909af 100644 --- a/lib/OptimizationBase/ext/OptimizationZygoteExt.jl +++ b/lib/OptimizationBase/ext/OptimizationZygoteExt.jl @@ -208,7 +208,7 @@ function OptimizationBase.instantiate_function( lag_extras = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false)) - lag_hess_prototype = zeros(Bool, num_cons, length(x)) + lag_hess_prototype = zeros(Bool, length(x), length(x)) function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -288,6 +288,18 @@ function OptimizationBase.instantiate_function( f, x, adtype, p, num_cons; kwargs...) end +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::DifferentiationInterface.SecondOrder{ + <:ADTypes.AbstractADType, <:ADTypes.AutoZygote}, + num_cons = 0; kwargs...) + x = cache.u0 + p = cache.p + + return OptimizationBase.instantiate_function( + f, x, adtype, p, num_cons; kwargs...) +end + function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote, @@ -575,4 +587,15 @@ function OptimizationBase.instantiate_function( return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...) end +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoSparse{<:DifferentiationInterface.SecondOrder{ + <:ADTypes.AbstractADType, <:ADTypes.AutoZygote}}, + num_cons = 0; kwargs...) + x = cache.u0 + p = cache.p + + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...) +end + end diff --git a/lib/OptimizationBase/src/OptimizationDIExt.jl b/lib/OptimizationBase/src/OptimizationDIExt.jl index adb84b55f..017506a03 100644 --- a/lib/OptimizationBase/src/OptimizationDIExt.jl +++ b/lib/OptimizationBase/src/OptimizationDIExt.jl @@ -198,7 +198,7 @@ function instantiate_function( lag_prep = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p)) - lag_hess_prototype = zeros(Bool, num_cons, length(x)) + lag_hess_prototype = zeros(Bool, length(x), length(x)) function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -457,7 +457,7 @@ function instantiate_function( lag_prep = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p)) - lag_hess_prototype = zeros(Bool, num_cons, length(x)) + lag_hess_prototype = zeros(Bool, length(x), length(x)) function lag_h!(θ, σ, λ) if σ == zero(eltype(θ)) diff --git a/lib/OptimizationBase/src/cache.jl b/lib/OptimizationBase/src/cache.jl index 375d5aaa1..9454e4f03 100644 --- a/lib/OptimizationBase/src/cache.jl +++ b/lib/OptimizationBase/src/cache.jl @@ -62,7 +62,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt; g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), - cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) + cons_vjp = SciMLBase.allowsconsvjp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) if structural_analysis obj_res, cons_res = symify_cache(f, prob, num_cons, manifold) diff --git a/lib/OptimizationBase/test/adtests.jl b/lib/OptimizationBase/test/adtests.jl index 6fe4eea05..1a9483deb 100644 --- a/lib/OptimizationBase/test/adtests.jl +++ b/lib/OptimizationBase/test/adtests.jl @@ -144,6 +144,17 @@ optprob.cons_h(H3, x0) optprob.lag_h(H4, x0, σ, μ) @test H4≈σ * H2 + μ[1] * H3[1] rtol=1e-6 + # Test that the AD-generated lag_hess_prototype has correct dimensions + @test !isnothing(optprob.lag_hess_prototype) + @test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n + + # Test that we can actually use it as a buffer + if !isnothing(optprob.lag_hess_prototype) + H_proto = similar(optprob.lag_hess_prototype, Float64) + optprob.lag_h(H_proto, x0, σ, μ) + @test H_proto ≈ σ * H2 + μ[1] * H3[1] rtol=1e-6 + end + G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) @@ -257,6 +268,17 @@ optprob.cons_h(H3, x0) optprob.lag_h(H4, x0, σ, μ) @test H4≈σ * H2 + μ[1] * H3[1] rtol=1e-6 + # Test that the AD-generated lag_hess_prototype has correct dimensions + @test !isnothing(optprob.lag_hess_prototype) + @test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n + + # Test that we can actually use it as a buffer (this would fail with the bug) + if !isnothing(optprob.lag_hess_prototype) + H_proto = similar(optprob.lag_hess_prototype, Float64) + optprob.lag_h(H_proto, x0, σ, μ) + @test H_proto ≈ σ * H2 + μ[1] * H3[1] rtol=1e-6 + end + G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) @@ -490,6 +512,17 @@ end optprob.lag_h(H4, x0, σ, μ) @test H4≈σ * H1 + sum(μ .* H3) rtol=1e-6 + # Test that the AD-generated lag_hess_prototype has correct dimensions + @test !isnothing(optprob.lag_hess_prototype) + @test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n + + # Test that we can actually use it as a buffer (this would fail with the bug) + if !isnothing(optprob.lag_hess_prototype) + H_proto = similar(optprob.lag_hess_prototype, Float64) + optprob.lag_h(H_proto, x0, σ, μ) + @test H_proto ≈ σ * H1 + sum(μ .* H3) rtol=1e-6 + end + G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2)