Skip to content

Commit 5ec1dd2

Browse files
committed
allow supplying the analytic gradient
1 parent 147d7a4 commit 5ec1dd2

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/numericalnim/optimize.nim

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,18 @@ proc line_search*[U, T](alpha: var U, p: Tensor[T], x0: Tensor[U], f: proc(x: Te
239239
alpha = 1e2
240240
return
241241

242+
template analyticOrNumericGradient(analytic, f, x, options: untyped): untyped =
243+
if analytic.isNil:
244+
tensorGradient(f, x, fastMode=options.fastMode)
245+
else:
246+
analytic(x)
242247

243-
proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = steepestDescentOptions[U]()): Tensor[U] =
248+
proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = steepestDescentOptions[U](), analyticGradient: proc(x: Tensor[U]): Tensor[T] = nil): Tensor[U] =
244249
## Minimize scalar-valued function f.
245250
var alpha = options.alpha
246251
var x = x0.clone()
247252
var fNorm = abs(f(x0))
248-
var gradient = tensorGradient(f, x0, fastMode=options.fastMode)
253+
var gradient = analyticOrNumericGradient(analyticGradient, f, x0, options) #tensorGradient(f, x0, fastMode=options.fastMode)
249254
var gradNorm = vectorNorm(gradient)
250255
var iters: int
251256
while gradNorm > options.tol*(1 + fNorm) and iters < 10000:
@@ -254,19 +259,19 @@ proc steepestDescent*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U],
254259
x += alpha * p
255260
let fx = f(x)
256261
fNorm = abs(fx)
257-
gradient = tensorGradient(f, x, fastMode=options.fastMode)
262+
gradient = analyticOrNumericGradient(analyticGradient, f, x, options) #tensorGradient(f, x, fastMode=options.fastMode)
258263
gradNorm = vectorNorm(gradient)
259264
iters += 1
260265
if iters >= 10000:
261266
discard "Limit of 10000 iterations reached!"
262267
#echo iters, " iterations done!"
263268
result = x
264269

265-
proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = newtonOptions[U]()): Tensor[U] =
270+
proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = newtonOptions[U](), analyticGradient: proc(x: Tensor[U]): Tensor[T] = nil): Tensor[U] =
266271
var alpha = options.alpha
267272
var x = x0.clone()
268273
var fNorm = abs(f(x))
269-
var gradient = tensorGradient(f, x, fastMode=options.fastMode)
274+
var gradient = analyticOrNumericGradient(analyticGradient, f, x0, options)
270275
var gradNorm = vectorNorm(gradient)
271276
var hessian = tensorHessian(f, x)
272277
var iters: int
@@ -276,7 +281,7 @@ proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options:
276281
x += alpha * p
277282
let fx = f(x)
278283
fNorm = abs(fx)
279-
gradient = tensorGradient(f, x, fastMode=options.fastMode)
284+
gradient = analyticOrNumericGradient(analyticGradient, f, x, options)
280285
gradNorm = vectorNorm(gradient)
281286
hessian = tensorHessian(f, x)
282287
iters += 1
@@ -285,7 +290,7 @@ proc newton*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options:
285290
#echo iters, " iterations done!"
286291
result = x
287292

288-
proc bfgs_old*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], alpha: U = U(1), tol: U = U(1e-6), fastMode: bool = false): Tensor[U] =
293+
proc bfgs_old*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], alpha: U = U(1), tol: U = U(1e-6), fastMode: bool = false, analyticGradient: proc(x: Tensor[U]): Tensor[T] = nil): Tensor[U] =
289294
var x = x0.clone()
290295
let xLen = x.shape[0]
291296
var fNorm = abs(f(x))
@@ -329,13 +334,13 @@ proc bfgs_old*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], alpha:
329334
#echo iters, " iterations done!"
330335
result = x
331336

332-
proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = bfgsOptions[U]()): Tensor[U] =
337+
proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: OptimOptions[U, StandardOptions] = bfgsOptions[U](), analyticGradient: proc(x: Tensor[U]): Tensor[T] = nil): Tensor[U] =
333338
# Use gemm and gemv with preallocated Tensors and setting beta = 0
334339
var alpha = options.alpha
335340
var x = x0.clone()
336341
let xLen = x.shape[0]
337342
var fNorm = abs(f(x))
338-
var gradient = 0.01*tensorGradient(f, x, fastMode=options.fastMode)
343+
var gradient = 0.01*analyticOrNumericGradient(analyticGradient, f, x0, options)
339344
var gradNorm = vectorNorm(gradient)
340345
var hessianB = eye[T](xLen) # inverse of the approximated hessian
341346
var p = newTensor[U](xLen)
@@ -354,7 +359,7 @@ proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: O
354359
#echo "gradient iter ", iters, ": ", gradient
355360
line_search(alpha, p, x, f, options.lineSearchCriterion, options.fastMode)
356361
x += alpha * p
357-
let newGradient = tensorGradient(f, x, fastMode=options.fastMode)
362+
let newGradient = analyticOrNumericGradient(analyticGradient, f, x, options) #tensorGradient(f, x, fastMode=options.fastMode)
358363
let sk = alpha * p.reshape(xLen, 1)
359364

