Skip to content

Commit 49f8050

Browse files
committed
improve options types with specialization for each methods (inspired by @Vindaar)
1 parent 772097c commit 49f8050

File tree

1 file changed

+69
-21
lines changed

1 file changed

+69
-21
lines changed

src/numericalnim/optimize.nim

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,67 @@ proc secant*(f: proc(x: float64): float64, start: array[2, float64], precision:
126126
type LineSearchCriterion = enum
127127
Armijo, Wolfe, WolfeStrong, NoLineSearch
128128

129-
type OptimOptions*[U] = object
130-
tol*, alpha*, lambda0*: U
131-
fastMode*: bool
132-
maxIterations*: int
133-
lineSearchCriterion*: LineSearchCriterion
134-
135-
proc optimOptions*[U](tol: U = U(1e-6), alpha: U = U(1), lambda0: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U] =
129+
type
130+
OptimOptions*[U, AO] = object
131+
tol*, alpha*: U
132+
fastMode*: bool
133+
maxIterations*: int
134+
lineSearchCriterion*: LineSearchCriterion
135+
algoOptions*: AO
136+
StandardOptions* = object
137+
LevMarqOptions*[U] = object
138+
lambda0*: U
139+
LBFGSOptions*[U] = object
140+
savedIterations*: int
141+
142+
proc optimOptions*[U](tol: U = U(1e-6), alpha: U = U(1), lambda0: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, StandardOptions] =
136143
result.tol = tol
137144
result.alpha = alpha
138145
result.lambda0 = lambda0
139146
result.fastMode = fastMode
140147
result.maxIterations = maxIterations
141148
result.lineSearchCriterion = lineSearchCriterion
142149

150+
proc steepestDescentOptions*[U](tol: U = U(1e-6), alpha: U = U(0.001), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, StandardOptions] =
151+
result.tol = tol
152+
result.alpha = alpha
153+
result.fastMode = fastMode
154+
result.maxIterations = maxIterations
155+
result.lineSearchCriterion = lineSearchCriterion
156+
157+
proc newtonOptions*[U](tol: U = U(1e-6), alpha: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, StandardOptions] =
158+
result.tol = tol
159+
result.alpha = alpha
160+
result.fastMode = fastMode
161+
result.maxIterations = maxIterations
162+
result.lineSearchCriterion = lineSearchCriterion
163+
164+
proc bfgsOptions*[U](tol: U = U(1e-6), alpha: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, StandardOptions] =
165+
result.tol = tol
166+
result.alpha = alpha
167+
result.fastMode = fastMode
168+
result.maxIterations = maxIterations
169+
result.lineSearchCriterion = lineSearchCriterion
170+
171+
proc lbfgsOptions*[U](savedIterations: int = 10, tol: U = U(1e-6), alpha: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, LBFGSOptions[U]] =
172+
result.tol = tol
173+
result.alpha = alpha
174+
result.fastMode = fastMode
175+
result.maxIterations = maxIterations
176+
result.lineSearchCriterion = lineSearchCriterion
177+
result.algoOptions.savedIterations = savedIterations
178+
179+
proc levmarqOptions*[U](lambda0: U = U(1), tol: U = U(1e-6), alpha: U = U(1), fastMode: bool = false, maxIterations: int = 10000, lineSearchCriterion: LineSearchCriterion = NoLineSearch): OptimOptions[U, LevMarqOptions[U]] =
180+
result.tol = tol
181+
result.alpha = alpha
182+
result.fastMode = fastMode
183+
result.maxIterations = maxIterations
184+
result.lineSearchCriterion = lineSearchCriterion
185+
result.algoOptions.lambda0 = lambda0
186+
187+
188+
189+
143190
proc vectorNorm*[T](v: Tensor[T]): T =
144191
## Calculates the norm of the vector, ie the sqrt(Σ vᵢ²)
145192
assert v.rank == 1, "v must be a 1d vector!"
@@ -193,11 +240,7 @@ proc line_search*[U, T](alpha: var U, p: Tensor[T], x0: Tensor[U], f: proc(x: Te
193240
return
194241

195242

196-
197-
198-
199-
200-
proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U] = optimOptions[U](alpha = U(0.001))): Tensor[U] =
243+
proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = steepestDescentOptions[U]()): Tensor[U] =
201244
## Minimize scalar-valued function f.
202245
var alpha = options.alpha
203246
var x = x0.clone()
@@ -219,7 +262,7 @@ proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U],
219262
#echo iters, " iterations done!"
220263
result = x
221264

