From 55ddcacd9a55290c175cd86e547a65c9e610a4a2 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sun, 17 Sep 2023 17:42:04 +0200 Subject: [PATCH 1/4] Add Tensors.extract_value --- docs/src/man/automatic_differentiation.md | 25 +++++++++++++++++++++++ src/automatic_differentiation.jl | 10 +++++++++ test/test_ad.jl | 13 +++++++++++- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/docs/src/man/automatic_differentiation.md b/docs/src/man/automatic_differentiation.md index a09c91b6..df43a993 100644 --- a/docs/src/man/automatic_differentiation.md +++ b/docs/src/man/automatic_differentiation.md @@ -93,6 +93,31 @@ julia> norm(majorsymmetric(E) - E_sym) 0.0 ``` +## Differentiating mutating functions +Some applications require the derivative of the output of a function `f(x,s)`, +wrt. `x`, where `f` also mutates `s`. +In these cases, we don't want the derivative of `s` wrt. `x`, and +the value to be set in `s` should only be the value and not be the dual part. +For scalars, `ForwardDiff.jl` provides `ForwardDiff.value`, +and for tensors, `Tensors.jl` provides `Tensors.extract_value`. + +```@docs +Tensors.extract_value +``` + +A simple example of the use-case is +```@example +function mutating_fun(x::Vec, state::Vector) + state[1] = Tensors.extract_value(x) + return x +end + +x = rand(Vec{2}); state = zeros(Vec{2}, 1) +gradient(a -> mutating_fun(a, state, true), x) +# Check that it got correctly modified by the extracted value +state[1] == x +``` + ## Inserting a known derivative When conditionals are used in a function evaluation, automatic differentiation may yield the wrong result. Consider, the simplified example of the function diff --git a/src/automatic_differentiation.jl b/src/automatic_differentiation.jl index 277dfec3..858dbddb 100644 --- a/src/automatic_differentiation.jl +++ b/src/automatic_differentiation.jl @@ -19,6 +19,16 @@ end # Value extraction # #################### +""" + extract_value(v::AbstractTensor) + +If `v` is used in a differentiation, such that +`eltype(v)::ForwardDiff.Dual`, extract the value-part of the derivative. +Otherwise, just return `v`. +""" +extract_value(v::AbstractTensor{<:Any,<:Any,<:Dual}) = _extract_value(v) +extract_value(v::AbstractTensor) = v + # Scalar output -> Scalar value """ function _extract_value(v::ForwardDiff.Dual) diff --git a/test/test_ad.jl b/test/test_ad.jl index cb0d76b3..51d62166 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -327,6 +327,17 @@ S(C) = S(C, μ, Kb) end end - + + @testset "value_extraction" begin + function mutating_fun(x::Vec, state::Vector, use_extract::Bool) + state[1] = use_extract ? Tensors.extract_value(x) : x + return x + end + T = Vec{2,Float64} + x = rand(T); state = zeros(T, 1) + gradient(a -> mutating_fun(a, state, true), x) + @test state[1] == x # Check that it got correctly modified by the extracted value + @test_throws MethodError gradient(a -> mutating_fun(a, state, false), x) + end end # testsection From ac6fdd4849dc116758c475094d59c030644387e1 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sun, 17 Sep 2023 20:24:23 +0200 Subject: [PATCH 2/4] Add one more test --- test/test_ad.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/test_ad.jl b/test/test_ad.jl index 51d62166..60c912bd 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -329,15 +329,20 @@ S(C) = S(C, μ, Kb) end @testset "value_extraction" begin - function mutating_fun(x::Vec, state::Vector, use_extract::Bool) - state[1] = use_extract ? Tensors.extract_value(x) : x + function mutating_fun(x::Vec, state::Vector; use_extract::Bool, contract::Bool) + v = contract ? x⋅x : x + state[1] = use_extract ? Tensors.extract_value(v) : v return x end - T = Vec{2,Float64} - x = rand(T); state = zeros(T, 1) - gradient(a -> mutating_fun(a, state, true), x) - @test state[1] == x # Check that it got correctly modified by the extracted value - @test_throws MethodError gradient(a -> mutating_fun(a, state, false), x) + TT = Vec{2,Float64} + x = rand(TT); state = zeros(TT, 1) + gradient(a -> mutating_fun(a, state; use_extract=true, contract=false), x) + # Check that it got correctly modified by the extracted value + @test state[1] == x + # Check that test works: Should fail if no extract_value is not used + @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=false), x) + # Do not allow extract_value on a <:Real + @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=true), x) end end # testsection From 5e58596d3b2f6765a5f107fa9ad6969dc7de818c Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sun, 17 Sep 2023 21:17:50 +0200 Subject: [PATCH 3/4] Add further test to complete coverage --- test/test_ad.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_ad.jl b/test/test_ad.jl index 60c912bd..f399ce71 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -343,6 +343,10 @@ S(C) = S(C, μ, Kb) @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=false), x) # Do not allow extract_value on a <:Real @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=true), x) + # Check that it get correctly modified when not differentiating + x = rand(TT); + mutating_fun(x, state; use_extract=true, contract=true) + @test state[1] == x end end # testsection From f9e1efeb6c9513dd5c0dafbc5be9b823f1fa2711 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sun, 17 Sep 2023 21:49:45 +0200 Subject: [PATCH 4/4] Typo --- test/test_ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ad.jl b/test/test_ad.jl index f399ce71..51465f4b 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -345,7 +345,7 @@ S(C) = S(C, μ, Kb) @test_throws MethodError gradient(a -> mutating_fun(a, state; use_extract=false, contract=true), x) # Check that it get correctly modified when not differentiating x = rand(TT); - mutating_fun(x, state; use_extract=true, contract=true) + mutating_fun(x, state; use_extract=true, contract=false) @test state[1] == x end