Skip to content

Commit 9e09404

Browse files
committed
niklas stuff in lecture.md moved to irtools.md
1 parent bda6544 commit 9e09404

File tree

2 files changed

+340
-338
lines changed

2 files changed

+340
-338
lines changed

docs/src/lecture_09/irtools.md

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
2+
## Zygote internals
3+
4+
```julia
5+
function pullback(f, args...)
6+
y, back = _pullback(f, args...)
7+
y, Δ -> tailmemaybe(back(Δ))
8+
end
9+
10+
function gradient(f, args...)
11+
y, back = pullback(f, args...)
12+
grad = back(sensitivity(y))
13+
isnothing(grad) ? nothing : map(_project, args, grad)
14+
end
15+
16+
_pullback(f, args...) = _pullback(Context(), f, args...)
17+
18+
@generated function _pullback(ctx::AContext, f, args...)
19+
# Try using ChainRulesCore
20+
if is_kwfunc(f, args...)
21+
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
22+
cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...}
23+
chain_rrule_f = :chain_rrule_kw
24+
else
25+
cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...}
26+
chain_rrule_f = :chain_rrule
27+
end
28+
29+
hascr, cr_edge = has_chain_rrule(cr_T)
30+
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
31+
32+
# No ChainRule, going to have to work it out.
33+
T = Tuple{f,args...}
34+
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
35+
36+
g = try
37+
_generate_pullback_via_decomposition(T)
38+
catch e
39+
rethrow(CompileError(T,e))
40+
end
41+
g === nothing && return :(f(args...), Pullback{$T}((f,)))
42+
meta, forw, _ = g
43+
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
44+
forw = varargs!(meta, forw, 3)
45+
# IRTools.verify(forw)
46+
forw = slots!(pis!(inlineable!(forw)))
47+
# be ready to swap to using chainrule if one is declared
48+
cr_edge != nothing && edge!(meta, cr_edge)
49+
return update!(meta.code, forw)
50+
end
51+
```
52+
## Source Code Transformation
53+
54+
The most recent approach to Reverse Mode AD is **_Source-to-Source_**
55+
transformation adopted by packages like **_JAX_** and **_Zygote.jl_**.
56+
Transforming code promises to eliminate the problems of tracing-based AD.
57+
`Tracked` types are not needed anymore, which reduces memory usage, promising
58+
significant speedups. Additionally, the reverse pass becomes a *compiler
59+
problem*, which makes it possible to leverage highly optimized compilers like
60+
LLVM.
61+
62+
Source-to-source AD uses meta-programming to produce `rrule`s for any function
63+
that is a composition of available `rrule`s. The code for `foo`
64+
```@example lec08
65+
foo(x) = h(g(f(x)))
66+
67+
f(x) = x^2
68+
g(x) = sin(x)
69+
h(x) = 5x
70+
nothing # hide
71+
```
72+
is transformed into
73+
```julia eval=false
74+
function rrule(::typeof(foo), x)
75+
a, Ja = rrule(f, x)
76+
b, Jb = rrule(g, a)
77+
y, Jy = rrule(h, b)
78+
79+
function dfoo(Δy)
80+
Δb = Jy(Δy)
81+
Δa = Jb(Δb)
82+
Δx = Ja(Δa)
83+
return Δx
84+
end
85+
86+
return y, dfoo
87+
end
88+
```
89+
For this simple example we can define the three `rrule`s by hand:
90+
```@example lec08
91+
rrule(::typeof(f), x) = f(x), Δ -> 2x*Δ
92+
rrule(::typeof(g), x) = g(x), Δ -> cos(x)*Δ
93+
rrule(::typeof(h), x) = h(x), Δ -> 5*Δ
94+
```
95+
Remember that this is a very artificial example. In real AD code you would
96+
overload functions like `+`, `*`, etc, such that you don't have to define a
97+
`rrule` for something like `5x`.
98+
99+
In order to transform our functions safely we will make use of `IRTools.jl`
100+
(*Intermediate Representation Tools*) which provide some convenience functions
101+
for inspecting and manipulating code snippets. The IR for `foo` looks like this:
102+
```@example lec08
103+
using IRTools: @code_ir, evalir
104+
ir = @code_ir foo(2.)
105+
```
106+
```@setup lec08
107+
msg = """
108+
ir = 1: (%1, %2) ## rrule(foo, x)
109+
%3 = Main.f(%2) ## a = f(x)
110+
%4 = Main.g(%3) ## b = g(a)
111+
%5 = Main.h(%4) ## y = h(b)
112+
return %5 ## return y
113+
"""
114+
```
115+
Variable names are replaced by `%N` and each function gets is own line.
116+
We can evalulate the IR (to actually run it) like this
117+
```@example lec08
118+
evalir(ir, nothing, 2.)
119+
```
120+
As a first step, lets transform the function calls to `rrule` calls. For
121+
this, all we need to do is iterate through the IR line by line and replace each
122+
statement with `(Main.rrule)(Main.func, %N)`, where `Main` just stand for the
123+
gobal main module in which we just defined our functions.
124+
But remember that the `rrule` returns
125+
the value `v` *and* the pullback `J` to compute the gradient. Just
126+
replacing the statements would alter our forward pass. Instead we can insert
127+
each statement *before* the one we want to change. Then we can replace the the
128+
original statement with `v = rr[1]` to use only `v` and not `J` in the
129+
subsequent computation.
130+
```@example lec08
131+
using IRTools
132+
using IRTools: xcall, stmt
133+
134+
xgetindex(x, i...) = xcall(Base, :getindex, x, i...)
135+
136+
ir = @code_ir foo(2.)
137+
pr = IRTools.Pipe(ir)
138+
139+
for (v,statement) in pr
140+
ex = statement.expr
141+
rr = xcall(rrule, ex.args...)
142+
# pr[v] = stmt(rr, line=ir[v].line)
143+
vJ = insert!(pr, v, stmt(rr, line = ir[v].line))
144+
pr[v] = xgetindex(vJ,1)
145+
end
146+
ir = IRTools.finish(pr)
147+
#
148+
#msg = """
149+
#ir = 1: (%1, %2) ## rrule(foo, x)
150+
# %3 = (Main.rrule)(Main.f, %2) ## ra = rrule(f,x)
151+
# %4 = Base.getindex(%3, 1) ## a = ra[1]
152+
# %5 = (Main.rrule)(Main.g, %4) ## rb = rrule(g,a)
153+
# %6 = Base.getindex(%5, 1) ## b = rb[1]
154+
# %7 = (Main.rrule)(Main.h, %6) ## ry = rrule(h,b)
155+
# %8 = Base.getindex(%7, 1) ## y = ry[1]
156+
# return %8 ## return y
157+
#"""
158+
#println(msg)
159+
```
160+
Evaluation of this transformed IR should still give us the same value
161+
```@example lec08
162+
evalir(ir, nothing, 2.)
163+
```
164+
165+
The only thing that is left to do now is collect the `Js` and return
166+
a tuple of our forward value and the `Js`.
167+
```@example lec08
168+
using IRTools: insertafter!, substitute, xcall, stmt
169+
170+
xtuple(xs...) = xcall(Core, :tuple, xs...)
171+
172+
ir = @code_ir foo(2.)
173+
pr = IRTools.Pipe(ir)
174+
Js = IRTools.Variable[]
175+
176+
for (v,statement) in pr
177+
ex = statement.expr
178+
rr = xcall(rrule, ex.args...) # ex.args = (f,x)
179+
vJ = insert!(pr, v, stmt(rr, line = ir[v].line))
180+
pr[v] = xgetindex(vJ,1)
181+
182+
# collect Js
183+
J = insertafter!(pr, v, stmt(xgetindex(vJ,2), line=ir[v].line))
184+
push!(Js, substitute(pr, J))
185+
end
186+
ir = IRTools.finish(pr)
187+
# add the collected `Js` to `ir`
188+
Js = push!(ir, xtuple(Js...))
189+
# return a tuple of the last `v` and `Js`
190+
ret = ir.blocks[end].branches[end].args[1]
191+
IRTools.return!(ir, xtuple(ret, Js))
192+
ir
193+
#msg = """
194+
#ir = 1: (%1, %2) ## rrule(foo, x)
195+
# %3 = (Main.rrule)(Main.f, %2) ## ra = rrule(f,x)
196+
# %4 = Base.getindex(%3, 1) ## a = ra[1]
197+
# %5 = Base.getindex(%3, 2) ## Ja = ra[2]
198+
# %6 = (Main.rrule)(Main.g, %4) ## rb = rrule(g,a)
199+
# %7 = Base.getindex(%6, 1) ## b = rb[1]
200+
# %8 = Base.getindex(%6, 2) ## Jb = rb[2]
201+
# %9 = (Main.rrule)(Main.h, %7) ## ry = rrule(h,b)
202+
# %10 = Base.getindex(%9, 1) ## y = ry[1]
203+
# %11 = Base.getindex(%9, 2) ## Jy = ry[2]
204+
# %12 = Core.tuple(%5, %8, %11) ## Js = (Ja,Jb,Jy)
205+
# %13 = Core.tuple(%10, %12) ## rr = (y, Js)
206+
# return %13 ## return rr
207+
#"""
208+
#println(msg)
209+
```
210+
The resulting IR can be evaluated to the forward pass value and the Jacobians:
211+
```@repl lec08
212+
(y, Js) = evalir(ir, foo, 2.)
213+
```
214+
To compute the derivative given the tuple of `Js` we just need to compose them
215+
and set the initial gradient to one:
216+
```@repl lec08
217+
reduce(|>, Js, init=1) # Ja(Jb(Jy(1)))
218+
```
219+
The code for transforming the IR as described above looks like this.
220+
```@example lec08
221+
function transform(ir, x)
222+
pr = IRTools.Pipe(ir)
223+
Js = IRTools.Variable[]
224+
225+
# loop over each line in the IR
226+
for (v,statement) in pr
227+
ex = statement.expr
228+
# insert the rrule
229+
rr = xcall(rrule, ex.args...) # ex.args = (f,x)
230+
vJ = insert!(pr, v, stmt(rr, line = ir[v].line))
231+
# replace original line with f(x) from rrule
232+
pr[v] = xgetindex(vJ,1)
233+
234+
# save jacobian in a variable
235+
J = insertafter!(pr, v, stmt(xgetindex(vJ,2), line=ir[v].line))
236+
# add it to a list of jacobians
237+
push!(Js, substitute(pr, J))
238+
end
239+
ir = IRTools.finish(pr)
240+
# add the collected `Js` to `ir`
241+
Js = push!(ir, xtuple(Js...))
242+
# return a tuple of the foo(x) and `Js`
243+
ret = ir.blocks[end].branches[end].args[1]
244+
IRTools.return!(ir, xtuple(ret, Js))
245+
return ir
246+
end
247+
248+
xgetindex(x, i...) = xcall(Base, :getindex, x, i...)
249+
xtuple(xs...) = xcall(Core, :tuple, xs...)
250+
nothing # hide
251+
```
252+
Now we can write a general `rrule` that can differentiate any function
253+
composed of our defined `rrule`s
254+
```@example lec08
255+
function rrule(f, x)
256+
ir = @code_ir f(x)
257+
ir_derived = transform(ir,x)
258+
y, Js = evalir(ir_derived, nothing, x)
259+
df(Δ) = reduce(|>, Js, init=Δ)
260+
return y, df
261+
end
262+
263+
264+
reverse(f,x) = rrule(f,x)[2](one(x))
265+
nothing # hide
266+
```
267+
Finally, we just have to use `reverse` to compute the gradient
268+
```@example lec08
269+
plot(-2:0.1:2, foo, label="f(x) = 5sin(x^2)", lw=3)
270+
plot!(-2:0.1:2, x->10x*cos(x^2), label="Analytic f'", ls=:dot, lw=3)
271+
plot!(-2:0.1:2, x->reverse(foo,x), label="Dual Forward Mode f'", lw=3, ls=:dash)
272+
```
273+
274+
---
275+
- Efficiency of the forward pass becomes essentially a compiler problem
276+
- If we define specialized rules we will gain performance
277+
---
278+
279+
# Performance Forward vs. Reverse
280+
281+
This section compares the performance of three different, widely used Julia AD
282+
systems `ForwardDiff.jl` (forward mode), `ReverseDiff.jl` (tracing-based
283+
reverse mode), and `Zygote.jl` (source-to-source reverse mode), as well as JAX
284+
forward/reverse modes.
285+
286+
As a benchmark function we can compute the Jacobian of $f:\mathbb R^N
287+
\rightarrow \mathbb R^M$ with respect to $\bm x$.
288+
In the benchmark we test various different values of $N$ and $M$ to show the
289+
differences between the backends.
290+
```math
291+
f(\bm x) = (\bm W \bm x + \bm b)^2
292+
```
293+
294+
```@setup lec08
295+
using DataFrames
296+
using DrWatson
297+
using Glob
298+
299+
300+
julia_res = map(glob("julia-*.txt")) do fname
301+
d = parse_savename(replace(fname, "julia-"=>""))[2]
302+
@unpack N, M = d
303+
lines = open(fname) |> eachline
304+
map(lines) do line
305+
s = split(line, ":")
306+
backend = s[1]
307+
time = parse(Float32, s[2]) / 10^6
308+
(backend, time, "$(N)x$(M)")
309+
end
310+
end
311+
312+
jax_res = map(glob("jax-*.txt")) do fname
313+
d = parse_savename(replace(fname, "jax-"=>""))[2]
314+
@unpack N, M = d
315+
lines = open(fname) |> eachline
316+
map(lines) do line
317+
s = split(line, ":")
318+
backend = s[1]
319+
time = parse(Float32, s[2]) * 10^3
320+
(backend, time, "$(N)x$(M)")
321+
end
322+
end
323+
324+
res = vcat(julia_res, jax_res)
325+
326+
df = DataFrame(reduce(vcat, res))
327+
df = unstack(df, 3, 1, 2)
328+
ns = names(df)
329+
ns[1] = "N x M"
330+
rename!(df, ns)
331+
df = DataFrame([[names(df)]; collect.(eachrow(df))], [:column; Symbol.(axes(df, 1))])
332+
333+
ns = df[1,:] |> values |> collect
334+
rename!(df, ns)
335+
```
336+
```@example lec08
337+
df[2:end,:] # hide
338+
```

0 commit comments

Comments
 (0)