360365
let yk = (newGradient - gradient).reshape(xLen, 1)
@@ -409,12 +414,12 @@ proc bfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], options: O
409414
#echo iters, " iterations done!"
410415
result = x
411416

412-
proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int = 10, options: OptimOptions[U, LBFGSOptions[U]] = lbfgsOptions[U]()): Tensor[U] =
417+
proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int = 10, options: OptimOptions[U, LBFGSOptions[U]] = lbfgsOptions[U](), analyticGradient: proc(x: Tensor[U]): Tensor[T] = nil): Tensor[U] =
413418
var alpha = options.alpha
414419
var x = x0.clone()
415420
let xLen = x.shape[0]
416421
var fNorm = abs(f(x))
417-
var gradient = 0.01*tensorGradient(f, x, fastMode=options.fastMode)
422+
var gradient = 0.01*analyticOrNumericGradient(analyticGradient, f, x0, options)
418423
var gradNorm = vectorNorm(gradient)
419424
var iters: int
420425
#let m = 10 # number of past iterations to save
@@ -447,7 +452,7 @@ proc lbfgs*[U; T: not Tensor](f: proc(x: Tensor[U]): T, x0: Tensor[U], m: int =
447452
line_search(alpha, p, x, f, options.lineSearchCriterion, options.fastMode)
448453
x += alpha * p
449454
sk_queue.addFirst alpha*p
450-
let newGradient = tensorGradient(f, x, fastMode=options.fastMode)
455+
let newGradient = analyticOrNumericGradient(analyticGradient, f, x, options)
451456
let yk = newGradient - gradient
452457
yk_queue.addFirst yk
453458
gradient = newGradient

tests/test_optimize.nim

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,30 +70,58 @@ suite "Multi-dim":
7070
## Function of 2 variables with minimum at (1, 1)
7171
## And it looks like a banana 🍌
7272
result = (1 - x[0])^2 + 100*(x[1] - x[0]^2)^2
73+
74+
proc bananaBend(x: Tensor[float]): Tensor[float] =
75+
## Calculates the gradient of the banana function
76+
result = newTensor[float](2)
77+
result[0] = -2 * (1 - x[0]) + 100 * 2 * (x[1] - x[0]*x[0]) * -2 * x[0] # this one is wrong
78+
result[1] = 100 * 2 * (x[1] - x[0]*x[0])
7379

7480
let x0 = [-1.0, -1.0].toTensor()
7581
let correct = [1.0, 1.0].toTensor()
7682

83+
doAssert checkGradient(bananaFunc, bananaBend, x0, 1e-6), "Analytic gradient is wrong in test!"
84+
7785
test "Steepest Gradient":
7886
let xSol = steepestDescent(bananaFunc, x0.clone)
7987
for x in abs(correct - xSol):
8088
check x < 2e-2
89+
90+
test "Steepest Gradient analytic":
91+
let xSol = steepestDescent(bananaFunc, x0.clone, analyticGradient=bananaBend)
92+
for x in abs(correct - xSol):
93+
check x < 2e-2
8194

8295
test "Newton":
8396
let xSol = newton(bananaFunc, x0.clone)
8497
for x in abs(correct - xSol):
8598
check x < 3e-10
8699

100+
test "Newton analytic":
101+
let xSol = newton(bananaFunc, x0.clone, analyticGradient=bananaBend)
102+
for x in abs(correct - xSol):
103+
check x < 3e-10
104+
87105
test "BFGS":
88106
let xSol = bfgs(bananaFunc, x0.clone)
89107
for x in abs(correct - xSol):
90108
check x < 3e-7
91109

110+
test "BFGS analytic":
111+
let xSol = bfgs(bananaFunc, x0.clone, analyticGradient=bananaBend)
112+
for x in abs(correct - xSol):
113+
check x < 3e-7
114+
92115
test "L-BFGS":
93116
let xSol = lbfgs(bananaFunc, x0.clone)
94117
for x in abs(correct - xSol):
95118
check x < 7e-10
96119

120+
test "L-BFGS analytic":
121+
let xSol = lbfgs(bananaFunc, x0.clone, analyticGradient=bananaBend)
122+
for x in abs(correct - xSol):
123+
check x < 7e-10
124+
97125
let correctParams = [10.4, -0.45].toTensor()
98126
proc fitFunc(params: Tensor[float], x: float): float =
99127
params[0] * exp(params[1] * x)

0 commit comments

Comments
 (0)