1+
2+ SciMLBase. supports_opt_cache_interface (:: LBFGS ) = true
3+ SciMLBase. allowsbounds (:: LBFGS ) = true
4+ SciMLBase. requiresgradient (:: LBFGS ) = true
5+ SciMLBase. allowsconstraints (:: LBFGS ) = true
6+ SciMLBase. requiresconsjac (:: LBFGS ) = true
7+
8+ function task_message_to_string (task:: Vector{UInt8} )
9+ return String (task)
10+ end
11+
12+ function __map_optimizer_args (cache:: Optimization.OptimizationCache , opt:: LBFGS ;
13+ callback = nothing ,
14+ maxiters:: Union{Number, Nothing} = nothing ,
15+ maxtime:: Union{Number, Nothing} = nothing ,
16+ abstol:: Union{Number, Nothing} = nothing ,
17+ reltol:: Union{Number, Nothing} = nothing ,
18+ verbose:: Bool = false ,
19+ kwargs... )
20+ if ! isnothing (abstol)
21+ @warn " common abstol is currently not used by $(opt) "
22+ end
23+ if ! isnothing (maxtime)
24+ @warn " common abstol is currently not used by $(opt) "
25+ end
26+
27+ mapped_args = (;)
28+
29+ if cache. lb != = nothing && cache. ub != = nothing
30+ mapped_args = (; mapped_args... , lb = cache. lb, ub = cache. ub)
31+ end
32+
33+ if ! isnothing (maxiters)
34+ mapped_args = (; mapped_args... , maxiter = maxiters)
35+ end
36+
37+ if ! isnothing (reltol)
38+ mapped_args = (; mapped_args... , pgtol = reltol)
39+ end
40+
41+ return mapped_args
42+ end
43+
44+ function SciMLBase. __solve (cache:: OptimizationCache {
45+ F,
46+ RC,
47+ LB,
48+ UB,
49+ LC,
50+ UC,
51+ S,
52+ O,
53+ D,
54+ P,
55+ C
56+ }) where {
57+ F,
58+ RC,
59+ LB,
60+ UB,
61+ LC,
62+ UC,
63+ S,
64+ O < :
65+ LBFGS,
66+ D,
67+ P,
68+ C
69+ }
70+ maxiters = Optimization. _check_and_convert_maxiters (cache. solver_args. maxiters)
71+
72+ local x
73+
74+ solver_kwargs = __map_optimizer_args (cache, cache. opt; maxiters, cache. solver_args... )
75+
76+ if ! isnothing (cache. f. cons)
77+ eq_inds = [cache. lcons[i] == cache. ucons[i] for i in eachindex (cache. lcons)]
78+ ineq_inds = (! ). (eq_inds)
79+
80+ τ = cache. opt. τ
81+ γ = cache. opt. γ
82+ λmin = cache. opt. λmin
83+ λmax = cache. opt. λmax
84+ μmin = cache. opt. μmin
85+ μmax = cache. opt. μmax
86+ ϵ = cache. opt. ϵ
87+
88+ λ = zeros (eltype (cache. u0), sum (eq_inds))
89+ μ = zeros (eltype (cache. u0), sum (ineq_inds))
90+
91+ cons_tmp = zeros (eltype (cache. u0), length (cache. lcons))
92+ cache. f. cons (cons_tmp, cache. u0)
93+ ρ = max (1e-6 , min (10 , 2 * (abs (cache. f (cache. u0, cache. p))) / norm (cons_tmp)))
94+
95+ _loss = function (θ)
96+ x = cache. f (θ, cache. p)
97+ cons_tmp .= zero (eltype (θ))
98+ cache. f. cons (cons_tmp, θ)
99+ cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache. lcons[eq_inds]
100+ cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache. ucons[ineq_inds]
101+ opt_state = Optimization. OptimizationState (u = θ, objective = x[1 ])
102+ if cache. callback (opt_state, x... )
103+ error (" Optimization halted by callback." )
104+ end
105+ return x[1 ] + sum (@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2 )) +
106+ 1 / (2 * ρ) * sum ((max .(Ref (0.0 ), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2 )
107+ end
108+
109+ prev_eqcons = zero (λ)
110+ θ = cache. u0
111+ β = max .(cons_tmp[ineq_inds], Ref (0.0 ))
112+ prevβ = zero (β)
113+ eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex (ineq_inds)]
114+ ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex (ineq_inds)]
115+ eqidxs = eqidxs[eqidxs .!= nothing ]
116+ ineqidxs = ineqidxs[ineqidxs .!= nothing ]
117+ function aug_grad (G, θ)
118+ cache. f. grad (G, θ)
119+ if ! isnothing (cache. f. cons_jac_prototype)
120+ J = Float64 .(cache. f. cons_jac_prototype)
121+ else
122+ J = zeros ((length (cache. lcons), length (θ)))
123+ end
124+ cache. f. cons_j (J, θ)
125+ __tmp = zero (cons_tmp)
126+ cache. f. cons (__tmp, θ)
127+ __tmp[eq_inds] .= __tmp[eq_inds] .- cache. lcons[eq_inds]
128+ __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache. ucons[ineq_inds]
129+ G .+ = sum (
130+ λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
131+ for (i, idx) in enumerate (eqidxs);
132+ init = zero (G)) # should be jvp
133+ G .+ = sum (
134+ 1 / ρ * (max .(Ref (0.0 ), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :])
135+ for (i, idx) in enumerate (ineqidxs);
136+ init = zero (G)) # should be jvp
137+ end
138+
139+ opt_ret = ReturnCode. MaxIters
140+ n = length (cache. u0)
141+
142+ sol = solve (.... )
143+
144+ solver_kwargs = Base. structdiff (solver_kwargs, (; lb = nothing , ub = nothing ))
145+
146+ for i in 1 : maxiters
147+ prev_eqcons .= cons_tmp[eq_inds] .- cache. lcons[eq_inds]
148+ prevβ .= copy (β)
149+ res = optimizer (_loss, aug_grad, θ, bounds; solver_kwargs... ,
150+ m = cache. opt. m, pgtol = sqrt (ϵ), maxiter = maxiters / 100 )
151+ # @show res[2]
152+ # @show res[1]
153+ # @show cons_tmp
154+ # @show λ
155+ # @show β
156+ # @show μ
157+ # @show ρ
158+ θ = res[2 ]
159+ cons_tmp .= 0.0
160+ cache. f. cons (cons_tmp, θ)
161+ λ = max .(min .(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache. lcons[eq_inds])), λmin)
162+ β = max .(cons_tmp[ineq_inds], - 1 .* μ ./ ρ)
163+ μ = min .(μmax, max .(μ .+ ρ * cons_tmp[ineq_inds], μmin))
164+ if max (norm (cons_tmp[eq_inds] .- cache. lcons[eq_inds], Inf ), norm (β, Inf )) >
165+ τ * max (norm (prev_eqcons, Inf ), norm (prevβ, Inf ))
166+ ρ = γ * ρ
167+ end
168+ if norm (
169+ (cons_tmp[eq_inds] .- cache. lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf ) <
170+ ϵ && norm (β, Inf ) < ϵ
171+ opt_ret = ReturnCode. Success
172+ break
173+ end
174+ end
175+ end
176+
177+ stats = Optimization. OptimizationStats (; iterations = maxiters,
178+ time = 0.0 , fevals = maxiters, gevals = maxiters)
179+ return SciMLBase. build_solution (
180+ cache, cache. opt, res[2 ], cache. f (res[2 ], cache. p)[1 ],
181+ stats = stats, retcode = opt_ret)
182+ end
0 commit comments