Skip to content

Commit a7f159f

Browse files
Merge pull request #112 from kanav99/kg/gradient
No mutation in Input
2 parents a68da8d + e08dd89 commit a7f159f

File tree

1 file changed

+43
-38
lines changed

1 file changed

+43
-38
lines changed

src/gradients.jl

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
struct GradientCache{CacheType1, CacheType2, CacheType3, fdtype, returntype, inplace}
1+
struct GradientCache{CacheType1, CacheType2, CacheType3, CacheType4, fdtype, returntype, inplace}
22
fx :: CacheType1
33
c1 :: CacheType2
44
c2 :: CacheType3
5+
c3 :: CacheType4
56
end
67

78
function GradientCache(
@@ -33,6 +34,7 @@ function GradientCache(
3334
_c2 = nothing
3435
end
3536
end
37+
_c3 = similar(x)
3638
else # the scalar->vector case
3739
# need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
3840
if fdtype != Val{:complex}
@@ -42,10 +44,11 @@ function GradientCache(
4244
_c1 = zero(Complex{eltype(x)}) .* df
4345
_c2 = nothing
4446
end
47+
_c3 = x
4548
end
4649

47-
GradientCache{Nothing,typeof(_c1),typeof(_c2),fdtype,
48-
returntype,inplace}(nothing,_c1,_c2)
50+
GradientCache{Nothing,typeof(_c1),typeof(_c2),typeof(_c3),fdtype,
51+
returntype,inplace}(nothing,_c1,_c2,_c3)
4952

5053
end
5154

@@ -102,10 +105,10 @@ end
102105
function finite_difference_gradient(
103106
f,
104107
x,
105-
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace};
108+
cache::GradientCache{T1,T2,T3,T4,fdtype,returntype,inplace};
106109
relstep=default_relstep(fdtype, eltype(x)),
107110
absstep=relstep,
108-
dir=true) where {T1,T2,T3,fdtype,returntype,inplace}
111+
dir=true) where {T1,T2,T3,T4,fdtype,returntype,inplace}
109112

110113
if typeof(x) <: AbstractArray
111114
df = zero(returntype) .* x
@@ -122,18 +125,19 @@ function finite_difference_gradient!(
122125
df,
123126
f,
124127
x,
125-
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace};
128+
cache::GradientCache{T1,T2,T3,T4,fdtype,returntype,inplace};
126129
relstep=default_relstep(fdtype, eltype(x)),
127130
absstep=relstep,
128-
dir=true) where {T1,T2,T3,fdtype,returntype,inplace}
131+
dir=true) where {T1,T2,T3,T4,fdtype,returntype,inplace}
129132

130133
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
131134
# c1 denotes x1, c2 is epsilon
132-
fx, c1, c2 = cache.fx, cache.c1, cache.c2
135+
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
133136
if fdtype != Val{:complex} && ArrayInterface.fast_scalar_indexing(c2)
134137
@. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
135138
copyto!(c1,x)
136139
end
140+
copyto!(c3,x)
137141
if fdtype == Val{:forward}
138142
@inbounds for i eachindex(x)
139143
if ArrayInterface.fast_scalar_indexing(c2)
@@ -174,18 +178,18 @@ function finite_difference_gradient!(
174178
c1_old = ArrayInterface.allowed_getindex(c1,i)
175179
ArrayInterface.allowed_setindex!(c1,c1_old + epsilon, i)
176180
x_old = ArrayInterface.allowed_getindex(x,i)
177-
ArrayInterface.allowed_setindex!(x,x_old - epsilon,i)
178-
df_tmp = real((f(c1) - f(x)) / (2*epsilon))
181+
ArrayInterface.allowed_setindex!(c3,x_old - epsilon,i)
182+
df_tmp = real((f(c1) - f(c3)) / (2*epsilon))
179183
if eltype(df)<:Complex
180184
ArrayInterface.allowed_setindex!(c1,c1_old + im*epsilon,i)
181-
ArrayInterface.allowed_setindex!(x,x_old - im*epsilon,i)
182-
df_tmp2 = im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
185+
ArrayInterface.allowed_setindex!(c3,x_old - im*epsilon,i)
186+
df_tmp2 = im*imag( (f(c1) - f(c3)) / (2*im*epsilon) )
183187
ArrayInterface.allowed_setindex!(df,df_tmp-df_tmp2,i)
184188
else
185189
ArrayInterface.allowed_setindex!(df,df_tmp,i)
186190
end
187191
ArrayInterface.allowed_setindex!(c1,c1_old, i)
188-
ArrayInterface.allowed_setindex!(x,x_old,i)
192+
ArrayInterface.allowed_setindex!(c3,x_old,i)
189193
end
190194
elseif fdtype == Val{:complex} && returntype <: Real
191195
copyto!(c1,x)
@@ -207,44 +211,45 @@ function finite_difference_gradient!(
207211
df::StridedVector{<:Number},
208212
f,
209213
x::StridedVector{<:Number},
210-
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace};
214+
cache::GradientCache{T1,T2,T3,T4,fdtype,returntype,inplace};
211215
relstep=default_relstep(fdtype, eltype(x)),
212216
absstep=relstep,
213-
dir=true) where {T1,T2,T3,fdtype,returntype,inplace}
217+
dir=true) where {T1,T2,T3,T4,fdtype,returntype,inplace}
214218

