Skip to content

Commit ebf1625

Browse files
committed
Fix thread safety issue in finite_difference_jacobian! when using
central differencing. Fix initialization of HessianCache Test thread safety for gradients, Jacobians, and Hessians
1 parent f564a08 commit ebf1625

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

src/hessians.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ function HessianCache(xpp,xpm,xmp,xmm,
1111
HessianCache{typeof(xpp),fdtype,inplace}(xpp,xpm,xmp,xmm)
1212
end
1313

14-
function HessianCache(x,fdtype=Val{:hcentral},
15-
inplace = x isa StaticArray ? Val{false} : Val{true})
16-
HessianCache{typeof(x),fdtype,inplace}(copy(x),copy(x),copy(x),copy(x))
14+
function HessianCache(x, fdtype=Val{:hcentral},
15+
inplace = x isa StaticArray ? Val{false} : Val{true})
16+
cx = copy(x)
17+
HessianCache{typeof(cx),fdtype,inplace}(cx, copy(x), copy(x), copy(x))
1718
end
1819

1920
function finite_difference_hessian(f, x,

src/jacobians.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,13 @@ function finite_difference_jacobian!(
356356
if sparsity isa Nothing
357357
x1_save = ArrayInterface.allowed_getindex(x1,color_i)
358358
epsilon = compute_epsilon(Val{:forward}, x1_save, relstep, absstep, dir)
359-
ArrayInterface.allowed_setindex!(x1,x1_save + epsilon,color_i)
359+
ArrayInterface.allowed_setindex!(x1, x1_save + epsilon, color_i)
360360
f(fx1, x1)
361361
# J is dense, so either it is truly dense or this is the
362362
# compressed form of the coloring, so write into it.
363363
@. J[:,color_i] = (vfx1 - vfx) / epsilon
364364
# Now return x1 back to its original value
365-
ArrayInterface.allowed_setindex!(x1,x1_save,color_i)
365+
ArrayInterface.allowed_setindex!(x1, x1_save, color_i)
366366
else # Perturb along the colorvec vector
367367
@. fx1 = x1 * (_color == color_i)
368368
tmp = norm(fx1)
@@ -381,9 +381,9 @@ function finite_difference_jacobian!(
381381
+= means requires a zero'd out start
382382
=#
383383
if J isa SparseMatrixCSC
384-
@. void_setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx1,),rows_index),rows_index)
384+
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
385385
else
386-
@. void_setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx1,),rows_index),rows_index, cols_index)
386+
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index)
387387
end
388388
end
389389
# Now return x1 back to its original value
@@ -394,16 +394,14 @@ function finite_difference_jacobian!(
394394
vfx1 = _vec(fx1)
395395
@inbounds for color_i 1:maximum(colorvec)
396396
if sparsity isa Nothing
397-
x_save = ArrayInterface.allowed_getindex(x,color_i)
398-
x1_save = ArrayInterface.allowed_getindex(x1,color_i)
397+
x_save = ArrayInterface.allowed_getindex(x, color_i)
399398
epsilon = compute_epsilon(Val{:central}, x_save, relstep, absstep, dir)
400-
ArrayInterface.allowed_setindex!(x1,x1_save+epsilon,color_i)
401-
ArrayInterface.allowed_setindex!(x,x_save-epsilon,color_i)
399+
ArrayInterface.allowed_setindex!(x1, x_save + epsilon, color_i)
402400
f(fx1, x1)
403-
f(fx, x)
401+
ArrayInterface.allowed_setindex!(x1, x_save - epsilon, color_i)
402+
f(fx, x1)
404403
@. J[:,color_i] = (vfx1 - vfx) / 2epsilon
405-
ArrayInterface.allowed_setindex!(x1,x1_save,color_i)
406-
ArrayInterface.allowed_setindex!(x,x_save,color_i)
404+
ArrayInterface.allowed_setindex!(x1, x_save, color_i)
407405
else # Perturb along the colorvec vector
408406
@. fx1 = x1 * (_color == color_i)
409407
tmp = norm(fx1)
@@ -417,9 +415,9 @@ function finite_difference_jacobian!(
417415
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
418416
else
419417
if J isa SparseMatrixCSC
420-
@. void_setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx1,),rows_index),rows_index)
418+
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
421419
else
422-
@. void_setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx1,),rows_index),rows_index, cols_index)
420+
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index)
423421
end
424422
end
425423
@. x1 = x1 - epsilon * (_color == color_i)
@@ -430,11 +428,11 @@ function finite_difference_jacobian!(
430428
epsilon = eps(eltype(x))
431429
@inbounds for color_i 1:maximum(colorvec)
432430
if sparsity isa Nothing
433-
x1_save = ArrayInterface.allowed_getindex(x1,color_i)
434-
ArrayInterface.allowed_setindex!(x1,x1_save + im*epsilon, color_i)
431+
x1_save = ArrayInterface.allowed_getindex(x1, color_i)
432+
ArrayInterface.allowed_setindex!(x1, x1_save + im*epsilon, color_i)
435433
f(fx,x1)
436434
@. J[:,color_i] = imag(vfx) / epsilon
437-
ArrayInterface.allowed_setindex!(x1,x1_save,color_i)
435+
ArrayInterface.allowed_setindex!(x1, x1_save,color_i)
438436
else # Perturb along the colorvec vector
439437
@. x1 = x1 + im * epsilon * (_color == color_i)
440438
f(fx,x1)
@@ -443,9 +441,9 @@ function finite_difference_jacobian!(
443441
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
444442
else
445443
if J isa SparseMatrixCSC
446-
@. void_setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx,),rows_index),rows_index)
444+
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,),rows_index), rows_index)
447445
else
448-
@. void_setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((_color,),cols_index) == color_i) * getindex((vfx,),rows_index),rows_index, cols_index)
446+
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index, cols_index)
449447
end
450448
end
451449
@. x1 = x1 - im * epsilon * (_color == color_i)

