Skip to content

Commit 2d81794

Browse files
Merge pull request #116 from JuliaDiff/myb/instance
Make FiniteDiff take both instances and types
2 parents b8178f1 + eae102d commit 2d81794

File tree

7 files changed

+125
-108
lines changed

7 files changed

+125
-108
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDiff"
22
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
3-
version = "2.6.0"
3+
version = "2.7.0"
44

55
[deps]
66
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/derivatives.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@ Single-point derivatives of scalar->scalar maps.
44
function finite_difference_derivative(
55
f,
66
x::T,
7-
fdtype=Val{:central},
7+
fdtype=Val(:central),
88
returntype=eltype(x),
99
f_x=nothing;
1010
relstep=default_relstep(fdtype, T),
1111
absstep=relstep,
1212
dir=true) where {T<:Number}
1313

14+
fdtype isa Type && (fdtype = fdtype())
1415
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
15-
if fdtype==Val{:forward}
16+
if fdtype==Val(:forward)
1617
return (f(x+epsilon) - f(x)) / epsilon
17-
elseif fdtype==Val{:central}
18+
elseif fdtype==Val(:central)
1819
return (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
19-
elseif fdtype==Val{:complex} && returntype<:Real
20+
elseif fdtype==Val(:complex) && returntype<:Real
2021
return imag(f(x+im*epsilon)) / epsilon
2122
end
2223
fdtype_error(returntype)
@@ -36,15 +37,16 @@ function DerivativeCache(
3637
x :: AbstractArray{<:Number},
3738
fx :: Union{Nothing,AbstractArray{<:Number}} = nothing,
3839
epsilon :: Union{Nothing,AbstractArray{<:Real}} = nothing,
39-
fdtype :: Type{T1} = Val{:central},
40+
fdtype :: Type{T1} = Val(:central),
4041
returntype :: Type{T2} = eltype(x)) where {T1,T2}
4142

42-
if fdtype==Val{:complex} && !(eltype(returntype)<:Real)
43+
fdtype isa Type && (fdtype = fdtype())
44+
if fdtype==Val(:complex) && !(eltype(returntype)<:Real)
4345
fdtype_error(returntype)
4446
end
4547

46-
if fdtype!=Val{:forward} && typeof(fx)!=Nothing
47-
@warn("Pre-computed function values are only useful for fdtype==Val{:forward}.")
48+
if fdtype!=Val(:forward) && typeof(fx)!=Nothing
49+
@warn("Pre-computed function values are only useful for fdtype==Val(:forward).")
4850
_fx = nothing
4951
else
5052
# more runtime sanity checks?
@@ -54,8 +56,8 @@ function DerivativeCache(
5456
if typeof(epsilon)!=Nothing && typeof(x)<:StridedArray && typeof(fx)<:Union{Nothing,StridedArray} && 1==2
5557
@warn("StridedArrays don't benefit from pre-allocating epsilon.")
5658
_epsilon = nothing
57-
elseif typeof(epsilon)!=Nothing && fdtype==Val{:complex}
58-
@warn("Val{:complex} makes the epsilon array redundant.")
59+
elseif typeof(epsilon)!=Nothing && fdtype==Val(:complex)
60+
@warn("Val(:complex) makes the epsilon array redundant.")
5961
_epsilon = nothing
6062
else
6163
if typeof(epsilon)==Nothing || eltype(epsilon)!=real(eltype(x))
@@ -72,7 +74,7 @@ Compute the derivative df of a scalar-valued map f at a collection of points x.
7274
function finite_difference_derivative(
7375
f,
7476
x,
75-
fdtype = Val{:central},
77+
fdtype = Val(:central),
7678
returntype = eltype(x), # return type of f
7779
fx = nothing,
7880
epsilon = nothing;
@@ -87,7 +89,7 @@ function finite_difference_derivative!(
8789
df,
8890
f,
8991
x,
90-
fdtype = Val{:central},
92+
fdtype = Val(:central),
9193
returntype = eltype(x),
9294
fx = nothing,
9395
epsilon = nothing;
@@ -111,15 +113,15 @@ function finite_difference_derivative!(
111113
if typeof(epsilon) != Nothing
112114
@. epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)
113115
end
114-
if fdtype == Val{:forward}
116+
if fdtype == Val(:forward)
115117
if typeof(fx) == Nothing
116118
@. df = (f(x+epsilon) - f(x)) / epsilon
117119
else
118120
@. df = (f(x+epsilon) - fx) / epsilon
119121
end
120-
elseif fdtype == Val{:central}
122+
elseif fdtype == Val(:central)
121123
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
122-
elseif fdtype == Val{:complex} && returntype<:Real
124+
elseif fdtype == Val(:complex) && returntype<:Real
123125
epsilon_complex = eps(eltype(x))
124126
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
125127
else
@@ -142,25 +144,25 @@ function finite_difference_derivative!(
142144
absstep=relstep,
143145
dir=true) where {T1,T2,fdtype,returntype}
144146

145-
if fdtype == Val{:forward}
147+
if fdtype == Val(:forward)
146148
fx = cache.fx
147149
@inbounds for i eachindex(x)
148-
epsilon = compute_epsilon(Val{:forward}, x[i], relstep, absstep, dir)
150+
epsilon = compute_epsilon(Val(:forward), x[i], relstep, absstep, dir)
149151
x_plus = x[i] + epsilon
150152
if typeof(fx) == Nothing
151153
df[i] = (f(x_plus) - f(x[i])) / epsilon
152154
else
153155
df[i] = (f(x_plus) - fx[i]) / epsilon
154156
end
155157
end
156-
elseif fdtype == Val{:central}
158+
elseif fdtype == Val(:central)
157159
@inbounds for i eachindex(x)
158-
epsilon = compute_epsilon(Val{:central}, x[i], relstep, absstep, dir)
160+
epsilon = compute_epsilon(Val(:central), x[i], relstep, absstep, dir)
159161
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
160162
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
161163
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
162164
end
163-
elseif fdtype == Val{:complex}
165+
elseif fdtype == Val(:complex)
164166
epsilon_complex = eps(eltype(x))
165167
@inbounds for i eachindex(x)
166168
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex

src/epsilons.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,39 @@ Very heavily inspired by Calculus.jl, but with an emphasis on performance and Di
66
Compute the finite difference interval epsilon.
77
Reference: Numerical Recipes, chapter 5.7.
88
=#
9-
@inline function compute_epsilon(::Type{Val{:forward}}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
9+
@inline function compute_epsilon(::Val{:forward}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
1010
return max(relstep*abs(x), absstep)*dir
1111
end
1212

13-
@inline function compute_epsilon(::Type{Val{:central}}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
13+
@inline function compute_epsilon(::Val{:central}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1414
return max(relstep*abs(x), absstep)
1515
end
1616

17-
@inline function compute_epsilon(::Type{Val{:hcentral}}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
17+
@inline function compute_epsilon(::Val{:hcentral}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1818
return max(relstep*abs(x), absstep)
1919
end
2020

21-
@inline function compute_epsilon(::Type{Val{:complex}}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
21+
@inline function compute_epsilon(::Val{:complex}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
2222
return eps(T)
2323
end
2424

25-
@inline function default_relstep(fdtype::DataType, ::Type{T}) where T<:Number
26-
if fdtype==Val{:forward}
25+
default_relstep(v::Type, T) = default_relstep(v(), T)
26+
@inline function default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype,T<:Number}
27+
if fdtype==:forward
2728
return sqrt(eps(real(T)))
28-
elseif fdtype==Val{:central}
29+
elseif fdtype==:central
2930
return cbrt(eps(real(T)))
30-
elseif fdtype==Val{:hcentral}
31+
elseif fdtype==:hcentral
3132
eps(T)^(1/4)
3233
else
3334
return one(real(T))
3435
end
3536
end
3637

37-
function fdtype_error(funtype::Type{T}=Float64) where T
38-
if funtype<:Real
38+
function fdtype_error(::Type{T}=Float64) where T
39+
if T<:Real
3940
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
40-
elseif funtype<:Complex
41+
elseif T<:Complex
4142
error("Unrecognized fdtype: valid values are Val{:forward} or Val{:central}.")
4243
else
4344
error("Unrecognized returntype: should be a subtype of Real or Complex.")

src/gradients.jl

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ end
88
function GradientCache(
99
df,
1010
x,
11-
fdtype = Val{:central},
11+
fdtype = Val(:central),
1212
returntype = eltype(df),
13-
inplace = Val{true})
13+
inplace = Val(true))
1414

15+
fdtype isa Type && (fdtype = fdtype())
16+
inplace isa Type && (inplace = inplace())
1517
if typeof(x)<:AbstractArray # the vector->scalar case
16-
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
18+
if fdtype!=Val(:complex) # complex-mode FD only needs one cache, for x+eps*im
1719
if typeof(x)<:StridedVector
1820
if eltype(df)<:Complex && !(eltype(x)<:Complex)
1921
_c1 = zero(Complex{eltype(x)}) .* x
@@ -37,7 +39,7 @@ function GradientCache(
3739
_c3 = similar(x)
3840
else # the scalar->vector case
3941
# need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
40-
if fdtype != Val{:complex}
42+
if fdtype != Val(:complex)
4143
_c1 = similar(df)
4244
_c2 = similar(df)
4345
else
@@ -55,20 +57,21 @@ end
5557
function finite_difference_gradient(
5658
f,
5759
x,
58-
fdtype = Val{:central},
60+
fdtype = Val(:central),
5961
returntype = eltype(x),
60-
inplace = Val{true},
62+
inplace = Val(true),
6163
fx = nothing,
6264
c1 = nothing,
6365
c2 = nothing;
6466
relstep=default_relstep(fdtype, eltype(x)),
6567
absstep=relstep,
6668
dir=true)
6769

70+
inplace isa Type && (inplace = inplace())
6871
if typeof(x) <: AbstractArray
6972
df = zero(returntype) .* x
7073
else
71-
if inplace == Val{true}
74+
if inplace == Val(true)
7275
if typeof(fx)==Nothing && typeof(c1)==Nothing && typeof(c2)==Nothing
7376
error("In the scalar->vector in-place map case, at least one of fx, c1 or c2 must be provided, otherwise we cannot infer the return size.")
7477
else
@@ -89,9 +92,9 @@ function finite_difference_gradient!(
8992
df,
9093
f,
9194
x,
92-
fdtype=Val{:central},
95+
fdtype=Val(:central),
9396
returntype=eltype(df),
94-
inplace=Val{true},
97+
inplace=Val(true),
9598
fx=nothing,
9699
c1=nothing,
97100
c2=nothing;
@@ -133,12 +136,12 @@ function finite_difference_gradient!(
133136
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
134137
# c1 denotes x1, c2 is epsilon
135138
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
136-
if fdtype != Val{:complex} && ArrayInterface.fast_scalar_indexing(c2)
139+
if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2)
137140
@. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
138141
copyto!(c1,x)
139142
end
140143
copyto!(c3,x)
141-
if fdtype == Val{:forward}
144+
if fdtype == Val(:forward)
142145
@inbounds for i eachindex(x)
143146
if ArrayInterface.fast_scalar_indexing(c2)
144147
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
@@ -168,7 +171,7 @@ function finite_difference_gradient!(
168171
ArrayInterface.allowed_setindex!(c1,c1_old,i)
169172
end
170173
end
171-
elseif fdtype == Val{:central}
174+
elseif fdtype == Val(:central)
172175
@inbounds for i eachindex(x)
173176
if ArrayInterface.fast_scalar_indexing(c2)
174177
epsilon = ArrayInterface.allowed_getindex(c2,i)*dir
@@ -191,7 +194,7 @@ function finite_difference_gradient!(
191194
ArrayInterface.allowed_setindex!(c1,c1_old, i)
192195
ArrayInterface.allowed_setindex!(c3,x_old,i)
193196
end
194-
elseif fdtype == Val{:complex} && returntype <: Real
197+
elseif fdtype == Val(:complex) && returntype <: Real
195198
copyto!(c1,x)
196199
epsilon_complex = eps(real(eltype(x)))
197200
# we use c1 here to avoid typing issues with x
@@ -219,13 +222,13 @@ function finite_difference_gradient!(
219222
# c1 is x1 if we need a complex copy of x, otherwise Nothing
220223
# c2 is Nothing
221224
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
222-
if fdtype != Val{:complex}
225+
if fdtype != Val(:complex)
223226
if eltype(df)<:Complex && !(eltype(x)<:Complex)
224227
copyto!(c1,x)
225228
end
226229
end
227230
copyto!(c3,x)
228-
if fdtype == Val{:forward}
231+
if fdtype == Val(:forward)
229232
for i eachindex(x)
230233
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
231234
x_old = x[i]
@@ -262,7 +265,7 @@ function finite_difference_gradient!(
262265
df[i] -= im * imag(dfi)
263266
end
264267
end
265-
elseif fdtype == Val{:central}
268+
elseif fdtype == Val(:central)
266269
@inbounds for i eachindex(x)
267270
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
268271
x_old = x[i]
@@ -289,7 +292,7 @@ function finite_difference_gradient!(
289292
df[i] -= im*imag(dfi / (2*im*epsilon))
290293
end
291294
end
292-
elseif fdtype==Val{:complex} && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
295+
elseif fdtype==Val(:complex) && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
293296
copyto!(c1,x)
294297
epsilon_complex = eps(real(eltype(x)))
295298
# we use c1 here to avoid typing issues with x
@@ -320,40 +323,40 @@ function finite_difference_gradient!(
320323
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
321324
fx, c1, c2 = cache.fx, cache.c1, cache.c2
322325

323-
if inplace == Val{true}
326+
if inplace == Val(true)
324327
_c1, _c2 = c1, c2
325328
end
326329

327-
if fdtype == Val{:forward}
328-
epsilon = compute_epsilon(Val{:forward}, x, relstep, absstep, dir)
329-
if inplace == Val{true}
330+
if fdtype == Val(:forward)
331+
epsilon = compute_epsilon(Val(:forward), x, relstep, absstep, dir)
332+
if inplace == Val(true)
330333
f(c1, x+epsilon)
331334
else
332335
_c1 = f(x+epsilon)
333336
end
334337
if typeof(fx) != Nothing
335338
@. df = (_c1 - fx) / epsilon
336339
else
337-
if inplace == Val{true}
340+
if inplace == Val(true)
338341
f(c2, x)
339342
else
340343
_c2 = f(x)
341344
end
342345
@. df = (_c1 - _c2) / epsilon
343346
end
344-
elseif fdtype == Val{:central}
345-
epsilon = compute_epsilon(Val{:central}, x, relstep, absstep, dir)
346-
if inplace == Val{true}
347+
elseif fdtype == Val(:central)
348+
epsilon = compute_epsilon(Val(:central), x, relstep, absstep, dir)
349+
if inplace == Val(true)
347350
f(c1, x+epsilon)
348351
f(c2, x-epsilon)
349352
else
350353
_c1 = f(x+epsilon)
351354
_c2 = f(x-epsilon)
352355
end
353356
@. df = (_c1 - _c2) / (2*epsilon)
354-
elseif fdtype == Val{:complex} && returntype <: Real
357+
elseif fdtype == Val(:complex) && returntype <: Real
355358
epsilon_complex = eps(real(eltype(x)))
356-
if inplace == Val{true}
359+
if inplace == Val(true)
357360
f(c1, x+im*epsilon_complex)
358361
else
359362
_c1 = f(x+im*epsilon_complex)

0 commit comments

Comments
 (0)