diff --git a/Project.toml b/Project.toml index b060ada85..93e5e3456 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.4.0" +version = "1.5.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index b4532390c..4b07b2762 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -83,8 +83,10 @@ _maybe_call(f, x) = f Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the relevant tangent space for `x`. -At present this undersands only `x::Number`, `x::AbstractArray` and `x::Ref`. -It should not be called on arguments of an `rrule` method which accepts other types. +Custom `ProjectTo` methods are provided for many subtypes of `Number` (to e.g. ensure precision), +and `AbstractArray` (to e.g. ensure sparsity structure is maintained by tangent). +Called on unknown types it will (as of v1.5.0) simply return `identity`, thus can be safely +applied to arbitrary `rrule` arguments. # Examples ```jldoctest @@ -112,7 +114,7 @@ julia> ProjectTo([1 2; 3 4]') # no special structure, integers are promoted to ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2), Base.OneTo(2))) ``` """ -ProjectTo(::Any) # just to attach docstring +ProjectTo(::Any) = identity # Generic (::ProjectTo{T})(dx::AbstractZero) where {T} = dx @@ -143,6 +145,11 @@ ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = Project # Bool ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above +# Other never-differentiable types +for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle) + @eval ProjectTo(::$T) = ProjectTo{NoTangent}() +end + # Numbers ProjectTo(::Real) = ProjectTo{Real}() ProjectTo(::Complex) = ProjectTo{Complex}() diff --git a/test/projection.jl b/test/projection.jl index 53e3e0bcb..ba61fb8da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -11,6 +11,9 @@ Base.real(x::Dual) = x Base.float(x::Dual) = Dual(float(x.value), float(x.partial)) Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) +# Trivial struct +struct NoSuperType end + @testset "projection" begin ##### @@ -24,7 +27,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im @test ProjectTo(2.0)(1+1im) === 1.0 - # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) @@ -94,9 +96,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) @test y1[1] == [1 2] @test !(y1 isa Adjoint) && !(y1[1] isa Adjoint) - # arrays of unknown things - @test_throws MethodError ProjectTo([:x, :y]) - @test_throws MethodError ProjectTo(Any[:x, :y]) + # arrays of other things + @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} + @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} + @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @@ -140,6 +143,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} end + @testset "Base: non-diff" begin + @test ProjectTo(:a)(1) == NoTangent() + @test ProjectTo('b')(2) == NoTangent() + @test ProjectTo("cde")(345) == NoTangent() + end + ##### ##### `LinearAlgebra` ##### @@ -301,6 +310,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) ##### `ChainRulesCore` ##### + @testset "pass-through" begin + @test ProjectTo(NoSuperType()) === identity + end + @testset "AbstractZero" begin pz = ProjectTo(ZeroTangent()) pz(0) == NoTangent()