Skip to content

Commit 2c1252b

Browse files
committed
simplify, some doctests
1 parent 1d2e11b commit 2c1252b

File tree

5 files changed

+11
-21
lines changed

5 files changed

+11
-21
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.6.20"
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9-
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
109
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1110
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1211
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -27,7 +26,6 @@ AbstractFFTs = "0.5, 1.0"
2726
ChainRules = "1.5"
2827
ChainRulesCore = "1.3"
2928
ChainRulesTestUtils = "1"
30-
Compat = "2.2, 3"
3129
DiffRules = "1.0"
3230
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
3331
ForwardDiff = "0.10"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ julia> using Zygote
1818
julia> f(x) = 5x + 3
1919

2020
julia> f(10), f'(10)
21-
(53, 5)
21+
(53, 5.0)
2222

2323
julia> @code_llvm f'(10)
2424
define i64 @"julia_#625_38792"(i64) {

src/Zygote.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using ChainRules: ChainRules, rrule, unthunk, canonicalize
1111
using IRTools
1212
using MacroTools, Requires
1313
using MacroTools: @forward
14-
using Compat # for Julia 1.3, need Compat 2.2
1514

1615
import Distributed: pmap, CachingPool, workers
1716
export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint

src/compiler/interface.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,13 @@ julia> gradient([7, 11], 0, 1) do x, y, d
6868
p = size(x, d)
6969
sum(x.^p .+ y)
7070
end
71-
([14.0, 22.0], 2, nothing)
71+
([14.0, 22.0], 2.0, nothing)
7272
```
7373
"""
7474
function gradient(f, args...)
7575
y, back = pullback(f, args...)
7676
grad = back(sensitivity(y))
77-
isnothing(grad) && return nothing
78-
map(_project, args, grad)
77+
isnothing(grad) ? nothing : map(_project, args, grad)
7978
end
8079

8180
Base.adjoint(f::Function) = x -> gradient(f, x)[1]
@@ -98,8 +97,8 @@ true
9897
function withgradient(f, args...)
9998
y, back = pullback(f, args...)
10099
grad = back(sensitivity(y))
101-
isnothing(grad) && return (val=y, grad=nothing)
102-
(val = y, grad = map(_project, args, grad))
100+
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
101+
(val=y, grad=results)
103102
end
104103

105104
# Param-style wrappers
@@ -138,23 +137,17 @@ julia> g = gradient(Params([x, y])) do
138137
Grads(...)
139138
140139
julia> g[x]
141-
2×3 Matrix{Int64}:
142-
7 70 700
143-
8 80 800
140+
2×3 Matrix{Float64}:
141+
7.0 70.0 700.0
142+
8.0 80.0 800.0
144143
145144
julia> haskey(g, z) # only x and y are parameters
146145
false
147146
```
148147
"""
149-
function gradient(f, ps::Params)
150-
y, back = pullback(f, ps)
151-
back(sensitivity(y))
152-
end
148+
gradient
153149

154-
function withgradient(f, ps::Params)
155-
y, back = pullback(f, ps)
156-
(val = y, grad = back(sensitivity(y)))
157-
end
150+
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
158151

159152
function Base.union!(ps::Params, itrs...)
160153
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)

src/lib/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function unbroadcast(x::AbstractArray, x̄)
5151
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
5252
else
5353
tup = filter(d -> size(x, d) == 1, ntuple(identity, N))
54-
dims = length(tup) == 1 ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails
54+
dims = length(tup) == 1 ? first(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails
5555
_project(x, accum_sum(x̄; dims = dims))
5656
end
5757
end

0 commit comments

Comments
 (0)