Skip to content

Commit 528e0be

Browse files
mcabbottoxinabox
andauthored
Use ProjectTo in broadcasting & gradient (#1044)
* use ProjectTo in broadcasting, etc * separate methods for Params * move after defn * better dims handling in unbroadcast * tidier * tests * more wrapping * fix a test * handle a few nothings * fix more, including FFT tests * tests * one test * tests * tests * tests * these are fixed * add Compat * tests * add tests for issues closed * simplify, some doctests * fix some tests * less piracy * adjoint * piract * skip a test * splat tests * skip on 1.3 * simplify _project * a typo * tweak * broken GPU test, unrelated * unexpected pass * only broken on 1.6 * let nothing through * rm some broken things * target 1.3 fix * comments * update for ProjectTo(::Any) * fix a test * Update test/utils.jl Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * Update src/lib/broadcast.jl * cu tests * v0.6.22 Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent b33988e commit 528e0be

File tree

13 files changed

+214
-78
lines changed

13 files changed

+214
-78
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.21"
3+
version = "0.6.22"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2424
[compat]
2525
AbstractFFTs = "0.5, 1.0"
2626
ChainRules = "1.5"
27-
ChainRulesCore = "1.1"
27+
ChainRulesCore = "1.6"
2828
ChainRulesTestUtils = "1"
2929
DiffRules = "1.0"
3030
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ julia> using Zygote
1818
julia> f(x) = 5x + 3
1919

2020
julia> f(10), f'(10)
21-
(53, 5)
21+
(53, 5.0)
2222

2323
julia> @code_llvm f'(10)
2424
define i64 @"julia_#625_38792"(i64) {

src/compiler/chainrules.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,33 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
123123
"""
124124
@inline wrap_chainrules_input(x) = x
125125
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
126+
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
126127
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
127128
xp = map(wrap_chainrules_input, xs)
128129
ChainRules.Tangent{Any, typeof(xp)}(xp)
129130
end
130131

132+
"""
133+
_project(x, dx)
134+
135+
Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape.
136+
Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
137+
Safe to apply to arbitrary input.
138+
"""
139+
@inline function _project(x, dx)
140+
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
141+
end
142+
143+
# Restore splatted arrays
144+
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))
145+
146+
# Piracy:
147+
# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
148+
(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()
149+
150+
# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
151+
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
152+
131153
"""
132154
ZBack{F}(back) <: Function
133155

src/compiler/interface.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d
6868
p = size(x, d)
6969
sum(x.^p .+ y)
7070
end
71-
([14.0, 22.0], 2, nothing)
71+
([14.0, 22.0], 2.0, nothing)
7272
```
7373
"""
7474
function gradient(f, args...)
7575
y, back = pullback(f, args...)
76-
return back(sensitivity(y))
76+
grad = back(sensitivity(y))
77+
isnothing(grad) ? nothing : map(_project, args, grad)
7778
end
7879

79-
Base.adjoint(f::Function) = x -> gradient(f, x)[1]
80+
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
81+
Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons
82+
y, back = pullback(f, x)
83+
back(sensitivity(y))[1]
84+
end
8085

8186
"""
8287
withgradient(f, args...)
@@ -95,7 +100,9 @@ true
95100
"""
96101
function withgradient(f, args...)
97102
y, back = pullback(f, args...)
98-
(val = y, grad = back(sensitivity(y)))
103+
grad = back(sensitivity(y))
104+
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
105+
(val=y, grad=results)
99106
end
100107

101108
# Param-style wrappers
@@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do
115122
Grads(...)
116123
117124
julia> g[x]
118-
2×3 Matrix{Int64}:
119-
7 70 700
120-
8 80 800
125+
2×3 Matrix{Float64}:
126+
7.0 70.0 700.0
127+
8.0 80.0 800.0
121128
122129
julia> haskey(g, z) # only x and y are parameters
123130
false
@@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs))
144151
@forward Params.order Base.iterate, Base.length, Base.getindex
145152
@forward Params.params Base.in
146153

154+
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
155+
147156
function Base.union!(ps::Params, itrs...)
148157
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
149158
return ps

src/lib/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
dxv = view(dx, inds...)
3939
dxv .= accum.(dxv, _droplike(dy, dxv))
4040
end
41-
return (dx, map(_->nothing, inds)...)
41+
return (_project(x, dx), map(_->nothing, inds)...)
4242
end
4343

4444
"""

src/lib/broadcast.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
4545
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
4646
end
4747

48-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
49-
trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
50-
51-
unbroadcast(x::AbstractArray, x̄) =
52-
size(x) == size(x̄) ?:
53-
length(x) == length(x̄) ? trim(x, x̄) :
54-
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
55-
48+
function unbroadcast(x::AbstractArray, x̄)
49+
N = ndims(x̄)
50+
if length(x) == length(x̄)
51+
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
52+
else
53+
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
54+
_project(x, accum_sum(x̄; dims = dims))
55+
end
56+
end
5657
unbroadcast(x::Number, x̄) = accum_sum(x̄)
5758
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
5859
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
59-
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
60+
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
6061

6162
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
6263

test/complex.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
using Zygote, Test, LinearAlgebra
22

3+
@testset "basic" begin
4+
35
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] 1
46
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] 0
5-
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] -1im
6-
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im
7+
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] -1im
8+
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] 0 # projected to zero
9+
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] 1im
10+
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] 0
711

812
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
913
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
@@ -21,6 +25,8 @@ using Zygote, Test, LinearAlgebra
2125
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
2226
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)
2327

28+
end # @testset
29+
2430
fs_C_to_R = (real,
2531
imag,
2632
abs,
@@ -81,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj,
8187
end
8288
end
8389
end
90+
91+
@testset "issue 342" begin
92+
@test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,)
93+
@test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,)
94+
end
95+
96+
@testset "issue 402" begin
97+
A = [1,2,3.0]
98+
y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
99+
bA = B_getindex(1)[1]
100+
@test bA isa Diagonal
101+
@test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
102+
end
103+
104+
@testset "issue #917" begin
105+
function fun(v)
106+
c = v[1:3] + v[4:6]*im
107+
r = v[7:9]
108+
sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c
109+
end
110+
@test Zygote.hessian(fun, collect(1:9)) [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
111+
end
112+

test/cuda.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
using CUDA
22
using Zygote: Grads
3+
using LinearAlgebra
34
using Random: randn!
45
CUDA.allowscalar(false)
56

67
# Test GPU movement inside the call to `gradient`
78
@testset "GPU movement" begin
89
r = rand(Float32, 3,3)
9-
@test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
10+
@test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32}
11+
@test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix
12+
@test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
13+
@test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul!
14+
15+
# Other direction:
16+
@test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray
17+
@test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
1018
end
1119

1220
@testset "broadcasting" begin
@@ -31,17 +39,38 @@ end
3139
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
3240
@test_skip cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
3341
@test cu(g3) gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
42+
43+
# Projection: eltype preservation:
44+
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
45+
@test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback
46+
# structure restoration:
47+
@test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix
48+
@test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
49+
# non-differentiables
50+
@test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing
3451
end
3552

3653
@testset "sum(f, x)" begin
37-
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
54+
a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01]
3855
a_gpu = a |> cu
3956

4057
f(x) = sum(abs, x)
4158
g = gradient(f, a)[1]
4259
g_gpu = gradient(f, a_gpu)[1]
4360
@test g_gpu isa CuArray
4461
@test g_gpu |> collect g
62+
63+
f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule
64+
g2 = gradient(f2, a)[1]
65+
g2_gpu = gradient(f2, a_gpu)[1]
66+
@test g2_gpu isa CuArray
67+
@test g2_gpu |> collect g2
68+
69+
f3(x) = sum(y->y^3, x') # anonymous function
70+
g3 = gradient(f3, a')[1]
71+
g3_gpu = gradient(f3, a_gpu')[1]
72+
@test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure
73+
@test g3_gpu |> collect g3
4574
end
4675

4776
@testset "jacobian" begin
@@ -103,5 +132,11 @@ end
103132
r = cu(rand(Float32, 3))
104133
grads = (cu(ones(Float32, 3)), 1.f0)
105134
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads
135+
136+
@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32}
137+
@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection
138+
139+
@test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
140+
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
106141
end
107142

test/features.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ end
176176

177177
@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)
178178

179-
@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),)
179+
@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)
180180

181-
@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),)
181+
@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)
182182

183183
struct Bar{T}
184184
a::T
@@ -262,6 +262,7 @@ D(f, x) = grad(f, x)[1]
262262
@test D(x -> x*D(y -> x+y, 1), 1) == 1
263263
@test D(x -> x*D(y -> x*y, 1), 4) == 8
264264

265+
@test sin''(1.0) == -sin(1.0)
265266
@test sin'''(1.0) == -cos(1.0)
266267

267268
f(x) = throw(DimensionMismatch("fubar"))
@@ -499,6 +500,25 @@ end
499500
@test x[1] == x[2]
500501
end
501502

503+
@testset "splats" begin
504+
@test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1]
505+
@test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0)
506+
507+
@test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1]
508+
@test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1]
509+
510+
# https://github.com/FluxML/Zygote.jl/issues/599
511+
@test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector
512+
513+
# https://github.com/FluxML/Zygote.jl/issues/866
514+
f866(x) = reshape(x, fill(2, 2)...)
515+
@test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1]
516+
517+
# https://github.com/FluxML/Zygote.jl/issues/731
518+
f731(x) = sum([x' * x, x...])
519+
@test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64})
520+
end
521+
502522
@testset "accumulation" begin
503523
# from https://github.com/FluxML/Zygote.jl/issues/905
504524
function net(x1)

test/forward/forward.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ end == 1
3636
x
3737
end == 0
3838

39-
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1]
39+
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
40+
@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real
4041

4142
using LinearAlgebra
4243

0 commit comments

Comments
 (0)