Skip to content

Commit ff9b306

Browse files
authored
Merge pull request #7 from tbeason/optim
remove optim dep, borrow functionality, update deps
2 parents 3c4633d + 0a01e6e commit ff9b306

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
name = "NonparametricRegression"
22
uuid = "db432338-e110-4b7a-9c53-0ace38eb8f7f"
33
authors = ["Tyler Beason <tbeas12@gmail.com>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
109
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1110
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1211

1312
[compat]
1413
DocStringExtensions = "0.8, 0.9"
15-
Optim = "1"
16-
StaticArrays = "1.2, 1.3"
14+
StaticArrays = "1.2, 1.3, 1.4"
1715
julia = "1.6"
1816

1917
[extras]

src/NonparametricRegression.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ module NonparametricRegression
88
using LinearAlgebra
99
using Statistics
1010
using StaticArrays
11-
using Optim
1211
using DocStringExtensions
1312

1413

@@ -19,6 +18,7 @@ export optimalbandwidth, leaveoneoutCV, optimizeAICc
1918
export npregress
2019

2120
include("kernels.jl")
21+
include("univariateopt.jl")
2222

2323

2424
########################################
@@ -319,8 +319,8 @@ end
319319
function leaveoneoutCV(x, y; kernelfun=NormalKernel, method=:lc, hLB = silvermanbw(y)/100, hUB = silvermanbw(y)*100,trimmed=true)
320320
objfun(h) = leaveoneoutCV_mse(h,x,y; kernelfun, method, trimmed)
321321
opt = optimize(objfun,hLB,hUB)
322-
@assert opt.converged "Convergence failed, cannot find optimal bandwidth."
323-
return Optim.minimizer(opt)
322+
# @assert opt.converged "Convergence failed, cannot find optimal bandwidth."
323+
return opt
324324
end
325325

326326

@@ -337,8 +337,8 @@ end
337337
function optimizeAICc(x, y; kernelfun=NormalKernel, method=:lc, hLB = silvermanbw(y)/100, hUB = silvermanbw(y)*100)
338338
objfun(h) = estimatorAICc(h,x,y; kernelfun, method)
339339
opt = optimize(objfun,hLB,hUB)
340-
@assert opt.converged "Convergence failed, cannot find optimal bandwidth."
341-
return Optim.minimizer(opt)
340+
# @assert opt.converged "Convergence failed, cannot find optimal bandwidth."
341+
return opt
342342
end
343343

344344

@@ -402,7 +402,7 @@ The keyword argument `method` can be `:lc` for a local constant estimator (Nadar
402402
403403
The keyword argument `bandwidthselection` should be `:aicc` for the bias-correct AICc method or `:loocv` for leave-one-out cross validation. `hLB` and `hUB` are the lower and upper bounds used when searching for the optimal bandwidth.
404404
405-
The keyword argument `kernelfun` should be a function that constructs a `KernelFunctions.jl` kernel with a given bandwidth. Defaults to `NormalKernel` (defined and exported by this package.)
405+
The keyword argument `kernelfun` should be a function that constructs a kernel with a given bandwidth. Defaults to `NormalKernel` (defined and exported by this package.)
406406
"""
407407
function npregress(x, y, xgrid=x; kernelfun=NormalKernel, method=:lc, bandwidthselection=:aicc, hLB = silvermanbw(y)/100, hUB = silvermanbw(y)*100)
408408
h = optimalbandwidth(x,y; kernelfun, method, bandwidthselection, hLB, hUB)

src/univariateopt.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
2+
# borrowed from KernelDensity.jl
3+
# https://github.com/JuliaStats/KernelDensity.jl/blob/master/src/univariate.jl
4+
5+
"""
6+
optimize(f, x_lower, x_upper; iterations=1000, rel_tol=nothing, abs_tol=nothing)
7+
8+
Minimize the function `f` in the interval `x_lower..x_upper`, using the
9+
[golden-section search](https://en.wikipedia.org/wiki/Golden-section_search).
10+
Return an approximate minimum `x̃` or error if such approximate minimum cannot be found.
11+
12+
This algorithm assumes that `-f` is unimodal on the interval `x_lower..x_upper`,
13+
that is to say, there exists a unique `x` in `x_lower..x_upper` such that `f` is
14+
decreasing on `x_lower..x` and increasing on `x..x_upper`.
15+
16+
`rel_tol` and `abs_tol` determine the relative and absolute tolerance, that is
17+
to say, the returned value `x̃` should differ from the actual minimum `x` at most
18+
`abs_tol + rel_tol * abs(x̃)`.
19+
If not manually specified, `rel_tol` and `abs_tol` default to `sqrt(eps(T))` and
20+
`eps(T)` respectively, where `T` is the floating point type of `x_lower` and `x_upper`.
21+
22+
`iterations` determines the maximum number of iterations allowed before convergence.
23+
24+
This is a private, unexported function, used internally to select the optimal bandwidth
25+
automatically.
26+
"""
27+
function optimize(f, x_lower, x_upper; iterations=1000, rel_tol=nothing, abs_tol=nothing)
28+
29+
if x_lower > x_upper
30+
error("x_lower must be less than x_upper")
31+
end
32+
33+
T = promote_type(typeof(x_lower/1), typeof(x_upper/1))
34+
rtol = something(rel_tol, sqrt(eps(T)))
35+
atol = something(abs_tol, eps(T))
36+
37+
function midpoint_and_convergence(lower, upper)
38+
midpoint = (lower + upper) / 2
39+
tol = atol + rtol * midpoint
40+
midpoint, (upper - lower) <= 2tol
41+
end
42+
43+
invphi::T = 0.5 * (sqrt(5) - 1)
44+
invphisq::T = 0.5 * (3 - sqrt(5))
45+
46+
a::T, b::T = x_lower, x_upper
47+
h = b - a
48+
c = a + invphisq * h
49+
d = a + invphi * h
50+
51+
fc, fd = f(c), f(d)
52+
53+
for _ in 1:iterations
54+
h *= invphi
55+
if fc < fd
56+
m, converged = midpoint_and_convergence(a, d)
57+
converged && return m
58+
b = d
59+
d, fd = c, fc
60+
c = a + invphisq * h
61+
fc = f(c)
62+
else
63+
m, converged = midpoint_and_convergence(c, b)
64+
converged && return m
65+
a = c
66+
c, fc = d, fd
67+
d = a + invphi * h
68+
fd = f(d)
69+
end
70+
end
71+
72+
error("Reached maximum number of iterations without convergence.")
73+
end

0 commit comments

Comments
 (0)