Skip to content

Commit e476e82

Browse files
committed
final fixes
1 parent d4bcc8c commit e476e82

File tree

5 files changed

+45
-59
lines changed

5 files changed

+45
-59
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.4.0"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
10+
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1011

1112
[compat]
1213
Aqua = "0.8"
@@ -15,6 +16,7 @@ Printf = "1"
1516
Random = "1"
1617
ScopedValues = "1"
1718
Test = "1"
19+
VectorInterface = "0.5"
1820
julia = "1.8"
1921

2022
[extras]

src/OptimKit.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module OptimKit
33
using LinearAlgebra: LinearAlgebra
44
using Printf
55
using ScopedValues
6+
using VectorInterface
67
using Base: @kwdef
78

89
# Default values for the keyword arguments using ScopedValues
@@ -14,15 +15,34 @@ const GRADTOL = ScopedValue(1e-8)
1415
const MAXITER = ScopedValue(1_000_000)
1516
const VERBOSITY = ScopedValue(1)
1617

17-
_retract(x, d, α) = (x + α * d, d)
18-
_inner(x, v1, v2) = v1 === v2 ? LinearAlgebra.norm(v1)^2 : LinearAlgebra.dot(v1, v2)
18+
# Default values for the manifold structure
19+
_retract(x, d, α) = (add(x, d, α), d)
20+
_inner(x, v1, v2) = v1 === v2 ? norm(v1)^2 : real(inner(v1, v2))
1921
_transport!(v, xold, d, α, xnew) = v
20-
_scale!(v, α) = LinearAlgebra.rmul!(v, α)
21-
_add!(vdst, vsrc, α) = LinearAlgebra.axpy!, vsrc, vdst)
22+
_scale!(v, α) = scale!!(v, α)
23+
_add!(vdst, vsrc, α) = add!!(vdst, vsrc, α)
2224

2325
_precondition(x, g) = deepcopy(g)
2426
_finalize!(x, f, g, numiter) = x, f, g
2527

28+
# Default structs for new convergence and termination keywords
29+
@kwdef struct DefaultHasConverged{T<:Real}
30+
gradtol::T
31+
end
32+
33+
function (d::DefaultHasConverged)(x, f, g, normgrad)
34+
return normgrad <= d.gradtol
35+
end
36+
37+
@kwdef struct DefaultShouldStop
38+
maxiter::Int
39+
end
40+
41+
function (d::DefaultShouldStop)(x, f, g, numfg, numiter, t)
42+
return numiter >= d.maxiter
43+
end
44+
45+
# Optimization
2646
abstract type OptimizationAlgorithm end
2747

2848
const _xlast = Ref{Any}()
@@ -85,7 +105,6 @@ Also see [`GradientDescent`](@ref), [`ConjugateGradient`](@ref), [`LBFGS`](@ref)
85105
function optimize end
86106

87107
include("linesearches.jl")
88-
include("terminate.jl")
89108
include("gd.jl")
90109
include("cg.jl")
91110
include("lbfgs.jl")

src/tangentvector.jl

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/terminate.jl

Lines changed: 0 additions & 15 deletions
This file was deleted.

test/runtests.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,30 @@ function quadraticproblem(B, y)
4949
return fg
5050
end
5151

52+
function quadratictupleproblem(B, y)
53+
function fg(x)
54+
x1, x2 = x
55+
y1, y2 = y
56+
g1 = B * (x1 - y1)
57+
g2 = x2 - y2
58+
f = dot(x1 - y1, g1) / 2 + (x2 - y2)^2 / 2
59+
return f, (g1, g2)
60+
end
61+
return fg
62+
end
63+
5264
algorithms = (GradientDescent, ConjugateGradient, LBFGS)
5365

5466
@testset "Optimization Algorithm $algtype" for algtype in algorithms
5567
n = 10
5668
y = randn(n)
5769
A = randn(n, n)
58-
fg = quadraticproblem(A' * A, y)
70+
A = A' * A
71+
fg = quadraticproblem(A, y)
5972
x₀ = randn(n)
6073
alg = algtype(; verbosity=2, gradtol=1e-12, maxiter=10_000_000)
6174
x, f, g, numfg, normgradhistory = optimize(fg, x₀, alg)
62-
@test x y rtol = 10 * cond(A) * 1e-12
75+
@test x y rtol = cond(A) * 1e-12
6376
@test f < 1e-12
6477

6578
n = 1000
@@ -68,11 +81,12 @@ algorithms = (GradientDescent, ConjugateGradient, LBFGS)
6881
smax = maximum(S)
6982
A = U * Diagonal(1 .+ S ./ smax) * U'
7083
# well conditioned, all eigenvalues between 1 and 2
71-
fg = quadraticproblem(A' * A, y)
72-
x₀ = randn(n)
84+
fg = quadratictupleproblem(A' * A, (y, 1.0))
85+
x₀ = (randn(n), 2.0)
7386
alg = algtype(; verbosity=3, gradtol=1e-8)
7487
x, f, g, numfg, normgradhistory = optimize(fg, x₀, alg)
75-
@test x y rtol = 1e-7
88+
@test x[1] y rtol = 1e-7
89+
@test x[2] 1 rtol = 1e-7
7690
@test f < 1e-12
7791
end
7892

0 commit comments

Comments
 (0)