Skip to content

Commit 075e704

Browse files
author
pevnak
committed
lecture 9
1 parent c1fda5b commit 075e704

File tree

1 file changed

+66
-29
lines changed

1 file changed

+66
-29
lines changed

docs/src/lectures/lecture_09/lecture_v2.md

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
# Source-to-Source Automatic Differentiation
22

3+
This lecture was tested with Julia version 1.7
4+
```julia
5+
julia> versioninfo()
6+
Julia Version 1.11.7
7+
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
8+
Build Info:
9+
Official https://julialang.org/ release
10+
Platform Info:
11+
OS: macOS (arm64-apple-darwin24.0.0)
12+
CPU: 14 × Apple M4 Pro
13+
WORD_SIZE: 64
14+
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
15+
Threads: 1 default, 0 interactive, 1 GC (on 10 virtual cores)
16+
```
17+
Importantly, it does not work with 1.12 due to changes in the compiler.
18+
319
Before diving into the next adventure in automatic differentiation (AD), we spent some time exploring the role of *rules* in AD.
420

521
The automatic differention libraries consists of two parts.
@@ -137,14 +153,14 @@ ir, _ = only(Base.code_ircode(foo, (Float64, Float64); optimize_until = "compact
137153

138154
We can ask for different level of optimization, but such a simple function, it will not make much difference:
139155
```julia
140-
@pass "convert" ir = convert_to_ircode(ci, sv)
141-
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
142-
@pass "compact 1" ir = compact!(ir)
143-
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
144-
@pass "compact 2" ir = compact!(ir)
145-
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
146-
@pass "ADCE" ir = adce_pass!(ir, sv.inlining)
147-
@pass "compact 3" ir = compact!(ir)
156+
CC.@pass "convert" ir = convert_to_ircode(ci, sv)
157+
CC.@pass "slot2reg" ir = slot2reg(ir, ci, sv)
158+
CC.@pass "compact 1" ir = compact!(ir)
159+
CC.@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
160+
CC.@pass "compact 2" ir = compact!(ir)
161+
CC.@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
162+
CC.@pass "ADCE" ir = adce_pass!(ir, sv.inlining)
163+
CC.@pass "compact 3" ir = compact!(ir)
148164
```
149165

150166

@@ -164,35 +180,37 @@ The code is stored in `IRCode` data structure.
164180
stmts::InstructionStream
165181
argtypes::Vector{Any}
166182
sptypes::Vector{VarState}
167-
linetable::Vector{LineInfoNode}
183+
debuginfo::Compiler.DebugInfoStream
168184
cfg::CFG
169185
new_nodes::NewNodeStream
170186
meta::Vector{Expr}
187+
valid_worlds::Compiler.WorldRange
171188
end
172189
```
173190

174191
where
175192
* `stmts` is a stream of instruction (more in this below)
176193
* `argtypes` holds types of arguments of the function whose `IRCode` we have obtained
177194
* `sptypes` is a vector of `VarState`. It seems to be related to parameters of types
178-
* `linetable` is a table of unique lines in the source code from which statement came from
195+
* `debuginfo` is a table of unique lines in the source code from which statement came from
179196
* `cfg` holds control flow graph, which contains building blocks and jumps between them
180197
* `new_nodes` is an infrastructure that can be used to insert new instructions to the existing `IRCode` . The idea behind is that since insertion requires a renumbering all statements, they are put in a separate queue. They are put to correct position with a correct `SSANumber` by calling `compact!`.
181198
* `meta` is something.
199+
* `valid_worlds` specify a "time" span in which the world is valid
182200

183201
**InstructionStream**
184202

185203
```julia
186204
struct InstructionStream
187-
inst::Vector{Any}
205+
stmt::Vector{Any}
188206
type::Vector{Any}
189207
info::Vector{CallInfo}
190208
line::Vector{Int32}
191209
flag::Vector{UInt8}
192210
end
193211
```
194212
where
195-
* `inst` is a vector of instructions, stored as `Expr`essions. The allowed fields in `head` are described [here](https://docs.julialang.org/en/v1/devdocs/ast/#Expr-types)
213+
* `stmt` is a vector of instructions, stored as `Expr`essions. The allowed fields in `head` are described [here](https://docs.julialang.org/en/v1/devdocs/ast/#Expr-types)
196214
* `type` is the type of the value returned by the corresponding statement
197215
* `CallInfo` is ???some info???
198216
* `line` is an index into `IRCode.linetable` identifying from which line in source code the statement comes from
@@ -208,15 +226,16 @@ The code is stored in `IRCode` data structure.
208226
For the above `foo` function, the InstructionStream looks like
209227

210228
```julia
211-
julia DataFrame(flag = ir.stmts.flag, info = ir.stmts.info, inst = ir.stmts.inst, line = ir.stmts.line, type = ir.stmts.type)
229+
julia DataFrame(flag = ir.stmts.flag, info = ir.stmts.info, stmt = ir.stmts.stmt, type = ir.stmts.type)
212230
4×5 DataFrame
213-
Row │ flag info inst line type
214-
│ UInt8 CallInfo Any Int32 Any
215-
─────┼────────────────────────────────────────────────────────────────────────
216-
1 │ 112 MethodMatchInfo(MethodLookupResu… _2 * _3 1 Float64
217-
2 │ 80 MethodMatchInfo(MethodLookupResu… Main.sin(_2) 2 Float64
218-
3 │ 112 MethodMatchInfo(MethodLookupResu… %1 + %2 2 Float64
219-
4 │ 0 NoCallInfo() return %3 2 Any
231+
Row │ flag info stmt type
232+
│ UInt32 CallInfo Any Any
233+
─────┼──────────────────────────────────────────────────────────────────
234+
1 │ 9336 ConstCallInfo(MethodMatchInfo(Me… Float64
235+
2 │ 9336 MethodMatchInfo(MethodLookupResu… _2 * _3 Float64
236+
3 │ 9304 MethodMatchInfo(MethodLookupResu… Main.sin(_2) Float64
237+
4 │ 9336 MethodMatchInfo(MethodLookupResu… %2 + %3 Float64
238+
5 │ 131072 NoCallInfo() return %4 Any
220239
```
221240
We can index into the statements as `ir.stmts[1]`, which provides a "view" into the vector. To obtain the first instruction, we can do `ir.stmts[1][:inst]`.
222241

@@ -231,7 +250,7 @@ Let's now go back to the problem of automatic differentiation. Recall the IRCode
231250
└── return %3
232251
=> Float64
233252
```
234-
The forward part needs to replace each call of the function by a call to `rrule` and stode pullbacks.
253+
The forward part needs to replace each call of the function by a call to `rrule` and store pullbacks.
235254
So in pseudocode, we want something like
236255
```julia
237256
(%1, %2) = rrule(*, _2, _3)
@@ -249,7 +268,7 @@ To implement the code performing the above transformation, we initiate few varia
249268
```julia
250269
adinfo = [] # storage for informations about pullbacks, needed for the construction of the reverse pass
251270
new_insts = Any[] # storate for instructions
252-
new_line = Int32[] # Index of instruction we are differentiating
271+
new_line = Int32[] # Index of instruction we are differentiating
253272
ssamap = Dict{SSAValue,SSAValue}() # this maps old SSA values to new SSA values, since they need to be linearly ordered.
254273
```
255274

@@ -274,7 +293,7 @@ The main loop transforming the function looks like foollows
274293

275294
```julia
276295
for (i, stmt) in enumerate(ir.stmts)
277-
inst = stmt[:inst]
296+
inst = stmt[:stmt]
278297
if inst isa Expr && inst.head == :call
279298
new_inst = Expr(:call, GlobalRef(ChainRules, :rrule), remap_ssa(ssamap, inst.args)...)
280299
push!(new_insts, new_inst)
@@ -306,6 +325,13 @@ for (i, stmt) in enumerate(ir.stmts)
306325
push!(new_line, stmt[:line])
307326
continue
308327
end
328+
329+
if inst isa Nothing
330+
push!(new_insts, nothing)
331+
new_ssa = SSAValue(length(new_insts))
332+
push!(new_line, stmt[:line])
333+
continue
334+
end
309335
error("unknown node $(i)")
310336
end
311337
```
@@ -317,11 +343,11 @@ constructing Instruction stream as
317343

318344
```julia
319345
stmts = CC.InstructionStream(
320-
new_insts,
321-
fill(Any, length(new_insts)),
322-
fill(CC.NoCallInfo(), length(new_insts)),
323-
new_line,
324-
fill(CC.IR_FLAG_REFINED, length(new_insts)),
346+
new_insts,
347+
fill(Any, length(new_insts)),
348+
fill(CC.NoCallInfo(), length(new_insts)),
349+
new_line,
350+
fill(CC.IR_FLAG_REFINED, length(new_insts)),
325351
)
326352
```
327353

@@ -371,7 +397,7 @@ pullback[1](1.0) == (ChainRules.NoTangent(), 1.0, 1.0)
371397

372398
Given some IR, generates a MethodInstance suitable for passing to infer_ir!, if you don't
373399
already have one with the right argument types. [Credit to@oxinabox:
374-
(https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54)
400+
(https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802)
375401

376402
```julia
377403
_type(x::Type) = x
@@ -493,6 +519,12 @@ function construct_forward(ir)
493519
end
494520
```
495521

522+
We can try to run the above as
523+
```julia
524+
oc = Core.OpaqueClosure(forward_ir)
525+
value, pullback = oc(1.0, 1.0)
526+
```
527+
496528
### Implementing reverse pass
497529
The reverse parts of a linear IR is relatively simple. In a nutshell, the code
498530
code needs to iterate over the pullbacks in the reverse order, execute them, and
@@ -597,6 +629,11 @@ end
597629
bar(x) = 5 * x
598630
```
599631

632+
Test as
600633

634+
```julia
635+
gradient(CachedGrad(foo, 1.0, 1.0), 1.0, 1.0)
636+
gradient(CachedGrad(bar, 1.0), 1.0)
637+
```
601638

602639
[1] [Autodiff by G. Dalle](https://gdalle.github.io/JuliaOptimizationDays2024-AutoDiff/#/)

0 commit comments

Comments
 (0)