222-
proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U] = optimOptions[U]()): Tensor[U] =
265+
proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = newtonOptions[U]()): Tensor[U] =
223266
var alpha = options.alpha
224267
var x = x0.clone()
225268
var fNorm = abs(f(x))
@@ -286,7 +329,7 @@ proc bfgs_old*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], alpha:
286329
#echo iters, " iterations done!"
287330
result = x
288331

289-
proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U] = optimOptions[U]()): Tensor[U] =
332+
proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = bfgsOptions[U]()): Tensor[U] =
290333
# Use gemm and gemv with preallocated Tensors and setting beta = 0
291334
var alpha = options.alpha
292335
var x = x0.clone()
@@ -366,7 +409,7 @@ proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: O
366409
#echo iters, " iterations done!"
367410
result = x
368411

369-
proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int = 10, options: OptimOptions[U] = optimOptions[U]()): Tensor[U] =
412+
proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int = 10, options: OptimOptions[U, LBFGSOptions[U]] = lbfgsOptions[U]()): Tensor[U] =
370413
var alpha = options.alpha
371414
var x = x0.clone()
372415
let xLen = x.shape[0]
@@ -420,7 +463,7 @@ proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int =
420463
#echo iters, " iterations done!"
421464
result = x
422465

423-
proc levmarq*[U; T: not Tensor](f: proc(params: Tensor[U], x: U): T, params0: Tensor[U], xData: Tensor[U], yData: Tensor[T], alpha = U(1), tol: U = U(1e-6), lambda0: U = U(1), fastMode = false): Tensor[U] =
466+
proc levmarq*[U; T: not Tensor](f: proc(params: Tensor[U], x: U): T, params0: Tensor[U], xData: Tensor[U], yData: Tensor[T], options: OptimOptions[U, LevmarqOptions[U]] = levmarqOptions[U]()): Tensor[U] =
424467
assert xData.rank == 1
425468
assert yData.rank == 1
426469
assert params0.rank == 1
@@ -434,21 +477,26 @@ proc levmarq*[U; T: not Tensor](f: proc(params: Tensor[U], x: U): T, params0: Te
434477
result = map2_inline(xData, yData):
435478
f(params, x) - y
436479

437-
var lambdaCoeff = lambda0
480+
let errorFunc = # proc that returns the scalar error
481+
proc (params: Tensor[U]): T =
482+
let res = residualFunc(params)
483+
result = dot(res, res)
484+
485+
var lambdaCoeff = options.algoOptions.lambda0
438486

439487
var params = params0.clone()
440-
var gradient = tensorGradient(residualFunc, params, fastMode=fastMode)
488+
var gradient = tensorGradient(residualFunc, params, fastMode=options.fastMode)
441489
var residuals = residualFunc(params)
442490
var resNorm = vectorNorm(residuals)
443491
var gradNorm = vectorNorm(squeeze(gradient * residuals.reshape(xLen, 1)))
444492
var iters: int
445493
let eyeNN = eye[T](paramsLen)
446-
while gradNorm > tol*(1 + resNorm) and iters < 10000:
494+
while gradNorm > options.tol*(1 + resNorm) and iters < 10000:
447495
let rhs = -gradient * residuals.reshape(xLen, 1)
448496
let lhs = gradient * gradient.transpose + lambdaCoeff * eyeNN
449497
let p = solve(lhs, rhs)
450-
params += p * alpha
451-
gradient = tensorGradient(residualFunc, params, fastMode=fastMode)
498+
params += p * options.alpha
499+
gradient = tensorGradient(residualFunc, params, fastMode=options.fastMode)
452500
residuals = residualFunc(params)
453501
let newGradNorm = vectorNorm(squeeze(gradient * residuals.reshape(xLen, 1)))
454502
if newGradNorm / gradNorm < 0.9: # we have improved, decrease lambda → more Gauss-Newton

0 commit comments

Comments
 (0)