215219
# c1 is x1 if we need a complex copy of x, otherwise Nothing
216220
# c2 is Nothing
217-
fx, c1, c2 = cache.fx, cache.c1, cache.c2
221+
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
218222
if fdtype != Val{:complex}
219223
if eltype(df)<:Complex && !(eltype(x)<:Complex)
220224
copyto!(c1,x)
221225
end
222226
end
227+
copyto!(c3,x)
223228
if fdtype == Val{:forward}
224229
for i eachindex(x)
225230
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
226231
x_old = x[i]
227232
if typeof(fx) != Nothing
228-
x[i] += epsilon
229-
dfi = (f(x) - fx) / epsilon
230-
x[i] = x_old
233+
c3[i] += epsilon
234+
dfi = (f(c3) - fx) / epsilon
235+
c3[i] = x_old
231236
else
232237
fx0 = f(x)
233-
x[i] += epsilon
234-
dfi = (f(x) - fx0) / epsilon
235-
x[i] = x_old
238+
c3[i] += epsilon
239+
dfi = (f(c3) - fx0) / epsilon
240+
c3[i] = x_old
236241
end
237242

238243
df[i] = real(dfi)
239244
if eltype(df)<:Complex
240245
if eltype(x)<:Complex
241-
x[i] += im * epsilon
246+
c3[i] += im * epsilon
242247
if typeof(fx) != Nothing
243-
dfi = (f(x) - fx) / (im*epsilon)
248+
dfi = (f(c3) - fx) / (im*epsilon)
244249
else
245-
dfi = (f(x) - fx0) / (im*epsilon)
250+
dfi = (f(c3) - fx0) / (im*epsilon)
246251
end
247-
x[i] = x_old
252+
c3[i] = x_old
248253
else
249254
c1[i] += im * epsilon
250255
if typeof(fx) != Nothing
@@ -261,19 +266,19 @@ function finite_difference_gradient!(
261266
@inbounds for i eachindex(x)
262267
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
263268
x_old = x[i]
264-
x[i] += epsilon
265-
dfi = f(x)
266-
x[i] = x_old - epsilon
267-
dfi -= f(x)
268-
x[i] = x_old
269+
c3[i] += epsilon
270+
dfi = f(c3)
271+
c3[i] = x_old - epsilon
272+
dfi -= f(c3)
273+
c3[i] = x_old
269274
df[i] = real(dfi / (2*epsilon))
270275
if eltype(df)<:Complex
271276
if eltype(x)<:Complex
272-
x[i] += im*epsilon
273-
dfi = f(x)
274-
x[i] = x_old - im*epsilon
275-
dfi -= f(x)
276-
x[i] = x_old
277+
c3[i] += im*epsilon
278+
dfi = f(c3)
279+
c3[i] = x_old - im*epsilon
280+
dfi -= f(c3)
281+
c3[i] = x_old
277282
else
278283
c1[i] += im*epsilon
279284
dfi = f(c1)
@@ -306,10 +311,10 @@ function finite_difference_gradient!(
306311
df,
307312
f,
308313
x::Number,
309-
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace};
314+
cache::GradientCache{T1,T2,T3,T4,fdtype,returntype,inplace};
310315
relstep=default_relstep(fdtype, eltype(x)),
311316
absstep=relstep,
312-
dir=true) where {T1,T2,T3,fdtype,returntype,inplace}
317+
dir=true) where {T1,T2,T3,T4,fdtype,returntype,inplace}
313318

314319
# NOTE: in this case epsilon is a scalar, we need two arrays for fx1 and fx2
315320
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor

0 commit comments

Comments
 (0)