1- struct GradientCache{CacheType1, CacheType2, CacheType3, fdtype, returntype, inplace}
1+ struct GradientCache{CacheType1, CacheType2, CacheType3, CacheType4, fdtype, returntype, inplace}
22 fx :: CacheType1
33 c1 :: CacheType2
44 c2 :: CacheType3
5+ c3 :: CacheType4
56end
67
78function GradientCache (
@@ -33,6 +34,7 @@ function GradientCache(
3334 _c2 = nothing
3435 end
3536 end
37+ _c3 = similar (x)
3638 else # the scalar->vector case
3739 # need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
3840 if fdtype != Val{:complex }
@@ -42,10 +44,11 @@ function GradientCache(
4244 _c1 = zero (Complex{eltype (x)}) .* df
4345 _c2 = nothing
4446 end
47+ _c3 = x
4548 end
4649
47- GradientCache{Nothing,typeof (_c1),typeof (_c2),fdtype,
48- returntype,inplace}(nothing ,_c1,_c2)
50+ GradientCache{Nothing,typeof (_c1),typeof (_c2),typeof (_c3), fdtype,
51+ returntype,inplace}(nothing ,_c1,_c2,_c3 )
4952
5053end
5154
@@ -102,10 +105,10 @@ end
102105function finite_difference_gradient (
103106 f,
104107 x,
105- cache:: GradientCache{T1,T2,T3,fdtype,returntype,inplace} ;
108+ cache:: GradientCache{T1,T2,T3,T4, fdtype,returntype,inplace} ;
106109 relstep= default_relstep (fdtype, eltype (x)),
107110 absstep= relstep,
108- dir= true ) where {T1,T2,T3,fdtype,returntype,inplace}
111+ dir= true ) where {T1,T2,T3,T4, fdtype,returntype,inplace}
109112
110113 if typeof (x) <: AbstractArray
111114 df = zero (returntype) .* x
@@ -122,18 +125,19 @@ function finite_difference_gradient!(
122125 df,
123126 f,
124127 x,
125- cache:: GradientCache{T1,T2,T3,fdtype,returntype,inplace} ;
128+ cache:: GradientCache{T1,T2,T3,T4, fdtype,returntype,inplace} ;
126129 relstep= default_relstep (fdtype, eltype (x)),
127130 absstep= relstep,
128- dir= true ) where {T1,T2,T3,fdtype,returntype,inplace}
131+ dir= true ) where {T1,T2,T3,T4, fdtype,returntype,inplace}
129132
130133 # NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
131134 # c1 denotes x1, c2 is epsilon
132- fx, c1, c2 = cache. fx, cache. c1, cache. c2
135+ fx, c1, c2, c3 = cache. fx, cache. c1, cache. c2, cache . c3
133136 if fdtype != Val{:complex } && ArrayInterface. fast_scalar_indexing (c2)
134137 @. c2 = compute_epsilon (fdtype, x, relstep, absstep, dir)
135138 copyto! (c1,x)
136139 end
140+ copyto! (c3,x)
137141 if fdtype == Val{:forward }
138142 @inbounds for i ∈ eachindex (x)
139143 if ArrayInterface. fast_scalar_indexing (c2)
@@ -174,18 +178,18 @@ function finite_difference_gradient!(
174178 c1_old = ArrayInterface. allowed_getindex (c1,i)
175179 ArrayInterface. allowed_setindex! (c1,c1_old + epsilon, i)
176180 x_old = ArrayInterface. allowed_getindex (x,i)
177- ArrayInterface. allowed_setindex! (x ,x_old - epsilon,i)
178- df_tmp = real ((f (c1) - f (x )) / (2 * epsilon))
181+ ArrayInterface. allowed_setindex! (c3 ,x_old - epsilon,i)
182+ df_tmp = real ((f (c1) - f (c3 )) / (2 * epsilon))
179183 if eltype (df)<: Complex
180184 ArrayInterface. allowed_setindex! (c1,c1_old + im* epsilon,i)
181- ArrayInterface. allowed_setindex! (x ,x_old - im* epsilon,i)
182- df_tmp2 = im* imag ( (f (c1) - f (x )) / (2 * im* epsilon) )
185+ ArrayInterface. allowed_setindex! (c3 ,x_old - im* epsilon,i)
186+ df_tmp2 = im* imag ( (f (c1) - f (c3 )) / (2 * im* epsilon) )
183187 ArrayInterface. allowed_setindex! (df,df_tmp- df_tmp2,i)
184188 else
185189 ArrayInterface. allowed_setindex! (df,df_tmp,i)
186190 end
187191 ArrayInterface. allowed_setindex! (c1,c1_old, i)
188- ArrayInterface. allowed_setindex! (x ,x_old,i)
192+ ArrayInterface. allowed_setindex! (c3 ,x_old,i)
189193 end
190194 elseif fdtype == Val{:complex } && returntype <: Real
191195 copyto! (c1,x)
@@ -207,44 +211,45 @@ function finite_difference_gradient!(
207211 df:: StridedVector{<:Number} ,
208212 f,
209213 x:: StridedVector{<:Number} ,
210- cache:: GradientCache{T1,T2,T3,fdtype,returntype,inplace} ;
214+ cache:: GradientCache{T1,T2,T3,T4, fdtype,returntype,inplace} ;
211215 relstep= default_relstep (fdtype, eltype (x)),
212216 absstep= relstep,
213- dir= true ) where {T1,T2,T3,fdtype,returntype,inplace}
217+ dir= true ) where {T1,T2,T3,T4, fdtype,returntype,inplace}
214218
215219 # c1 is x1 if we need a complex copy of x, otherwise Nothing
216220 # c2 is Nothing
217- fx, c1, c2 = cache. fx, cache. c1, cache. c2
221+ fx, c1, c2, c3 = cache. fx, cache. c1, cache. c2, cache . c3
218222 if fdtype != Val{:complex }
219223 if eltype (df)<: Complex && ! (eltype (x)<: Complex )
220224 copyto! (c1,x)
221225 end
222226 end
227+ copyto! (c3,x)
223228 if fdtype == Val{:forward }
224229 for i ∈ eachindex (x)
225230 epsilon = compute_epsilon (fdtype, x[i], relstep, absstep, dir)
226231 x_old = x[i]
227232 if typeof (fx) != Nothing
228- x [i] += epsilon
229- dfi = (f (x ) - fx) / epsilon
230- x [i] = x_old
233+ c3 [i] += epsilon
234+ dfi = (f (c3 ) - fx) / epsilon
235+ c3 [i] = x_old
231236 else
232237 fx0 = f (x)
233- x [i] += epsilon
234- dfi = (f (x ) - fx0) / epsilon
235- x [i] = x_old
238+ c3 [i] += epsilon
239+ dfi = (f (c3 ) - fx0) / epsilon
240+ c3 [i] = x_old
236241 end
237242
238243 df[i] = real (dfi)
239244 if eltype (df)<: Complex
240245 if eltype (x)<: Complex
241- x [i] += im * epsilon
246+ c3 [i] += im * epsilon
242247 if typeof (fx) != Nothing
243- dfi = (f (x ) - fx) / (im* epsilon)
248+ dfi = (f (c3 ) - fx) / (im* epsilon)
244249 else
245- dfi = (f (x ) - fx0) / (im* epsilon)
250+ dfi = (f (c3 ) - fx0) / (im* epsilon)
246251 end
247- x [i] = x_old
252+ c3 [i] = x_old
248253 else
249254 c1[i] += im * epsilon
250255 if typeof (fx) != Nothing
@@ -261,19 +266,19 @@ function finite_difference_gradient!(
261266 @inbounds for i ∈ eachindex (x)
262267 epsilon = compute_epsilon (fdtype, x[i], relstep, absstep, dir)
263268 x_old = x[i]
264- x [i] += epsilon
265- dfi = f (x )
266- x [i] = x_old - epsilon
267- dfi -= f (x )
268- x [i] = x_old
269+ c3 [i] += epsilon
270+ dfi = f (c3 )
271+ c3 [i] = x_old - epsilon
272+ dfi -= f (c3 )
273+ c3 [i] = x_old
269274 df[i] = real (dfi / (2 * epsilon))
270275 if eltype (df)<: Complex
271276 if eltype (x)<: Complex
272- x [i] += im* epsilon
273- dfi = f (x )
274- x [i] = x_old - im* epsilon
275- dfi -= f (x )
276- x [i] = x_old
277+ c3 [i] += im* epsilon
278+ dfi = f (c3 )
279+ c3 [i] = x_old - im* epsilon
280+ dfi -= f (c3 )
281+ c3 [i] = x_old
277282 else
278283 c1[i] += im* epsilon
279284 dfi = f (c1)
@@ -306,10 +311,10 @@ function finite_difference_gradient!(
306311 df,
307312 f,
308313 x:: Number ,
309- cache:: GradientCache{T1,T2,T3,fdtype,returntype,inplace} ;
314+ cache:: GradientCache{T1,T2,T3,T4, fdtype,returntype,inplace} ;
310315 relstep= default_relstep (fdtype, eltype (x)),
311316 absstep= relstep,
312- dir= true ) where {T1,T2,T3,fdtype,returntype,inplace}
317+ dir= true ) where {T1,T2,T3,T4, fdtype,returntype,inplace}
313318
314319 # NOTE: in this case epsilon is a scalar, we need two arrays for fx1 and fx2
315320 # c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
0 commit comments