Skip to content

Commit 39da4b5

Browse files
committed
Use Val(fdtype)
1 parent b8178f1 commit 39da4b5

File tree

6 files changed

+95
-85
lines changed

6 files changed

+95
-85
lines changed

src/derivatives.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@ Single-point derivatives of scalar->scalar maps.
44
function finite_difference_derivative(
55
f,
66
x::T,
7-
fdtype=Val{:central},
7+
fdtype=Val(:central),
88
returntype=eltype(x),
99
f_x=nothing;
1010
relstep=default_relstep(fdtype, T),
1111
absstep=relstep,
1212
dir=true) where {T<:Number}
1313

14+
fdtype isa Type && (fdtype = fdtype())
1415
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
15-
if fdtype==Val{:forward}
16+
if fdtype==Val(:forward)
1617
return (f(x+epsilon) - f(x)) / epsilon
17-
elseif fdtype==Val{:central}
18+
elseif fdtype==Val(:central)
1819
return (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
19-
elseif fdtype==Val{:complex} && returntype<:Real
20+
elseif fdtype==Val(:complex) && returntype<:Real
2021
return imag(f(x+im*epsilon)) / epsilon
2122
end
2223
fdtype_error(returntype)
@@ -36,15 +37,16 @@ function DerivativeCache(
3637
x :: AbstractArray{<:Number},
3738
fx :: Union{Nothing,AbstractArray{<:Number}} = nothing,
3839
epsilon :: Union{Nothing,AbstractArray{<:Real}} = nothing,
39-
fdtype :: Type{T1} = Val{:central},
40+
fdtype :: Type{T1} = Val(:central),
4041
returntype :: Type{T2} = eltype(x)) where {T1,T2}
4142

42-
if fdtype==Val{:complex} && !(eltype(returntype)<:Real)
43+
fdtype isa Type && (fdtype = fdtype())
44+
if fdtype==Val(:complex) && !(eltype(returntype)<:Real)
4345
fdtype_error(returntype)
4446
end
4547

46-
if fdtype!=Val{:forward} && typeof(fx)!=Nothing
47-
@warn("Pre-computed function values are only useful for fdtype==Val{:forward}.")
48+
if fdtype!=Val(:forward) && typeof(fx)!=Nothing
49+
@warn("Pre-computed function values are only useful for fdtype==Val(:forward).")
4850
_fx = nothing
4951
else
5052
# more runtime sanity checks?
@@ -54,8 +56,8 @@ function DerivativeCache(
5456
if typeof(epsilon)!=Nothing && typeof(x)<:StridedArray && typeof(fx)<:Union{Nothing,StridedArray} && 1==2
5557
@warn("StridedArrays don't benefit from pre-allocating epsilon.")
5658
_epsilon = nothing
57-
elseif typeof(epsilon)!=Nothing && fdtype==Val{:complex}
58-
@warn("Val{:complex} makes the epsilon array redundant.")
59+
elseif typeof(epsilon)!=Nothing && fdtype==Val(:complex)
60+
@warn("Val(:complex) makes the epsilon array redundant.")
5961
_epsilon = nothing
6062
else
6163
if typeof(epsilon)==Nothing || eltype(epsilon)!=real(eltype(x))
@@ -72,7 +74,7 @@ Compute the derivative df of a scalar-valued map f at a collection of points x.
7274
function finite_difference_derivative(
7375
f,
7476
x,
75-
fdtype = Val{:central},
77+
fdtype = Val(:central),
7678
returntype = eltype(x), # return type of f
7779
fx = nothing,
7880
epsilon = nothing;
@@ -87,7 +89,7 @@ function finite_difference_derivative!(
8789
df,
8890
f,
8991
x,
90-
fdtype = Val{:central},
92+
fdtype = Val(:central),
9193
returntype = eltype(x),
9294
fx = nothing,
9395
epsilon = nothing;
@@ -111,15 +113,15 @@ function finite_difference_derivative!(
111113
if typeof(epsilon) != Nothing
112114
@. epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
113115
end
114-
if fdtype == Val{:forward}
116+
if fdtype == Val(:forward)
115117
if typeof(fx) == Nothing
116118
@. df = (f(x+epsilon) - f(x)) / epsilon
117119
else
118120
@. df = (f(x+epsilon) - fx) / epsilon
119121
end
120-
elseif fdtype == Val{:central}
122+
elseif fdtype == Val(:central)
121123
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
122-
elseif fdtype == Val{:complex} && returntype<:Real
124+
elseif fdtype == Val(:complex) && returntype<:Real
123125
epsilon_complex = eps(eltype(x))
124126
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
125127
else
@@ -142,25 +144,25 @@ function finite_difference_derivative!(
142144
absstep=relstep,
143145
dir=true) where {T1,T2,fdtype,returntype}
144146

145-
if fdtype == Val{:forward}
147+
if fdtype == Val(:forward)
146148
fx = cache.fx
147149
@inbounds for i eachindex(x)
148-
epsilon = compute_epsilon(Val{:forward}, x[i], relstep, absstep, dir)
150+
epsilon = compute_epsilon(Val(:forward), x[i], relstep, absstep, dir)
149151
x_plus = x[i] + epsilon
150152
if typeof(fx) == Nothing
151153
df[i] = (f(x_plus) - f(x[i])) / epsilon
152154
else
153155
df[i] = (f(x_plus) - fx[i]) / epsilon
154156
end
155157
end
156-
elseif fdtype == Val{:central}
158+
elseif fdtype == Val(:central)
157159
@inbounds for i eachindex(x)
158-
epsilon = compute_epsilon(Val{:central}, x[i], relstep, absstep, dir)
160+
epsilon = compute_epsilon(Val(:central), x[i], relstep, absstep, dir)
159161
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
160162
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
161163
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
162164
end
163-
elseif fdtype == Val{:complex}
165+
elseif fdtype == Val(:complex)
164166
epsilon_complex = eps(eltype(x))
165167
@inbounds for i eachindex(x)
166168
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex

src/epsilons.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,39 @@ Very heavily inspired by Calculus.jl, but with an emphasis on performance and Di
66
Compute the finite difference interval epsilon.
77
Reference: Numerical Recipes, chapter 5.7.
88
=#
9-
@inline function compute_epsilon(::Type{Val{:forward}}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
9+
@inline function compute_epsilon(::Val{:forward}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
1010
return max(relstep*abs(x), absstep)*dir
1111
end
1212

13-
@inline function compute_epsilon(::Type{Val{:central}}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
13+
@inline function compute_epsilon(::Val{:central}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1414
return max(relstep*abs(x), absstep)
1515
end
1616

17-
@inline function compute_epsilon(::Type{Val{:hcentral}}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
17+
@inline function compute_epsilon(::Val{:hcentral}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1818
return max(relstep*abs(x), absstep)
1919
end
2020

21-
@inline function compute_epsilon(::Type{Val{:complex}}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
21+
@inline function compute_epsilon(::Val{:complex}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
2222
return eps(T)
2323
end
2424

25-
@inline function default_relstep(fdtype::DataType, ::Type{T}) where T<:Number
26-
if fdtype==Val{:forward}
25+
default_relstep(v::Type, T) = default_relstep(v(), T)
26+
@inline function default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype,T<:Number}
27+
if fdtype==:forward
2728
return sqrt(eps(real(T)))
28-
elseif fdtype==Val{:central}
29+
elseif fdtype==:central
2930
return cbrt(eps(real(T)))
30-
elseif fdtype==Val{:hcentral}
31+
elseif fdtype==:hcentral
3132
eps(T)^(1/4)
3233
else
3334
return one(real(T))
3435
end
3536
end
3637

37-
function fdtype_error(funtype::Type{T}=Float64) where T
38-
if funtype<:Real
38+
function fdtype_error(::Type{T}=Float64) where T
39+
if T<:Real
3940
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
40-
elseif funtype<:Complex
41+
elseif T<:Complex
4142
error("Unrecognized fdtype: valid values are Val{:forward} or Val{:central}.")
4243
else
4344
error("Unrecognized returntype: should be a subtype of Real or Complex.")

src/gradients.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ end
88
function GradientCache(
99
df,
1010
x,
11-
fdtype = Val{:central},
11+
fdtype = Val(:central),
1212
returntype = eltype(df),
1313
inplace = Val{true})
1414

15+
fdtype isa Type && (fdtype = fdtype())
1516
if typeof(x)<:AbstractArray # the vector->scalar case
16-
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
17+
if fdtype!=Val(:complex) # complex-mode FD only needs one cache, for x+eps*im
1718
if typeof(x)<:StridedVector
1819
if eltype(df)<:Complex && !(eltype(x)<:Complex)
1920
_c1 = zero(Complex{eltype(x)}) .* x
@@ -37,7 +38,7 @@ function GradientCache(
3738
_c3 = similar(x)
3839
else # the scalar->vector case
3940
# need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
40-
if fdtype != Val{:complex}
41+
if fdtype != Val(:complex)
4142
_c1 = similar(df)
4243
_c2 = similar(df)
4344
else
@@ -55,7 +56,7 @@ end
5556
function finite_difference_gradient(
5657
f,
5758
x,
58-
fdtype = Val{:central},
59+
fdtype = Val(:central),
5960
returntype = eltype(x),
6061
inplace = Val{true},
6162
fx = nothing,
@@ -89,7 +90,7 @@ function finite_difference_gradient!(
8990
df,
9091
f,
9192
x,
92-
fdtype=Val{:central},
93+
fdtype=Val(:central),
9394
returntype=eltype(df),
9495
inplace=Val{true},
9596
fx=nothing,
@@ -133,12 +134,12 @@ function finite_difference_gradient!(
133134
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
134135
# c1 denotes x1, c2 is epsilon
135136
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
136-
if fdtype != Val{:complex} && ArrayInterface.fast_scalar_indexing(c2)
137+
if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2)
137138
@. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
138139
copyto!(c1,x)
139140
end
140141
copyto!(c3,x)
141-
if fdtype == Val{:forward}
142+
if fdtype == Val(:forward)
142143
@inbounds for i eachindex(x)
143144
if ArrayInterface.fast_scalar_indexing(c2)
144145
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
@@ -168,7 +169,7 @@ function finite_difference_gradient!(
168169
ArrayInterface.allowed_setindex!(c1,c1_old,i)
169170
end
170171
end
171-
elseif fdtype == Val{:central}
172+
elseif fdtype == Val(:central)
172173
@inbounds for i eachindex(x)
173174
if ArrayInterface.fast_scalar_indexing(c2)
174175
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
@@ -191,7 +192,7 @@ function finite_difference_gradient!(
191192
ArrayInterface.allowed_setindex!(c1,c1_old, i)
192193
ArrayInterface.allowed_setindex!(c3,x_old,i)
193194
end
194-
elseif fdtype == Val{:complex} && returntype <: Real
195+
elseif fdtype == Val(:complex) && returntype <: Real
195196
copyto!(c1,x)
196197
epsilon_complex = eps(real(eltype(x)))
197198
# we use c1 here to avoid typing issues with x
@@ -219,13 +220,13 @@ function finite_difference_gradient!(
219220
# c1 is x1 if we need a complex copy of x, otherwise Nothing
220221
# c2 is Nothing
221222
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
222-
if fdtype != Val{:complex}
223+
if fdtype != Val(:complex)
223224
if eltype(df)<:Complex && !(eltype(x)<:Complex)
224225
copyto!(c1,x)
225226
end
226227
end
227228
copyto!(c3,x)
228-
if fdtype == Val{:forward}
229+
if fdtype == Val(:forward)
229230
for i eachindex(x)
230231
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
231232
x_old = x[i]
@@ -262,7 +263,7 @@ function finite_difference_gradient!(
262263
df[i] -= im * imag(dfi)
263264
end
264265
end
265-
elseif fdtype == Val{:central}
266+
elseif fdtype == Val(:central)
266267
@inbounds for i eachindex(x)
267268
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
268269
x_old = x[i]
@@ -289,7 +290,7 @@ function finite_difference_gradient!(
289290
df[i] -= im*imag(dfi / (2*im*epsilon))
290291
end
291292
end
292-
elseif fdtype==Val{:complex} && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
293+
elseif fdtype==Val(:complex) && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
293294
copyto!(c1,x)
294295
epsilon_complex = eps(real(eltype(x)))
295296
# we use c1 here to avoid typing issues with x
@@ -324,8 +325,8 @@ function finite_difference_gradient!(
324325
_c1, _c2 = c1, c2
325326
end
326327

327-
if fdtype == Val{:forward}
328-
epsilon = compute_epsilon(Val{:forward}, x, relstep, absstep, dir)
328+
if fdtype == Val(:forward)
329+
epsilon = compute_epsilon(Val(:forward), x, relstep, absstep, dir)
329330
if inplace == Val{true}
330331
f(c1, x+epsilon)
331332
else
@@ -341,8 +342,8 @@ function finite_difference_gradient!(
341342
end
342343
@. df = (_c1 - _c2) / epsilon
343344
end
344-
elseif fdtype == Val{:central}
345-
epsilon = compute_epsilon(Val{:central}, x, relstep, absstep, dir)
345+
elseif fdtype == Val(:central)
346+
epsilon = compute_epsilon(Val(:central), x, relstep, absstep, dir)
346347
if inplace == Val{true}
347348
f(c1, x+epsilon)
348349
f(c2, x-epsilon)
@@ -351,7 +352,7 @@ function finite_difference_gradient!(
351352
_c2 = f(x-epsilon)
352353
end
353354
@. df = (_c1 - _c2) / (2*epsilon)
354-
elseif fdtype == Val{:complex} && returntype <: Real
355+
elseif fdtype == Val(:complex) && returntype <: Real
355356
epsilon_complex = eps(real(eltype(x)))
356357
if inplace == Val{true}
357358
f(c1, x+im*epsilon_complex)

src/hessians.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ struct HessianCache{T,fdtype,inplace}
66
end
77

88
function HessianCache(xpp,xpm,xmp,xmm,
9-
fdtype=Val{:hcentral},
9+
fdtype=Val(:hcentral),
1010
inplace = x isa StaticArray ? Val{false} : Val{true})
11+
fdtype isa Type && (fdtype = fdtype())
1112
HessianCache{typeof(xpp),fdtype,inplace}(xpp,xpm,xmp,xmm)
1213
end
1314

14-
function HessianCache(x, fdtype=Val{:hcentral},
15+
function HessianCache(x, fdtype=Val(:hcentral),
1516
inplace = x isa StaticArray ? Val{false} : Val{true})
1617
cx = copy(x)
18+
fdtype isa Type && (fdtype = fdtype())
1719
HessianCache{typeof(cx),fdtype,inplace}(cx, copy(x), copy(x), copy(x))
1820
end
1921

2022
function finite_difference_hessian(f, x,
21-
fdtype = Val{:hcentral},
23+
fdtype = Val(:hcentral),
2224
inplace = x isa StaticArray ? Val{false} : Val{true};
2325
relstep = default_relstep(fdtype, eltype(x)),
2426
absstep = relstep)
@@ -40,7 +42,7 @@ end
4042

4143
function finite_difference_hessian!(H,f,
4244
x,
43-
fdtype = Val{:hcentral},
45+
fdtype = Val(:hcentral),
4446
inplace = x isa StaticArray ? Val{false} : Val{true};
4547
relstep=default_relstep(fdtype, eltype(x)),
4648
absstep=relstep)
@@ -54,7 +56,7 @@ function finite_difference_hessian!(H,f,x,
5456
relstep = default_relstep(fdtype, eltype(x)),
5557
absstep = relstep) where {T,fdtype,inplace}
5658

57-
@assert fdtype == Val{:hcentral}
59+
@assert fdtype == Val(:hcentral)
5860
n = length(x)
5961
xpp, xpm, xmp, xmm = cache.xpp, cache.xpm, cache.xmp, cache.xmm
6062
fx = f(x)
@@ -65,7 +67,7 @@ function finite_difference_hessian!(H,f,x,
6567

6668
for i = 1:n
6769
xi = ArrayInterface.allowed_getindex(x,i)
68-
epsilon = compute_epsilon(Val{:hcentral}, xi, relstep, absstep)
70+
epsilon = compute_epsilon(Val(:hcentral), xi, relstep, absstep)
6971

7072
if inplace === Val{true}
7173
ArrayInterface.allowed_setindex!(xpp,xi + epsilon,i)
@@ -76,7 +78,7 @@ function finite_difference_hessian!(H,f,x,
7678
end
7779

7880
ArrayInterface.allowed_setindex!(H,(f(_xpp) - 2*fx + f(_xmm)) / epsilon^2,i,i)
79-
epsiloni = compute_epsilon(Val{:central}, xi, relstep, absstep)
81+
epsiloni = compute_epsilon(Val(:central), xi, relstep, absstep)
8082
xp = xi + epsiloni
8183
xm = xi - epsiloni
8284

@@ -94,7 +96,7 @@ function finite_difference_hessian!(H,f,x,
9496

9597
for j = i+1:n
9698
xj = ArrayInterface.allowed_getindex(x,j)
97-
epsilonj = compute_epsilon(Val{:central}, xj, relstep, absstep)
99+
epsilonj = compute_epsilon(Val(:central), xj, relstep, absstep)
98100
xp = xj + epsilonj
99101
xm = xj - epsilonj
100102

0 commit comments

Comments
 (0)