|
| 1 | + |
| 2 | +## Zygote internals |
| 3 | + |
| 4 | +```julia |
| 5 | +function pullback(f, args...) |
| 6 | + y, back = _pullback(f, args...) |
| 7 | + y, Δ -> tailmemaybe(back(Δ)) |
| 8 | +end |
| 9 | + |
| 10 | +function gradient(f, args...) |
| 11 | + y, back = pullback(f, args...) |
| 12 | + grad = back(sensitivity(y)) |
| 13 | + isnothing(grad) ? nothing : map(_project, args, grad) |
| 14 | +end |
| 15 | + |
| 16 | +_pullback(f, args...) = _pullback(Context(), f, args...) |
| 17 | + |
| 18 | +@generated function _pullback(ctx::AContext, f, args...) |
| 19 | + # Try using ChainRulesCore |
| 20 | + if is_kwfunc(f, args...) |
| 21 | + # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function |
| 22 | + cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...} |
| 23 | + chain_rrule_f = :chain_rrule_kw |
| 24 | + else |
| 25 | + cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...} |
| 26 | + chain_rrule_f = :chain_rrule |
| 27 | + end |
| 28 | + |
| 29 | + hascr, cr_edge = has_chain_rrule(cr_T) |
| 30 | + hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...)) |
| 31 | + |
| 32 | + # No ChainRule, going to have to work it out. |
| 33 | + T = Tuple{f,args...} |
| 34 | + ignore_sig(T) && return :(f(args...), Pullback{$T}(())) |
| 35 | + |
| 36 | + g = try |
| 37 | + _generate_pullback_via_decomposition(T) |
| 38 | + catch e |
| 39 | + rethrow(CompileError(T,e)) |
| 40 | + end |
| 41 | + g === nothing && return :(f(args...), Pullback{$T}((f,))) |
| 42 | + meta, forw, _ = g |
| 43 | + argnames!(meta, Symbol("#self#"), :ctx, :f, :args) |
| 44 | + forw = varargs!(meta, forw, 3) |
| 45 | + # IRTools.verify(forw) |
| 46 | + forw = slots!(pis!(inlineable!(forw))) |
| 47 | + # be ready to swap to using chainrule if one is declared |
| 48 | + cr_edge != nothing && edge!(meta, cr_edge) |
| 49 | + return update!(meta.code, forw) |
| 50 | +end |
| 51 | +``` |
| 52 | +## Source Code Transformation |
| 53 | + |
| 54 | +The most recent approach to Reverse Mode AD is **_Source-to-Source_** |
| 55 | +transformation adopted by packages like **_JAX_** and **_Zygote.jl_**. |
| 56 | +Transforming code promises to eliminate the problems of tracing-based AD. |
| 57 | +`Tracked` types are not needed anymore, which reduces memory usage, promising |
| 58 | +significant speedups. Additionally, the reverse pass becomes a *compiler |
| 59 | +problem*, which makes it possible to leverage highly optimized compilers like |
| 60 | +LLVM. |
| 61 | + |
| 62 | +Source-to-source AD uses meta-programming to produce `rrule`s for any function |
| 63 | +that is a composition of available `rrule`s. The code for `foo` |
| 64 | +```@example lec08 |
| 65 | +foo(x) = h(g(f(x))) |
| 66 | +
|
| 67 | +f(x) = x^2 |
| 68 | +g(x) = sin(x) |
| 69 | +h(x) = 5x |
| 70 | +nothing # hide |
| 71 | +``` |
| 72 | +is transformed into |
| 73 | +```julia eval=false |
| 74 | +function rrule(::typeof(foo), x) |
| 75 | + a, Ja = rrule(f, x) |
| 76 | + b, Jb = rrule(g, a) |
| 77 | + y, Jy = rrule(h, b) |
| 78 | + |
| 79 | + function dfoo(Δy) |
| 80 | + Δb = Jy(Δy) |
| 81 | + Δa = Jb(Δb) |
| 82 | + Δx = Ja(Δa) |
| 83 | + return Δx |
| 84 | + end |
| 85 | + |
| 86 | + return y, dfoo |
| 87 | +end |
| 88 | +``` |
| 89 | +For this simple example we can define the three `rrule`s by hand: |
| 90 | +```@example lec08 |
| 91 | +rrule(::typeof(f), x) = f(x), Δ -> 2x*Δ |
| 92 | +rrule(::typeof(g), x) = g(x), Δ -> cos(x)*Δ |
| 93 | +rrule(::typeof(h), x) = h(x), Δ -> 5*Δ |
| 94 | +``` |
| 95 | +Remember that this is a very artificial example. In real AD code you would |
| 96 | +overload functions like `+`, `*`, etc, such that you don't have to define a |
| 97 | +`rrule` for something like `5x`. |
| 98 | + |
| 99 | +In order to transform our functions safely we will make use of `IRTools.jl` |
| 100 | +(*Intermediate Representation Tools*) which provide some convenience functions |
| 101 | +for inspecting and manipulating code snippets. The IR for `foo` looks like this: |
| 102 | +```@example lec08 |
| 103 | +using IRTools: @code_ir, evalir |
| 104 | +ir = @code_ir foo(2.) |
| 105 | +``` |
| 106 | +```@setup lec08 |
| 107 | +msg = """ |
| 108 | +ir = 1: (%1, %2) ## rrule(foo, x) |
| 109 | + %3 = Main.f(%2) ## a = f(x) |
| 110 | + %4 = Main.g(%3) ## b = g(a) |
| 111 | + %5 = Main.h(%4) ## y = h(b) |
| 112 | + return %5 ## return y |
| 113 | +""" |
| 114 | +``` |
| 115 | +Variable names are replaced by `%N` and each function gets is own line. |
| 116 | +We can evalulate the IR (to actually run it) like this |
| 117 | +```@example lec08 |
| 118 | +evalir(ir, nothing, 2.) |
| 119 | +``` |
| 120 | +As a first step, lets transform the function calls to `rrule` calls. For |
| 121 | +this, all we need to do is iterate through the IR line by line and replace each |
| 122 | +statement with `(Main.rrule)(Main.func, %N)`, where `Main` just stand for the |
| 123 | +gobal main module in which we just defined our functions. |
| 124 | +But remember that the `rrule` returns |
| 125 | +the value `v` *and* the pullback `J` to compute the gradient. Just |
| 126 | +replacing the statements would alter our forward pass. Instead we can insert |
| 127 | +each statement *before* the one we want to change. Then we can replace the the |
| 128 | +original statement with `v = rr[1]` to use only `v` and not `J` in the |
| 129 | +subsequent computation. |
| 130 | +```@example lec08 |
| 131 | +using IRTools |
| 132 | +using IRTools: xcall, stmt |
| 133 | +
|
| 134 | +xgetindex(x, i...) = xcall(Base, :getindex, x, i...) |
| 135 | +
|
| 136 | +ir = @code_ir foo(2.) |
| 137 | +pr = IRTools.Pipe(ir) |
| 138 | +
|
| 139 | +for (v,statement) in pr |
| 140 | + ex = statement.expr |
| 141 | + rr = xcall(rrule, ex.args...) |
| 142 | + # pr[v] = stmt(rr, line=ir[v].line) |
| 143 | + vJ = insert!(pr, v, stmt(rr, line = ir[v].line)) |
| 144 | + pr[v] = xgetindex(vJ,1) |
| 145 | +end |
| 146 | +ir = IRTools.finish(pr) |
| 147 | +# |
| 148 | +#msg = """ |
| 149 | +#ir = 1: (%1, %2) ## rrule(foo, x) |
| 150 | +# %3 = (Main.rrule)(Main.f, %2) ## ra = rrule(f,x) |
| 151 | +# %4 = Base.getindex(%3, 1) ## a = ra[1] |
| 152 | +# %5 = (Main.rrule)(Main.g, %4) ## rb = rrule(g,a) |
| 153 | +# %6 = Base.getindex(%5, 1) ## b = rb[1] |
| 154 | +# %7 = (Main.rrule)(Main.h, %6) ## ry = rrule(h,b) |
| 155 | +# %8 = Base.getindex(%7, 1) ## y = ry[1] |
| 156 | +# return %8 ## return y |
| 157 | +#""" |
| 158 | +#println(msg) |
| 159 | +``` |
| 160 | +Evaluation of this transformed IR should still give us the same value |
| 161 | +```@example lec08 |
| 162 | +evalir(ir, nothing, 2.) |
| 163 | +``` |
| 164 | + |
| 165 | +The only thing that is left to do now is collect the `Js` and return |
| 166 | +a tuple of our forward value and the `Js`. |
| 167 | +```@example lec08 |
| 168 | +using IRTools: insertafter!, substitute, xcall, stmt |
| 169 | +
|
| 170 | +xtuple(xs...) = xcall(Core, :tuple, xs...) |
| 171 | +
|
| 172 | +ir = @code_ir foo(2.) |
| 173 | +pr = IRTools.Pipe(ir) |
| 174 | +Js = IRTools.Variable[] |
| 175 | +
|
| 176 | +for (v,statement) in pr |
| 177 | + ex = statement.expr |
| 178 | + rr = xcall(rrule, ex.args...) # ex.args = (f,x) |
| 179 | + vJ = insert!(pr, v, stmt(rr, line = ir[v].line)) |
| 180 | + pr[v] = xgetindex(vJ,1) |
| 181 | +
|
| 182 | + # collect Js |
| 183 | + J = insertafter!(pr, v, stmt(xgetindex(vJ,2), line=ir[v].line)) |
| 184 | + push!(Js, substitute(pr, J)) |
| 185 | +end |
| 186 | +ir = IRTools.finish(pr) |
| 187 | +# add the collected `Js` to `ir` |
| 188 | +Js = push!(ir, xtuple(Js...)) |
| 189 | +# return a tuple of the last `v` and `Js` |
| 190 | +ret = ir.blocks[end].branches[end].args[1] |
| 191 | +IRTools.return!(ir, xtuple(ret, Js)) |
| 192 | +ir |
| 193 | +#msg = """ |
| 194 | +#ir = 1: (%1, %2) ## rrule(foo, x) |
| 195 | +# %3 = (Main.rrule)(Main.f, %2) ## ra = rrule(f,x) |
| 196 | +# %4 = Base.getindex(%3, 1) ## a = ra[1] |
| 197 | +# %5 = Base.getindex(%3, 2) ## Ja = ra[2] |
| 198 | +# %6 = (Main.rrule)(Main.g, %4) ## rb = rrule(g,a) |
| 199 | +# %7 = Base.getindex(%6, 1) ## b = rb[1] |
| 200 | +# %8 = Base.getindex(%6, 2) ## Jb = rb[2] |
| 201 | +# %9 = (Main.rrule)(Main.h, %7) ## ry = rrule(h,b) |
| 202 | +# %10 = Base.getindex(%9, 1) ## y = ry[1] |
| 203 | +# %11 = Base.getindex(%9, 2) ## Jy = ry[2] |
| 204 | +# %12 = Core.tuple(%5, %8, %11) ## Js = (Ja,Jb,Jy) |
| 205 | +# %13 = Core.tuple(%10, %12) ## rr = (y, Js) |
| 206 | +# return %13 ## return rr |
| 207 | +#""" |
| 208 | +#println(msg) |
| 209 | +``` |
| 210 | +The resulting IR can be evaluated to the forward pass value and the Jacobians: |
| 211 | +```@repl lec08 |
| 212 | +(y, Js) = evalir(ir, foo, 2.) |
| 213 | +``` |
| 214 | +To compute the derivative given the tuple of `Js` we just need to compose them |
| 215 | +and set the initial gradient to one: |
| 216 | +```@repl lec08 |
| 217 | +reduce(|>, Js, init=1) # Ja(Jb(Jy(1))) |
| 218 | +``` |
| 219 | +The code for transforming the IR as described above looks like this. |
| 220 | +```@example lec08 |
| 221 | +function transform(ir, x) |
| 222 | + pr = IRTools.Pipe(ir) |
| 223 | + Js = IRTools.Variable[] |
| 224 | + |
| 225 | + # loop over each line in the IR |
| 226 | + for (v,statement) in pr |
| 227 | + ex = statement.expr |
| 228 | + # insert the rrule |
| 229 | + rr = xcall(rrule, ex.args...) # ex.args = (f,x) |
| 230 | + vJ = insert!(pr, v, stmt(rr, line = ir[v].line)) |
| 231 | + # replace original line with f(x) from rrule |
| 232 | + pr[v] = xgetindex(vJ,1) |
| 233 | + |
| 234 | + # save jacobian in a variable |
| 235 | + J = insertafter!(pr, v, stmt(xgetindex(vJ,2), line=ir[v].line)) |
| 236 | + # add it to a list of jacobians |
| 237 | + push!(Js, substitute(pr, J)) |
| 238 | + end |
| 239 | + ir = IRTools.finish(pr) |
| 240 | + # add the collected `Js` to `ir` |
| 241 | + Js = push!(ir, xtuple(Js...)) |
| 242 | + # return a tuple of the foo(x) and `Js` |
| 243 | + ret = ir.blocks[end].branches[end].args[1] |
| 244 | + IRTools.return!(ir, xtuple(ret, Js)) |
| 245 | + return ir |
| 246 | +end |
| 247 | +
|
| 248 | +xgetindex(x, i...) = xcall(Base, :getindex, x, i...) |
| 249 | +xtuple(xs...) = xcall(Core, :tuple, xs...) |
| 250 | +nothing # hide |
| 251 | +``` |
| 252 | +Now we can write a general `rrule` that can differentiate any function |
| 253 | +composed of our defined `rrule`s |
| 254 | +```@example lec08 |
| 255 | +function rrule(f, x) |
| 256 | + ir = @code_ir f(x) |
| 257 | + ir_derived = transform(ir,x) |
| 258 | + y, Js = evalir(ir_derived, nothing, x) |
| 259 | + df(Δ) = reduce(|>, Js, init=Δ) |
| 260 | + return y, df |
| 261 | +end |
| 262 | +
|
| 263 | +
|
| 264 | +reverse(f,x) = rrule(f,x)[2](one(x)) |
| 265 | +nothing # hide |
| 266 | +``` |
| 267 | +Finally, we just have to use `reverse` to compute the gradient |
| 268 | +```@example lec08 |
| 269 | +plot(-2:0.1:2, foo, label="f(x) = 5sin(x^2)", lw=3) |
| 270 | +plot!(-2:0.1:2, x->10x*cos(x^2), label="Analytic f'", ls=:dot, lw=3) |
| 271 | +plot!(-2:0.1:2, x->reverse(foo,x), label="Dual Forward Mode f'", lw=3, ls=:dash) |
| 272 | +``` |
| 273 | + |
| 274 | +--- |
| 275 | +- Efficiency of the forward pass becomes essentially a compiler problem |
| 276 | +- If we define specialized rules we will gain performance |
| 277 | +--- |
| 278 | + |
| 279 | +# Performance Forward vs. Reverse |
| 280 | + |
| 281 | +This section compares the performance of three different, widely used Julia AD |
| 282 | +systems `ForwardDiff.jl` (forward mode), `ReverseDiff.jl` (tracing-based |
| 283 | +reverse mode), and `Zygote.jl` (source-to-source reverse mode), as well as JAX |
| 284 | +forward/reverse modes. |
| 285 | + |
| 286 | +As a benchmark function we can compute the Jacobian of $f:\mathbb R^N |
| 287 | +\rightarrow \mathbb R^M$ with respect to $\bm x$. |
| 288 | +In the benchmark we test various different values of $N$ and $M$ to show the |
| 289 | +differences between the backends. |
| 290 | +```math |
| 291 | +f(\bm x) = (\bm W \bm x + \bm b)^2 |
| 292 | +``` |
| 293 | + |
| 294 | +```@setup lec08 |
| 295 | +using DataFrames |
| 296 | +using DrWatson |
| 297 | +using Glob |
| 298 | +
|
| 299 | +
|
| 300 | +julia_res = map(glob("julia-*.txt")) do fname |
| 301 | + d = parse_savename(replace(fname, "julia-"=>""))[2] |
| 302 | + @unpack N, M = d |
| 303 | + lines = open(fname) |> eachline |
| 304 | + map(lines) do line |
| 305 | + s = split(line, ":") |
| 306 | + backend = s[1] |
| 307 | + time = parse(Float32, s[2]) / 10^6 |
| 308 | + (backend, time, "$(N)x$(M)") |
| 309 | + end |
| 310 | +end |
| 311 | +
|
| 312 | +jax_res = map(glob("jax-*.txt")) do fname |
| 313 | + d = parse_savename(replace(fname, "jax-"=>""))[2] |
| 314 | + @unpack N, M = d |
| 315 | + lines = open(fname) |> eachline |
| 316 | + map(lines) do line |
| 317 | + s = split(line, ":") |
| 318 | + backend = s[1] |
| 319 | + time = parse(Float32, s[2]) * 10^3 |
| 320 | + (backend, time, "$(N)x$(M)") |
| 321 | + end |
| 322 | +end |
| 323 | +
|
| 324 | +res = vcat(julia_res, jax_res) |
| 325 | +
|
| 326 | +df = DataFrame(reduce(vcat, res)) |
| 327 | +df = unstack(df, 3, 1, 2) |
| 328 | +ns = names(df) |
| 329 | +ns[1] = "N x M" |
| 330 | +rename!(df, ns) |
| 331 | +df = DataFrame([[names(df)]; collect.(eachrow(df))], [:column; Symbol.(axes(df, 1))]) |
| 332 | +
|
| 333 | +ns = df[1,:] |> values |> collect |
| 334 | +rename!(df, ns) |
| 335 | +``` |
| 336 | +```@example lec08 |
| 337 | +df[2:end,:] # hide |
| 338 | +``` |
0 commit comments