Skip to content

Commit bf9b025

Browse files
committed
added dissected petite zygote example
1 parent ad101a4 commit bf9b025

File tree

2 files changed

+223
-6
lines changed

2 files changed

+223
-6
lines changed

docs/src/lecture_09/irtools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function foo(x, y)
4848
end
4949

5050
reset!(to)
51-
profile_fun(foo, 1.0, 1.0)
51+
@elapsed profile_fun(foo, 1.0, 1.0)
5252
to
5353

5454
@record foo(1.0, 1.0) => profile_fun(foo, 1.0, 1.0)

docs/src/lecture_09/lecture.md

Lines changed: 222 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,11 +604,228 @@ to
604604
where you should notice the long time the first execution of `profile_fun(foo, 1.0, 1.0)` takes. This is caused by the compiler specializing for every function into which we dive into. The second execution of `profile_fun(foo, 1.0, 1.0)` is fast. It is also interesting to observe how the time of the compilation is logged by the profiler. The output of the profiler `to` is not shown here due to the length of the output.
605605

606606
## Petite zygote
607+
`IRTools.jl` were created for `Zygote.jl` --- Julia's source-to-source AD system currently powering `Flux.jl`. An interesting aspect of `Zygote` was to recognize that TensorFlow is in its nutshell a compiler, PyTorch is an interpreter. So the idea was to let Julia's compiler to compile the gradient and perform optimizations that are normally performed with normal code. Recall that a lot of research went into how to generate an efficient code and it is reasonable to use this research. `Zygote.jl` provides mainly reversediff, but there was an experimental support for forwarddiff.
607608

608-
### some thoughts
609-
ri is the variable
610-
∂i is a function to which I need to pass the ∂ri to get the gradinets
609+
### Strategy
610+
We assume that we are provided with the set of AD rules (e.g. ChainRules), which for a given function returns its evaluation and pullback. When `Zygote.jl` is tasked with computing gradient.
611+
1. if a rule exist for this function, return directly the rule
612+
2. if not, deconstruct the function into a sequence of functions using `CodeInfo` / IR representation
613+
3. replace statements by calls to obtain the evaluation of the statements and the pullback
614+
4. chain appropriately pullbacks in reverse order
615+
5. return the function evaluation and chained pullback
611616

