Skip to content

Commit e861492

Browse files
authored
Merge pull request #24 from leburgel/lb/optimtest_tweak
Modify `optimtest` to perform less function evaluations
2 parents e476e82 + 61a67b0 commit e861492

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

src/OptimKit.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,24 @@ Test the compatibility between the computation of the gradient, the retraction a
129129
It is up to the user to check that the values in `dfs1` and `dfs2` match up to expected precision, by inspecting the numerical values or plotting them. If these values don't match, the linesearch in `optimize` cannot be expected to work.
130130
"""
131131
function optimtest(fg, x, d=fg(x)[2]; alpha=-0.1:0.001:0.1, retract=_retract, inner=_inner)
132-
f0, g0 = fg(x)
133-
fs = Vector{typeof(f0)}(undef, length(alpha) - 1)
134-
dfs1 = similar(fs, length(alpha) - 1)
135-
dfs2 = similar(fs, length(alpha) - 1)
136-
for i in 1:(length(alpha) - 1)
137-
a1 = alpha[i]
138-
a2 = alpha[i + 1]
139-
f1, = fg(retract(x, d, a1)[1])
140-
f2, = fg(retract(x, d, a2)[1])
141-
dfs1[i] = (f2 - f1) / (a2 - a1)
142-
xmid, dmid = retract(x, d, (a1 + a2) / 2)
132+
# evaluate function at given edge points
133+
fs_edges = map(alpha) do a
134+
f, = fg(retract(x, d, a)[1])
135+
return f
136+
end
137+
a1s = alpha[1:(end - 1)]
138+
a2s = alpha[2:end]
139+
dfs1 = (fs_edges[2:end] .- fs_edges[1:(end - 1)]) ./ (a2s .- a1s)
140+
# evaluate function and gradient at midpoints
141+
alphas = collect((a1s + a2s) / 2)
142+
fs_dfs = map(alphas) do a
143+
xmid, dmid = retract(x, d, a)
143144
fmid, gmid = fg(xmid)
144-
fs[i] = fmid
145-
dfs2[i] = inner(xmid, dmid, gmid)
145+
df = inner(xmid, dmid, gmid)
146+
return fmid, df
146147
end
147-
alphas = collect((alpha[2:end] + alpha[1:(end - 1)]) / 2)
148+
fs = first.(fs_dfs)
149+
dfs2 = last.(fs_dfs)
148150
return alphas, fs, dfs1, dfs2
149151
end
150152

0 commit comments

Comments
 (0)