Skip to content

Commit 09a0ed6

Browse files
committed
simplify, some doctests
1 parent 1d2e11b commit 09a0ed6

File tree

5 files changed

+29
-39
lines changed

5 files changed

+29
-39
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.6.20"
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9-
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
109
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1110
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1211
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -27,7 +26,6 @@ AbstractFFTs = "0.5, 1.0"
2726
ChainRules = "1.5"
2827
ChainRulesCore = "1.3"
2928
ChainRulesTestUtils = "1"
30-
Compat = "2.2, 3"
3129
DiffRules = "1.0"
3230
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
3331
ForwardDiff = "0.10"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ julia> using Zygote
1818
julia> f(x) = 5x + 3
1919

2020
julia> f(10), f'(10)
21-
(53, 5)
21+
(53, 5.0)
2222

2323
julia> @code_llvm f'(10)
2424
define i64 @"julia_#625_38792"(i64) {

src/Zygote.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using ChainRules: ChainRules, rrule, unthunk, canonicalize
1111
using IRTools
1212
using MacroTools, Requires
1313
using MacroTools: @forward
14-
using Compat # for Julia 1.3, need Compat 2.2
1514

1615
import Distributed: pmap, CachingPool, workers
1716
export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint

src/compiler/interface.jl

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,13 @@ julia> gradient([7, 11], 0, 1) do x, y, d
6868
p = size(x, d)
6969
sum(x.^p .+ y)
7070
end
71-
([14.0, 22.0], 2, nothing)
71+
([14.0, 22.0], 2.0, nothing)
7272
```
7373
"""
7474
function gradient(f, args...)
7575
y, back = pullback(f, args...)
7676
grad = back(sensitivity(y))
77-
isnothing(grad) && return nothing
78-
map(_project, args, grad)
77+
isnothing(grad) ? nothing : map(_project, args, grad)
7978
end
8079

8180
Base.adjoint(f::Function) = x -> gradient(f, x)[1]
@@ -98,31 +97,12 @@ true
9897
function withgradient(f, args...)
9998
y, back = pullback(f, args...)
10099
grad = back(sensitivity(y))
101-
isnothing(grad) && return (val=y, grad=nothing)
102-
(val = y, grad = map(_project, args, grad))
100+
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
101+
(val=y, grad=results)
103102
end
104103

105104
# Param-style wrappers
106105

107-
"""
108-
Params([A, B])
109-
110-
Container for implicit parameters, used when differentiating
111-
a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
112-
"""
113-
struct Params
114-
order::Buffer # {Any, Vector{Any}}
115-
params::IdSet{Any} # TODO store ids only
116-
end
117-
118-
Params() = Params(Buffer([], false), IdSet())
119-
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
120-
Params(ps::Params) = ps
121-
Params(xs::Tuple) = Params(collect(xs))
122-
123-
@forward Params.order Base.iterate, Base.length, Base.getindex
124-
@forward Params.params Base.in
125-
126106
"""
127107
gradient(() -> loss(), ps::Params) -> Grads
128108
@@ -138,24 +118,37 @@ julia> g = gradient(Params([x, y])) do
138118
Grads(...)
139119
140120
julia> g[x]
141-
2×3 Matrix{Int64}:
142-
7 70 700
143-
8 80 800
121+
2×3 Matrix{Float64}:
122+
7.0 70.0 700.0
123+
8.0 80.0 800.0
144124
145125
julia> haskey(g, z) # only x and y are parameters
146126
false
147127
```
148128
"""
149-
function gradient(f, ps::Params)
150-
y, back = pullback(f, ps)
151-
back(sensitivity(y))
152-
end
129+
gradient
153130

154-
function withgradient(f, ps::Params)
155-
y, back = pullback(f, ps)
156-
(val = y, grad = back(sensitivity(y)))
131+
"""
132+
Params([A, B])
133+
134+
Container for implicit parameters, used when differentiating
135+
a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
136+
"""
137+
struct Params
138+
order::Buffer # {Any, Vector{Any}}
139+
params::IdSet{Any} # TODO store ids only
157140
end
158141

142+
Params() = Params(Buffer([], false), IdSet())
143+
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
144+
Params(ps::Params) = ps
145+
Params(xs::Tuple) = Params(collect(xs))
146+
147+
@forward Params.order Base.iterate, Base.length, Base.getindex
148+
@forward Params.params Base.in
149+
150+
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
151+
159152
function Base.union!(ps::Params, itrs...)
160153
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
161154
return ps

src/lib/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function unbroadcast(x::AbstractArray, x̄)
5151
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
5252
else
5353
tup = filter(d -> size(x, d) == 1, ntuple(identity, N))
54-
dims = length(tup) == 1 ? only(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails
54+
dims = length(tup) == 1 ? first(tup) : tup # avoid sum(xbar, dims=(1,)) as e.g. sum(SA[1 2; 3 4], dims=(1,)) fails
5555
_project(x, accum_sum(x̄; dims = dims))
5656
end
5757
end

0 commit comments

Comments
 (0)