612-
# TODO
617+
### Simplified implementation
618+
The following code is adapted from [this example](https://github.com/FluxML/IRTools.jl/blob/master/examples/reverse.jl)
613619

614-
* show constant gradient of linear function in LLVM
620+
```julia
621+
using IRTools, ChainRules
622+
using IRTools: @dynamo, IR, Pipe, finish, substitute, return!, block, blocks,
623+
returnvalue, arguments, isexpr, xcall, self, stmt
624+
625+
struct Pullback{S,T}
626+
data::T
627+
end
628+
629+
Pullback{S}(data) where S = Pullback{S,typeof(data)}(data)
630+
631+
function primal(ir, T = Any)
632+
pr = Pipe(ir)
633+
calls = []
634+
ret = []
635+
for (v, st) in pr
636+
ex = st.expr
637+
if isexpr(ex, :call)
638+
t = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = st.line))
639+
pr[v] = xcall(:getindex, t, 1)
640+
J = push!(pr, xcall(:getindex, t, 2))
641+
push!(calls, v)
642+
push!(ret, J)
643+
end
644+
end
645+
pb = Expr(:call, Pullback{T}, xcall(:tuple, ret...))
646+
return!(pr, xcall(:tuple, returnvalue(block(ir, 1)), pb))
647+
return finish(pr), calls
648+
end
649+
650+
@dynamo function forward(m...)
651+
ir = IR(m...)
652+
ir == nothing && return :(error("Non-differentiable function ", repr(args[1])))
653+
length(blocks(ir)) == 1 || error("control flow is not supported")
654+
return primal(ir, Tuple{m...})[1]
655+
end
656+
657+
```
658+
where
659+
- the generated function `forward` calls `primal` to perform AD manual chainrule
660+
- actual chainrule is performed in the for loop
661+
- every function call is replaced `xcall(Main, :forward, ex.args...)`, which is the recursion we have observed above. `stmt` allows to insert information about lines in the source code).
662+
- the output of the forward is the value of the function, and *pullback*, the function calculating gradient with respect to its inputs.
663+
- `pr[v] = xcall(:getindex, t, 1)` fixes the output of the overwritten function call to be the output of `forward(...)`
664+
- the next line logs the *pullback*
665+
- `Expr(:call, Pullback{T}, xcall(:tuple, ret...))` will serve to call generated function which will assemble the pullback in the right order
666+
667+
Let's now observe how the the IR of `foo` is transformed
668+
```julia
669+
ir = IR(typeof(foo), Float64, Float64)
670+
julia> primal(ir)[1]
671+
1: (%1, %2, %3)
672+
%4 = Main.forward(Main.:*, %2, %3)
673+
%5 = Base.getindex(%4, 1)
674+
%6 = Base.getindex(%4, 2)
675+
%7 = Main.forward(Main.sin, %3)
676+
%8 = Base.getindex(%7, 1)
677+
%9 = Base.getindex(%7, 2)
678+
%10 = Main.forward(Main.:+, %5, %8)
679+
%11 = Base.getindex(%10, 1)
680+
%12 = Base.getindex(%10, 2)
681+
%13 = Base.tuple(%6, %9, %12)
682+
%14 = (Pullback{Any, T} where T)(%13)
683+
%15 = Base.tuple(%11, %14)
684+
return %15
685+
```
686+
- Every function call was transformed into the sequence of `forward(...)` and obtaining first and second item from the returned typle.
687+
- Line `%14` constructs the `Pullback`, which (as will be seen shortly below) will allow to generate the pullback for the generated function
688+
- Line `%15` generates the returned tuple, where the first item is the function value (computed at line `%11`) and pullback (constructed at libe `%15`).
689+
690+
We define few AD rules by specializing `forward` with calls from `ChainRules`
691+
```julia
692+
forward(::typeof(sin), x) = ChainRules.rrule(sin, x)
693+
forward(::typeof(*), x, y) = ChainRules.rrule(*, x, y)
694+
forward(::typeof(+), x, y) = ChainRules.rrule(+, x, y)
695+
```
696+
Zygote implements this inside the generated function, such that whatever is added to `ChainRules` is automatically reflected. The process is not as trivial (see [`has_chain_rule`](https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl)) and for the brevity is not shown here.
697+
698+
We now obtain the value and the pullback of function `foo` as
699+
```julia
700+
julia> v, pb = forward(foo, 1.0, 1.0);
701+
702+
julia> pb(1.0)
703+
(0, 1.0, 1.5403023058681398)
704+
```
705+
The pullback contains in `data` field individual
706+
707+
Let's now turn the attention to the reverse part implemented as
708+
```julia
709+
710+
_sum() = 0
711+
_sum(x) = x
712+
_sum(x...) = xcall(:+, x...)
713+
714+
function pullback(pr)
715+
ir = empty(pr)
716+
grads = Dict()
717+
grad(x) = _sum(get(grads, x, [])...)
718+
grad(x, x̄) = push!(get!(grads, x, []), x̄)
719+
grad(returnvalue(block(pr, 1)), IRTools.argument!(ir))
720+
data = push!(ir, xcall(:getfield, self, QuoteNode(:data)))
721+
_, pbs = primal(pr)
722+
pbs = Dict(pbs[i] => push!(ir, xcall(:getindex, data, i)) for i = 1:length(pbs))
723+
for v in reverse(keys(pr))
724+
ex = pr[v].expr
725+
isexpr(ex, :call) || continue
726+
Δs = push!(ir, Expr(:call, pbs[v], grad(v)))
727+
for (i, x) in enumerate(ex.args)
728+
grad(x, push!(ir, xcall(:getindex, Δs, i)))
729+
end
730+
end
731+
return!(ir, xcall(:tuple, [grad(x) for x in arguments(pr)]...))
732+
end
733+
734+
@dynamo function (pb::Pullback{S})(Δ) where S
735+
return pullback(IR(S.parameters...))
736+
end
737+
```
738+
739+
The implementation is a bit twisted. Function `pullback` obtains the `IR` of the primal function (stored in `S` type parameter of the `Pullback` type). But it is generating call for `(pb::Pullback)(Δ)` therefore it will generate code where it assumes to have access for Jacobinas
740+
```julia
741+
pb.data[1]
742+
pb.data[2]
743+
pb.data[3]
744+
```
745+
746+
Let's walk how the reverse is constructed for `pr = IR(typeof(foo), Float64, Float64)`
747+
```julia
748+
ir = empty(pr)
749+
grads = Dict()
750+
grad(x) = _sum(get(grads, x, [])...)
751+
grad(x, x̄) = push!(get!(grads, x, []), x̄)
752+
```
753+
construct the empty `ir` for the constructed pullback, defines `Dict` where individual contributors of the gradient with respect to certain variable will be stored, and two function for pushing statements to to `grads`. The next statement
754+
```julia
755+
grad(returnvalue(block(pr, 1)), IRTools.argument!(ir))
756+
```
757+
pushes to `grads` statement that the gradient of the output of the primal `pr` is provided as an argument of the pullback `IRTools.argument!(ir)`.
758+
```
759+
data = push!(ir, xcall(:getfield, self, QuoteNode(:data)))
760+
_, pbs = primal(pr)
761+
pbs = Dict(pbs[i] => push!(ir, xcall(:getindex, data, i)) for i = 1:length(pbs))
762+
```
763+
sets `data` to the `data` field of the `Pullback` structure containing pullback functions. Then it create a dictionary `pbs`, where the output of each call in the primal (identified by the line) is mapped to the corresponding pullback, which is now a line in the IR representation.
764+
The IR so far looks as
765+
```julia
766+
1: (%1)
767+
%2 = Base.getfield(IRTools.Inner.Self(), :data)
768+
%3 = Base.getindex(%2, 1)
769+
%4 = Base.getindex(%2, 2)
770+
%5 = Base.getindex(%2, 3)
771+
```
772+
and `pbs` contains
773+
```julia
774+
julia> pbs
775+
Dict{IRTools.Inner.Variable, IRTools.Inner.Variable} with 3 entries:
776+
%6 => %5
777+
%4 => %3
778+
%5 => %4
779+
```
780+
says that the pullback of a function producing variable at line `%6` in the primal is stored at variable `%5` in the contructed pullback.
781+
The real deal comes in the for loop
782+
```julia
783+
for v in reverse(keys(pr))
784+
ex = pr[v].expr
785+
isexpr(ex, :call) || continue
786+
Δs = push!(ir, Expr(:call, pbs[v], grad(v)))
787+
for (i, x) in enumerate(ex.args)
788+
grad(x, push!(ir, xcall(:getindex, Δs, i)))
789+
end
790+
end
791+
```
792+
which iterates the primal `pr` in the reverse order and for every call, it inserts statement to calls the appropriate pullback `Δs = push!(ir, Expr(:call, pbs[v], grad(v)))` and adds gradients with respect to the inputs to values accumulating corresponding gradient in the loop `for (i, x) in enumerate(ex.args) ...`
793+
The last line
794+
```julia
795+
return!(ir, xcall(:tuple, [grad(x) for x in arguments(pr)]...))
796+
```
797+
puts statements accumulating gradients with respect to individual variables to the ir.
798+
799+
The final generated IR code looks as
800+
```julia
801+
julia> pullback(IR(typeof(foo), Float64, Float64))
802+
1: (%1)
803+
%2 = Base.getfield(IRTools.Inner.Self(), :data)
804+
%3 = Base.getindex(%2, 1)
805+
%4 = Base.getindex(%2, 2)
806+
%5 = Base.getindex(%2, 3)
807+
%6 = (%5)(%1)
808+
%7 = Base.getindex(%6, 1)
809+
%8 = Base.getindex(%6, 2)
810+
%9 = Base.getindex(%6, 3)
811+
%10 = (%4)(%9)
812+
%11 = Base.getindex(%10, 1)
813+
%12 = Base.getindex(%10, 2)
814+
%13 = (%3)(%8)
815+
%14 = Base.getindex(%13, 1)
816+
%15 = Base.getindex(%13, 2)
817+
%16 = Base.getindex(%13, 3)
818+
%17 = %12 + %16
819+
%18 = Base.tuple(0, %15, %17)
820+
return %18
821+
```
822+
823+
and it calculates the gradient with respect to the input as
824+
```julia
825+
julia> pb(1.0)
826+
(0, 1.0, 1.5403023058681398)
827+
```
828+
where the first item is gradient with parameters of the function itself.
829+
830+
## Conclusion
831+
The above examples served to demonstrate that `@generated` functions offers extremely powerful paradigm, especially if coupled with manipulation of intermediate representation. Within few lines of code, we have implemented reasonably powerful profiler and reverse AD engine. Importantly, it has been done without a single-purpose engine or tooling.

0 commit comments

Comments
 (0)