Skip to content

Commit cb59b6c

Browse files
authored
Merge pull request #1249 from mzgubic/mz/deprecate_nograd
deprecate `Zygote.@nograd`
2 parents 5ffbd43 + af434d6 commit cb59b6c

File tree

11 files changed

+23
-35
lines changed

11 files changed

+23
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2727

2828
[compat]
2929
AbstractFFTs = "0.5, 1.0"
30-
ChainRules = "1.36.2"
30+
ChainRules = "1.37"
3131
ChainRulesCore = "1.9"
3232
ChainRulesTestUtils = "1"
3333
DiffRules = "1.4"

src/Zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ include("profiler/Profile.jl")
5858
end
5959

6060
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
61-
@nograd Colors.ColorTypes._parameter_upper_bound
61+
@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
6262
end
6363

6464
using InteractiveUtils

src/deprecated.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,19 @@ macro ignore(ex)
4949
$(esc(ex))
5050
end)
5151
end
52+
53+
using MacroTools: @q
54+
55+
macro nograd(ex)
56+
Base.depwarn(
57+
"`Zygote.@nograd myfunc` is deprecated, use `ChainRulesCore.@non_differentiable myfunc(::Any...)` instead.",
58+
:nograd
59+
)
60+
isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
61+
blk = @q begin end
62+
for f in ex.args
63+
back = MacroTools.@q _ -> ($__source__; nothing)
64+
push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
65+
end
66+
return blk
67+
end

src/forward/lib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99
# TODO figure out why this made a test fail
1010
zerolike(x::Union{Module,Type}) = nothing
1111

12-
# TODO: `@nograd` and `@linear`
12+
# TODO: `@non_differentiable` and `@linear`
1313

1414
@tangent zerolike(x) = zerolike(x), _ -> zerolike(x)
1515
@tangent one(x::Number) = one(x), _ -> zero(x)

src/lib/array.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ using Distributed: pmap, AbstractWorkerPool
66
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
77
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
88

9-
@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count
10-
119
@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)
1210

1311
@adjoint collect(x::Tuple) = collect(x), dy -> (Tuple(dy),)
@@ -233,11 +231,6 @@ end
233231
end
234232
end
235233

236-
for t in subtypes(AbstractWorkerPool)
237-
@nograd t
238-
end
239-
@nograd workers
240-
241234
function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
242235
y, b = ∇map(cx, g.f, g.iter)
243236
back(::Nothing) = nothing

src/lib/base.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ end
4949

5050
# Channels
5151

52-
@nograd Channel
53-
5452
grad_mut(ch::Channel) = Channel(ch.sz_max)
5553

5654
@adjoint! function put!(ch::Channel, x)
@@ -157,8 +155,6 @@ end
157155

158156
@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)
159157

160-
@nograd typeintersect
161-
162158
# Base.Fix1 and Base.Fix2: https://github.com/FluxML/Zygote.jl/issues/957
163159
@adjoint function (g::Base.Fix1)(y)
164160
f = g.f

src/lib/buffer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing)
22
grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0)
33

4-
@nograd Buffer
4+
@non_differentiable Buffer(::Any...)
55

66
@adjoint function getindex(b::Buffer, i...)
77
b[i...], function (Δ)

src/lib/grad.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
using MacroTools: @q
2-
3-
macro nograd(ex)
4-
isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
5-
blk = @q begin end
6-
for f in ex.args
7-
back = MacroTools.@q _ -> ($__source__; nothing)
8-
push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
9-
end
10-
return blk
11-
end
12-
131
macro which(ex)
142
@capture(ex, f_(args__)) || error("Zygote.@which f(args...)")
153
:(InteractiveUtils.@which adjoint(Context(), $(esc(f)), $(esc.(args)...)))

src/lib/lib.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ function accum(x::RefValue, y::RefValue)
3838
end
3939

4040
# Core functions
41-
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll, Symbol
42-
4341
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
4442

4543
@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing

src/lib/number.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
2-
@nograd floor, ceil, trunc, round, div
3-
41
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
52
Base.literal_pow(^,x,Val(p)),
63
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)

0 commit comments

Comments
 (0)