Skip to content

Commit 147d7a4

Browse files
committed
checkGradient
1 parent 49f8050 commit 147d7a4

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/numericalnim/differentiate.nim

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import std/strformat
12
import arraymancer
23

34
proc diff1dForward*[U, T](f: proc(x: U): T, x0: U, h: U = U(1e-6)): T =
@@ -139,9 +140,25 @@ proc tensorHessian*[U; T: not Tensor](
139140
result[i, j] = mixed
140141
result[j, i] = mixed
141142

142-
143-
144-
143+
proc checkGradient*[U; T: not Tensor](f: proc(x: Tensor[U]): T, fGrad: proc(x: Tensor[U]): Tensor[T], x0: Tensor[U], tol: T): bool =
144+
## Checks if the provided gradient function `fGrad` gives the same values as numeric gradient.
145+
let numGrad = tensorGradient(f, x0)
146+
let grad = fGrad(x0)
147+
result = true
148+
for i, x in abs(numGrad - grad):
149+
if x > tol:
150+
echo fmt"Gradient at index {i[0]} has error: {x} (tol = {tol})"
151+
result = false
152+
153+
proc checkGradient*[U; T](f: proc(x: Tensor[U]): Tensor[T], fGrad: proc(x: Tensor[U]): Tensor[T], x0: Tensor[U], tol: T): bool =
154+
## Checks if the provided gradient function `fGrad` gives the same values as numeric gradient.
155+
let numGrad = tensorGradient(f, x0)
156+
let grad = fGrad(x0)
157+
result = true
158+
for i, x in abs(numGrad - grad):
159+
if x > tol:
160+
echo fmt"Gradient at index {i[0]} has error: {x} (tol = {tol})"
161+
result = false
145162

146163

147164

tests/test_differentiate.nim

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ suite "Multi dimensional numeric gradients":
140140
for err in abs(numJacobian - exact):
141141
check err < 1e-10
142142

143+
test "checkGradient":
144+
check checkGradient(fScalar, scalarGradient, [0.5, 0.5, 0.5].toTensor, 6e-11)
145+
check checkGradient(fMultidim, multidimGradient, [0.5, 0.5, 0.5].toTensor, 4e-12)
146+
143147
test "Hessian scalar valued function":
144148
for x in numericalnim.linspace(0, 1, 10):
145149
for y in numericalnim.linspace(0, 1, 10):

0 commit comments

Comments
 (0)