Skip to content

Commit c098f37

Browse files
author
Miha Zgubic
committed
number adjoints to rrules
1 parent c822e9e commit c098f37

File tree

2 files changed

+91
-23
lines changed

2 files changed

+91
-23
lines changed

src/lib/number.jl

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,60 @@
1-
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
2-
Base.literal_pow(^,x,Val(p)),
3-
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
1+
function ChainRulesCore.rrule(
2+
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}
3+
) where {p}
4+
function literal_pow_pullback(Δ)
5+
dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1)))
6+
return (NoTangent(), NoTangent(), dx, NoTangent())
7+
end
8+
return Base.literal_pow(^,x,Val(p)), literal_pow_pullback
9+
end
410

5-
@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ)
6-
@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ)
11+
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real)
12+
Real_pullback(Δ) = (NoTangent(), Δ)
13+
return T(x), Real_pullback
14+
end
715

816
for T in Base.uniontypes(Core.BuiltinInts)
9-
@adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,)
17+
@eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts)
18+
IntX_pullback(Δ) = (NoTangent(), Δ)
19+
return $T(x), IntX_pullback
20+
end
1021
end
1122

12-
@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)
23+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...)
24+
plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...)
25+
return +(xs...), plus_pullback
26+
end
1327

14-
@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))
28+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b)
29+
divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, -* a // b // b)
30+
return a // b, divide_pullback
31+
end
1532

1633
# Complex Numbers
1734

18-
@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))
35+
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i)
36+
Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄))
37+
return T(r, i), Complex_pullback
38+
end
1939

2040
# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}
2141

22-
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
23-
@adjoint real(x::Number) = real(x), -> (real(r̄),)
24-
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
25-
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)
42+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number)
43+
abs2_pullback) = (NoTangent(), real(Δ)*(x + x))
44+
return abs2(x), abs2_pullback
45+
end
2646

27-
# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
28-
# as embedded in the complex numbers and pull back a pure imaginary adjoint
29-
@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)
47+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number)
48+
real_pullback(r̄) = (NoTangent(), real(r̄))
49+
return real(x), real_pullback
50+
end
51+
52+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number)
53+
conj_pullback(c̄) = (NoTangent(), conj(c̄))
54+
return conj(x), conj_pullback
55+
end
56+
57+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number)
58+
imag_pullback(ī) = (NoTangent(), real(ī)*im)
59+
return imag(x), imag_pullback
60+
end

test/lib/number.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1-
@testset "nograds" begin
2-
@test gradient(floor, 1) === (0.0,)
3-
@test gradient(ceil, 1) === (0.0,)
4-
@test gradient(round, 1) === (0.0,)
5-
@test gradient(hash, 1) === nothing
6-
@test gradient(div, 1, 2) === nothing
7-
end #testset
1+
@testset "number.jl" begin
2+
@testset "nograds" begin
3+
@test gradient(floor, 1) === (0.0,)
4+
@test gradient(ceil, 1) === (0.0,)
5+
@test gradient(round, 1) === (0.0,)
6+
@test gradient(hash, 1) === nothing
7+
@test gradient(div, 1, 2) === nothing
8+
end
9+
10+
@testset "basics" begin
11+
@test gradient(Base.literal_pow, ^, 3//2, Val(-5))[2] isa Rational
12+
13+
@test gradient(convert, Rational, 3.14) == (nothing, 1.0)
14+
@test gradient(convert, Rational, 2.3) == (nothing, 1.0)
15+
@test gradient(convert, UInt64, 2) == (nothing, 1.0)
16+
@test gradient(convert, BigFloat, π) == (nothing, 1.0)
17+
18+
@test gradient(Rational, 2) == (1//1,)
19+
20+
@test gradient(Bool, 1) == (1.0,)
21+
@test gradient(Int32, 2) == (1.0,)
22+
@test gradient(UInt16, 2) == (1.0,)
23+
24+
@test gradient(+, 2.0, 3, 4.0, 5.0) == (1.0, 1.0, 1.0, 1.0)
25+
26+
@test gradient(//, 3, 2) == (1//2, -3//4)
27+
end
28+
29+
@testset "Complex numbers" begin
30+
@test gradient(imag, 3.0) == (0.0,)
31+
@test gradient(imag, 3.0 + 3.0im) == (0.0 + 1.0im,)
32+
33+
@test gradient(conj, 3.0) == (1.0,)
34+
@test gradient(real conj, 3.0 + 1im) == (1.0 + 0im,)
35+
36+
@test gradient(real, 3.0) == (1.0,)
37+
@test gradient(real, 3.0 + 1im) == (1.0 + 0im,)
38+
39+
@test gradient(abs2, 3.0) == (2*3.0,)
40+
@test gradient(abs2, 3.0+2im) == (2*3.0 + 2*2.0im,)
41+
42+
@test gradient(real Complex, 3.0, 2.0) == (1.0, 0.0)
43+
end
44+
end

0 commit comments

Comments
 (0)