Skip to content

Commit be5b47f

Browse files
committed
Improved type stability with explicit params
We can disable accumulating (implicit) parameters to the gradient cache in explicit mode. This can dramatically improve type stability because `accum_param` will return a `Union{Nothing, [grad type]}` otherwise.
1 parent c822e9e commit be5b47f

File tree

5 files changed

+49
-21
lines changed

5 files changed

+49
-21
lines changed

src/compiler/interface.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using Core: Typeof
44
import Base: copy!, IdSet
55
import Base.Broadcast: broadcasted, materialize!
66

7-
mutable struct Context <: AContext
7+
mutable struct Context{I} <: AContext
88
cache::Union{IdDict{Any,Any},Nothing}
99
end
1010

11-
Context() = Context(nothing)
11+
Context() = Context{false}(nothing)
1212

1313
cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache
1414

@@ -36,10 +36,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
3636
tailmemaybe(::Nothing) = nothing
3737
tailmemaybe(x::Tuple) = Base.tail(x)
3838

39-
function pullback(f, args...)
40-
y, back = _pullback(f, args...)
39+
@inline pullback(f, args...) = pullback(f, Context(), args...)
40+
function pullback(f, cx::AContext, args...)
41+
y, back = _pullback(cx, f, args...)
4142
y, Δ -> tailmemaybe(back(Δ))
4243
end
44+
function pullback(cx::Context, f, args...)
45+
ChainRulesCore.ignore_derivatives() do
46+
@warn """
47+
Incorrect argument order for pullback, please use:
48+
49+
pullback(f, __context__::Context, args)
50+
51+
instead of:
52+
53+
pullback(__context__::Context, f, args)
54+
55+
This is usually caused by a call to pullback in a higher-order @adjoint.
56+
The above warning will become an error in Zygote 0.7.
57+
"""
58+
end
59+
return pullback(f, cx, args...)
60+
end
4361

4462
sensitivity(y::Number) = one(y)
4563
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
@@ -334,21 +352,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
334352
end
335353

336354
function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
337-
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
355+
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
338356
throw(ArgumentError("map! expects Grads objects with the same Params."))
339357
for p in gsout.params
340-
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
358+
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
341359
end
342360
return gsout
343361
end
344362

345363
function _getformap(gs, p)
346364
g = gs[p]
347-
isnothing(g) ? fill!(similar(p), 0) : g
365+
isnothing(g) ? fill!(similar(p), 0) : g
348366
end
349367

350368
function pullback(f, ps::Params)
351-
cx = Context()
369+
cx = Context{true}(nothing)
352370
y, back = _pullback(cx, f)
353371
y, function (Δ)
354372
for p in ps

src/lib/array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,15 @@ end
310310

311311
@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
312312
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
313-
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
313+
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
314314
end
315315

316316
@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
317317
sum(xs, dims = dims), Δ -> (nothing,)
318318
end
319319

320320
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
321-
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
321+
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
322322
y, ȳ -> (nothing, back(ȳ)...)
323323
end
324324

src/lib/broadcast.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
3030
# Utilities
3131
# =========
3232

33+
# ChainRules already marks this non-differentiable,
34+
# But inference can still give up because of the Zygote -> CR wrapper layer
35+
@nograd Broadcast.combine_styles
36+
3337
accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)
3438

3539
# Work around reducedim_init issue
@@ -82,16 +86,16 @@ _minus(::Nothing) = nothing
8286
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
8387
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
8488
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
85-
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
89+
_pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
8690
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
87-
_pullback(*, x, y)
91+
_pullback(__context__, *, x, y)
8892

8993
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
9094
res = x ./ y
9195
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
9296
end
9397
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
94-
_pullback(/, x, y)
98+
_pullback(__context__, /, x, y)
9599

96100
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
97101
y = Base.literal_pow.(^, x, exp)
@@ -273,7 +277,7 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
273277
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
274278
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
275279
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
276-
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
280+
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
277281
end
278282

279283
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}

src/lib/lib.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ accum(x, y) =
2121

2222
accum(x, y, zs...) = accum(accum(x, y), zs...)
2323

24-
accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
24+
accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
2525
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
2626

2727
@generated function accum(x::NamedTuple, y::NamedTuple)
@@ -48,6 +48,7 @@ end
4848

4949
@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
5050

51+
accum_param(::Context{false}, _, Δ) = Δ
5152
@generated function accum_param(cx::Context, x, Δ)
5253
isbitstype(x) && return :(Δ)
5354
quote

test/compiler.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Zygote, Test
2-
using Zygote: pullback, @adjoint
2+
using Zygote: pullback, @adjoint, Context
33

44
macro test_inferred(ex)
55
:(let res = nothing
@@ -160,13 +160,18 @@ end
160160
@testset "inference for `getproperty`" begin
161161
Gaussian = _Gaussian(:getproperty)
162162
g = Gaussian(randn(3), randn(3, 3))
163-
y, back = @inferred pullback(x -> x.m, g)
164-
@test y == getfield(g, :m)
165-
# This type instability is due to the handling of non-bitstypes in `accum_param`
163+
y_explicit, back_explicit = @inferred pullback(x -> x.m, g)
164+
y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g)
165+
@test y_explicit == y_implicit == getfield(g, :m)
166+
167+
∇args = ((m = [1.0, 0.0, 0.0], P = nothing),)
166168
if VERSION > v"1.7-"
167-
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
169+
# This type instability is due to the handling of non-bitstypes in `accum_param`
170+
@test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}]
171+
# But the same should infer if implicit parameters are disabled
172+
@test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)]
168173
end
169-
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
174+
@test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args
170175

171176
Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
172177
y, back = pullback(x -> x.m, g)

0 commit comments

Comments
 (0)