Skip to content

Commit 701af97

Browse files
committed
Add callback tronls
1 parent 1f08101 commit 701af97

File tree

2 files changed

+84
-30
lines changed

2 files changed

+84
-30
lines changed

src/tronls.jl

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ nonlinear least-squares problems:
2929
3030
min ½‖F(x)‖² s.t. ℓ ≦ x ≦ u
3131
32+
For advanced usage, first define a `TronSolverNLS` to preallocate the memory used in the algorithm, and then call `solve!`:
33+
solver = TronSolverNLS(nls)
34+
solve!(solver, nls; kwargs...)
35+
3236
# Arguments
3337
- `nls::AbstractNLSModel{T, V}` represents the model to solve, see `NLPModels.jl`.
3438
The keyword arguments may include
@@ -43,12 +47,28 @@ The keyword arguments may include
4347
- `max_cgiter::Int = 50`: subproblem iteration limit.
4448
- `cgtol::T = T(0.1)`: subproblem tolerance.
4549
- `atol::T = √eps(T)`: absolute tolerance.
46-
- `rtol::T = √eps(T)`: relative tolerance, the algorithm stops when ||∇f(xᵏ)|| ≤ atol + rtol * ||∇f(x⁰)||.
50+
- `rtol::T = √eps(T)`: relative tolerance, the algorithm stops when ∇f(xᵏ) ≤ atol + rtol * ∇f(x⁰).
4751
- `verbose::Int = 0`: if > 0, display iteration details every `verbose` iteration.
4852
4953
# Output
5054
The value returned is a `GenericExecutionStats`, see `SolverCore.jl`.
5155
56+
# Callback
57+
The callback is called at each iteration.
58+
The expected signature of the callback is `callback(nlp, solver, stats)`, and its output is ignored.
59+
Changing any of the input arguments will affect the subsequent iterations.
60+
In particular, setting `stats.status = :user` will stop the algorithm.
61+
All relevant information should be available in `nlp` and `solver`.
62+
Notably, you can access, and modify, the following:
63+
- `solver.x`: current iterate;
64+
- `solver.gx`: current gradient;
65+
- `stats`: structure holding the output of the algorithm (`GenericExecutionStats`), which contains, among other things:
66+
- `stats.dual_feas`: norm of current gradient;
67+
- `stats.iter`: current iteration counter;
68+
- `stats.objective`: current objective function value;
69+
- `stats.status`: current status of the algorithm. Should be `:unknown` unless the algorithm attained a stopping criterion. Changing this to anything will stop the algorithm, but you should use `:user` to properly indicate the intention.
70+
- `stats.elapsed_time`: elapsed time in seconds.
71+
5272
# References
5373
This is an adaptation for bound-constrained nonlinear least-squares problems of the TRON method described in
5474
@@ -123,6 +143,7 @@ function SolverCore.solve!(
123143
solver::TronSolverNLS{T, V},
124144
nlp::AbstractNLSModel{T, V},
125145
stats::GenericExecutionStats{T, V};
146+
callback = (args...) -> nothing,
126147
subsolver_logger::AbstractLogger = NullLogger(),
127148
x::V = nlp.meta.x0,
128149
subsolver::Symbol = :lsmr,
@@ -152,7 +173,7 @@ function SolverCore.solve!(
152173

153174
iter = 0
154175
start_time = time()
155-
el_time = 0.0
176+
set_time!(stats, 0.0)
156177

157178
solver.x .= x
158179
x = solver.x
@@ -183,10 +204,11 @@ function SolverCore.solve!(
183204
ϵ = atol + rtol * πx
184205
fmin = min(-one(T), fx) / eps(eltype(x))
185206
optimal = πx <= ϵ
186-
tired = el_time > max_time || neval_obj(nlp) > max_eval 0
187207
unbounded = fx < fmin
188-
stalled = false
189-
status = :unknown
208+
209+
set_iter!(stats, 0)
210+
set_objective!(stats, fx)
211+
set_dual_residual!(stats, πx)
190212

191213
αC = one(T)
192214
tr = TRONTrustRegion(gt, min(max(one(T), πx / 10), 100))
@@ -195,7 +217,30 @@ function SolverCore.solve!(
195217
[Int, T, T, T, T, String],
196218
hdr_override = Dict(:f => "f(x)", :dual => "π", :radius => "Δ"),
197219
)
198-
while !(optimal || tired || stalled || unbounded)
220+
221+
set_status!(
222+
stats,
223+
get_status(
224+
nlp,
225+
elapsed_time = stats.elapsed_time,
226+
optimal = optimal,
227+
unbounded = unbounded,
228+
max_eval = max_eval,
229+
max_time = max_time,
230+
),
231+
)
232+
233+
callback(nlp, solver, stats)
234+
235+
done =
236+
(stats.status == :first_order) ||
237+
(stats.status == :max_eval) ||
238+
(stats.status == :max_time) ||
239+
(stats.status == :user) ||
240+
(stats.status == :small_step) ||
241+
(stats.status == :neg_pred)
242+
243+
while !done
199244
# Current iteration
200245
xc .= x
201246
fc = fx
@@ -207,7 +252,6 @@ function SolverCore.solve!(
207252
if cauchy_status != :success
208253
@error "Cauchy step returned: $cauchy_status"
209254
status = cauchy_status
210-
stalled = true
211255
continue
212256
end
213257
s, As, cgits, cginfo = with_logger(subsolver_logger) do
@@ -232,11 +276,10 @@ function SolverCore.solve!(
232276
ared, pred = aredpred!(tr, nlp, fc, fx, qs, x, s, slope)
233277
if pred 0
234278
status = :neg_pred
235-
stalled = true
236279
continue
237280
end
238281
tr.ratio = ared / pred
239-
verbose > 0 && mod(iter, verbose) == 0 && @info log_row([iter, fx, πx, Δ, tr.ratio, cginfo])
282+
verbose > 0 && mod(stats.iter, verbose) == 0 && @info log_row([stats.iter, fx, πx, Δ, tr.ratio, cginfo])
240283

241284
s_norm = nrm2(n, s)
242285
if num_success_iters == 0
@@ -267,32 +310,39 @@ function SolverCore.solve!(
267310
x .= xc
268311
end
269312

270-
iter += 1
271-
el_time = time() - start_time
272-
tired = el_time > max_time || neval_obj(nlp) > max_eval 0
313+
set_iter!(stats, stats.iter + 1)
314+
set_objective!(stats, fx)
315+
set_time!(stats, time() - start_time)
316+
set_dual_residual!(stats, πx)
317+
273318
optimal = πx <= ϵ
274319
unbounded = fx < fmin
275-
end
276-
verbose > 0 && @info log_row(Any[iter, fx, πx, tr.radius])
277320

278-
if tired
279-
if el_time > max_time
280-
status = :max_time
281-
elseif neval_obj(nlp) > max_eval 0
282-
status = :max_eval
283-
end
284-
elseif optimal
285-
status = :first_order
286-
elseif unbounded
287-
status = :unbounded
321+
set_status!(
322+
stats,
323+
get_status(
324+
nlp,
325+
elapsed_time = stats.elapsed_time,
326+
optimal = optimal,
327+
unbounded = unbounded,
328+
max_eval = max_eval,
329+
max_time = max_time,
330+
),
331+
)
332+
333+
callback(nlp, solver, stats)
334+
335+
done =
336+
(stats.status == :first_order) ||
337+
(stats.status == :max_eval) ||
338+
(stats.status == :max_time) ||
339+
(stats.status == :user) ||
340+
(stats.status == :small_step) ||
341+
(stats.status == :neg_pred)
288342
end
343+
verbose > 0 && @info log_row(Any[stats.iter, fx, πx, tr.radius])
289344

290-
set_status!(stats, status)
291345
set_solution!(stats, x)
292-
set_objective!(stats, fx)
293-
set_residuals!(stats, zero(T), πx)
294-
set_iter!(stats, iter)
295-
set_time!(stats, el_time)
296346
stats
297347
end
298348

@@ -403,7 +453,6 @@ function cauchy_ls(
403453
end
404454
# TODO: Correctly assess why this fails
405455
if α < sqrt(nextfloat(zero(α)))
406-
stalled = true
407456
search = false
408457
status = :small_step
409458
end

test/callback.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,9 @@ end
4646
trunk(nls, callback = cb)
4747
end
4848
@test stats.iter == 8
49+
50+
stats = with_logger(NullLogger()) do
51+
tron(nls, callback = cb)
52+
end
53+
@test stats.iter == 8
4954
end

0 commit comments

Comments
 (0)