Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.4.0"
version = "1.5.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
10 changes: 8 additions & 2 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the
relevant tangent space for `x`.

At present this undersands only `x::Number`, `x::AbstractArray` and `x::Ref`.
It should not be called on arguments of an `rrule` method which accepts other types.
Called on unknown types it will now simply return `identity`, thus can be safely
applied to arbitrary `rrule` arguments.

# Examples
```jldoctest
Expand Down Expand Up @@ -112,7 +113,7 @@ julia> ProjectTo([1 2; 3 4]') # no special structure, integers are promoted to
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2), Base.OneTo(2)))
```
"""
ProjectTo(::Any) # just to attach docstring
ProjectTo(::Any) = identity

# Generic
(::ProjectTo{T})(dx::AbstractZero) where {T} = dx
Expand Down Expand Up @@ -143,6 +144,11 @@ ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = Project
# Bool
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine with below?


# Other never-differentiable types
for T in [:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle]
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
end

# Numbers
ProjectTo(::Real) = ProjectTo{Real}()
ProjectTo(::Complex) = ProjectTo{Complex}()
Expand Down
21 changes: 17 additions & 4 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Base.real(x::Dual) = x
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))

# Trivial struct
struct NoSuperType end

@testset "projection" begin

#####
Expand All @@ -24,7 +27,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
@test ProjectTo(2.0)(1+1im) === 1.0


# storage
@test ProjectTo(1)(pi) === pi
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
Expand Down Expand Up @@ -94,9 +96,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test y1[1] == [1 2]
@test !(y1 isa Adjoint) && !(y1[1] isa Adjoint)

# arrays of unknown things
@test_throws MethodError ProjectTo([:x, :y])
@test_throws MethodError ProjectTo(Any[:x, :y])
# arrays of other things
@test ProjectTo([:x, :y]) isa ProjectTo{NoTangent}
@test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent}
@test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray}

@test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number.
@test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im)
Expand Down Expand Up @@ -140,6 +143,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
end

@testset "Base: non-diff" begin
@test ProjectTo(:a)(1) == NoTangent()
@test ProjectTo('b')(2) == NoTangent()
@test ProjectTo("cde")(345) == NoTangent()
end

#####
##### `LinearAlgebra`
#####
Expand Down Expand Up @@ -301,6 +310,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
##### `ChainRulesCore`
#####

@testset "pass-through" begin
@test ProjectTo(NoSuperType()) === identity
end

@testset "AbstractZero" begin
pz = ProjectTo(ZeroTangent())
pz(0) == NoTangent()
Expand Down