test/finitedifftests.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,33 @@ H = similar(H_ref)
425425
@test err_func(FiniteDiff.finite_difference_hessian!(H, f, x), H_ref) < 1e-4
426426
@test err_func(FiniteDiff.finite_difference_hessian!(H, f, x, hcache), H_ref) < 1e-4
427427
end
428+
429+
# Thread safety
430+
# create an an abstract array type that doesn't allow setindex
431+
struct ImmutableVector <: DenseVector{Float64}
432+
x::Vector{Float64}
433+
end
434+
Base.size(x::ImmutableVector) = size(x.x)
435+
Base.getindex(x::ImmutableVector, i::Integer) = x.x[i]
436+
@testset "thread safety" begin
437+
@testset "Gradients with diff type $difftype" for difftype in (Val{:forward}, Val{:central}, Val{:complex})
438+
g = FiniteDiff.finite_difference_gradient(sum, ImmutableVector(ones(2)), difftype)
439+
@test g ones(2)
440+
FiniteDiff.finite_difference_gradient!(g, sum, ImmutableVector(ones(2)), difftype)
441+
@test g ones(2)
442+
end
443+
444+
@testset "Hessians (only supported diff type is :hcentral)" begin
445+
H = FiniteDiff.finite_difference_hessian(t -> sum(abs2, t)/2, ImmutableVector(ones(2)))
446+
@test H Matrix(I, 2, 2)
447+
FiniteDiff.finite_difference_hessian!(parent(H), t -> sum(abs2, t)/2, ImmutableVector(ones(2)))
448+
@test H Matrix(I, 2, 2)
449+
end
450+
451+
@testset "Jacobians with diff type $difftype" for difftype in (Val{:forward}, Val{:central}, Val{:complex})
452+
J = FiniteDiff.finite_difference_jacobian(identity, ImmutableVector(ones(2)), difftype)
453+
@test J Matrix(I, 2, 2)
454+
FiniteDiff.finite_difference_jacobian!(J, (out, in) -> out .= in, ImmutableVector(ones(2)), difftype)
455+
@test J Matrix(I, 2, 2)
456+
end
457+
end

0 commit comments

Comments
 (0)