diff --git a/docs/src/man/automatic_differentiation.md b/docs/src/man/automatic_differentiation.md index a09c91b6..df43a993 100644 --- a/docs/src/man/automatic_differentiation.md +++ b/docs/src/man/automatic_differentiation.md @@ -93,6 +93,31 @@ julia> norm(majorsymmetric(E) - E_sym) 0.0 ``` +## Differentiating mutating functions +Some applications require the derivative of the output of a function `f(x,s)`, +wrt. `x`, where `f` also mutates `s`. +In these cases, we don't want the derivative of `s` wrt. `x`, and +the value to be set in `s` should only be the value and not be the dual part. +For scalars, `ForwardDiff.jl` provides `ForwardDiff.value`, +and for tensors, `Tensors.jl` provides `Tensors.extract_value`. + +```@docs +Tensors.extract_value +``` + +A simple example of the use-case is +```@example +function mutating_fun(x::Vec, state::Vector) + state[1] = Tensors.extract_value(x) + return x +end + +x = rand(Vec{2}); state = zeros(Vec{2}, 1) +gradient(a -> mutating_fun(a, state, true), x) +# Check that it got correctly modified by the extracted value +state[1] == x +``` + ## Inserting a known derivative When conditionals are used in a function evaluation, automatic differentiation may yield the wrong result. Consider, the simplified example of the function diff --git a/src/automatic_differentiation.jl b/src/automatic_differentiation.jl index d0aaa877..a2eda3a9 100644 --- a/src/automatic_differentiation.jl +++ b/src/automatic_differentiation.jl @@ -19,6 +19,16 @@ end # Value extraction # #################### +""" + extract_value(v::AbstractTensor) + +If `v` is used in a differentiation, such that +`eltype(v)::ForwardDiff.Dual`, extract the value-part of the derivative. +Otherwise, just return `v`. +""" +extract_value(v::AbstractTensor{<:Any,<:Any,<:Dual}) = _extract_value(v) +extract_value(v::AbstractTensor) = v + # Scalar output -> Scalar value """ function _extract_value(v::ForwardDiff.Dual) diff --git a/test/test_ad.jl b/test/test_ad.jl index bb696b4c..d6ad37b8 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -337,5 +337,26 @@ S(C) = S(C, μ, Kb) end end + + @testset "value_extraction" begin + function mutating_fun(x::Vec, state::Vector; use_extract::Bool, contract::Bool) + v = contract ? x⋅x : x + state[1] = use_extract ? Tensors.extract_value(v) : v + return x + end + TT = Vec{2,Float64} + x = rand(TT); state = zeros(TT, 1) + gradient(a -> mutating_fun(a, state; use_extract=true, contract=false), x) + # Check that it got correctly modified by the extracted value + @test state[1] == x + # Check that test works: Should fail if no extract_value is not used + @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=false), x) + # Do not allow extract_value on a <:Real + @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=true), x) + # Check that it get correctly modified when not differentiating + x = rand(TT); + mutating_fun(x, state; use_extract=true, contract=false) + @test state[1] == x + end end # testsection