Skip to content

Commit ace5330

Browse files
committed
checkpoint for code rewriting
1 parent 9e09404 commit ace5330

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

docs/src/lecture_09/lecture.md

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,177 @@ In theory, we can do all the above by directly modifying the code or introducing
129129

130130
The technique we desire is called contextual dispatch, which means that under some context, we invoke a different function. The library `Casette.jl` provides a high-level api for overdubbing, but it is by no means interesting to see, how it works, as it shows, how we can "interact" with the lowered code before the code is typed.
131131

132+
## insertion of the code
133+
134+
Imagine that julia has compiled some function. For example
135+
```julia
136+
foo(x,y) = x * y + sin(x)
137+
```
138+
Can I get access to lowered form?
139+
```julia
140+
julia> @code_lowered foo(1.0, 1.0)
141+
CodeInfo(
142+
1%1 = x * y
143+
%2 = Main.sin(x)
144+
%3 = %1 + %2
145+
└── return %3
146+
)
147+
```
148+
The lowered form is very nice, because on the left hand, there is **always** one parameter. Such form would be very nice for example for computation of automatic gradients, because it is very close to a computation graph. It is built for you by the compiler. Swell.
149+
150+
The answer is affirmative. with a little bit of effort (copied from `Cassette.jl`), you can have it
151+
```julia
152+
function retrieve_code_info(sigtypes, world = Base.get_world_counter())
153+
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
154+
_methods = Base._methods_by_ftype(S, -1, world)
155+
isempty(_methods) && @error("method $(sigtypes) does not exist, may-be run it once")
156+
type_signature, raw_static_params, method = _methods[1] # method is the same as we would get by invoking methods(+, (Int, Int)).ms[1]
157+
158+
# this provides us with the CodeInfo
159+
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
160+
code_info = Core.Compiler.retrieve_code_info(method_instance)
161+
end
162+
```
163+
164+
And look at the result
165+
```julia
166+
ci = retrieve_code_info((typeof(foo), Float64, Float64))
167+
ci.code
168+
```
169+
https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form
170+
171+
Scheme of overdubbing
172+
1. Let's define context `struct Context ... end`
173+
2. Let's define a generated function with signature
174+
```julia
175+
@generated function overdub(ctx::Context, ::typeof(f), args...)
176+
# find an IR representation of `f(args)` using retrieve_code_info
177+
ci = retrieve_code_info(f, args...)
178+
# modify the code of `f(args)` by replacing the calls with
179+
overdub(ctx, ...
180+
# perform user actions
181+
return the modified expressions
182+
end
183+
```
184+
185+
```julia
186+
struct Context{T<:Union{Nothing, Vector{Symbol}}}
187+
functions::T
188+
end
189+
Context() = Context(nothing)
190+
191+
ctx = Context()
192+
193+
overdubbable(ctx::Context, ex::Expr) = ex.head == :call
194+
overdubbable(ctx::Context, ex) = false
195+
timable(ctx::Context{Nothing}, ex) = true
196+
197+
exprs = []
198+
for (i, ex) in enumerate(ci.code)
199+
if !overdubbable(ctx, ex)
200+
push!(exprs, ex)
201+
continue
202+
end
203+
if timable(ctx, ex)
204+
push!(exprs, Expr(:call, :push!, :to, :start, ex.args[1]))
205+
push!(exprs, Expr(ex.head, :overdub, :ctx, ex.args...))
206+
push!(exprs, Expr(:call, :push!, :to, :stop, ex.args[1]))
207+
else
208+
push!(exprs, ex)
209+
end
210+
end
211+
```
212+
213+
A further complication is that we need to change variables back and also gives them names. if names have existed before. Recall that lowered form introduces additional variables while converting the code to SSA. The variables defined in the source code (argument names and user-defined variables) can be found in `ci.slotnames`,
214+
whereas the variables introduces during lowering to SSA are named by the line number.
215+
Observe a difference between
216+
```julia
217+
julia> foo(x,y) = x * y + sin(x)
218+
foo (generic function with 1 method)
219+
220+
julia> retrieve_code_info((typeof(foo), Float64, Float64)).slotnames
221+
3-element Vector{Symbol}:
222+
Symbol("#self#")
223+
:x
224+
:y
225+
```
226+
and
227+
```julia
228+
julia> function foo(x, y)
229+
z = x * y
230+
z + sin(y)
231+
end
232+
foo (generic function with 1 method)
233+
234+
julia> retrieve_code_info((typeof(foo), Float64, Float64)).slotnames
235+
4-element Vector{Symbol}:
236+
Symbol("#self#")
237+
:x
238+
:y
239+
:z
240+
```
241+
This allows to convert the names back to the original. Let's create a dictionary of assignments
242+
as
243+
```julia
244+
fun_vars = Dict(enumerate(ci.slotnames))
245+
lower_vars = Dict(i => gensym(:left) for i in 1:length(ci.code))
246+
247+
rename_args(ex::Expr, fun_vars, lower_vars) = Expr(ex.head, rename_args(ex.args, fun_vars, lower_vars))
248+
rename_args(args::AbstractArray, fun_vars, lower_vars) = map(a -> rename_args(a, fun_vars, lower_vars), args)
249+
rename_args(a::Core.SlotNumber, fun_vars, lower_vars) = lower_vars[a.id]
250+
rename_args(a::Core.SSAValue, fun_vars, lower_vars) = fun_vars[a.id]
251+
```
252+
and let's redo the rewriting once more while replacing the variables and inserting the left-hand equalities
253+
exprs = []
254+
for (i, ex) in enumerate(ci.code)
255+
ex = rename_args(ex, fun_vars, lower_vars)
256+
if !overdubbable(ctx, ex)
257+
push!(exprs, ex)
258+
continue
259+
end
260+
if timable(ctx, ex)
261+
push!(exprs, Expr(:call, :push!, :to, :start, ex.args[1]))
262+
#ex = Expr(ex.head, :overdub, :ctx, ex.args...)
263+
#push!(exprs, :($(lower_vars[i]) = $(ex))
264+
push!(exprs, Expr(ex.head, :overdub, :ctx, ex.args...))
265+
push!(exprs, Expr(:call, :push!, :to, :stop, ex.args[1]))
266+
else
267+
push!(exprs, ex)
268+
end
269+
end
270+
271+
```julia
272+
macro timeit(ex::Expr)
273+
ex.head != :call && error("timeit is implemented only for function calls")
274+
quote
275+
ctx = Context()
276+
277+
end
278+
end
279+
macro timeit(ex)
280+
error("timeit is implemented only for function calls")
281+
end
282+
```
283+
284+
285+
```julia
286+
struct Calls
287+
stamps::Vector{Float64} # contains the time stamps
288+
event::Vector{Symbol} # name of the function that is being recorded
289+
startstop::Vector{Symbol} # if the time stamp corresponds to start or to stop
290+
i::Ref{Int}
291+
end
292+
293+
function Calls(n::Int)
294+
Calls(Vector{Float64}(undef, n+1), Vector{Symbol}(undef, n+1), Vector{Symbol}(undef, n+1), Ref{Int}(0))
295+
end
296+
297+
global to = Calls(100)
298+
```
299+
300+
301+
302+
## Petite zygote
132303
```julia
133304
world = Base.get_world_counter()
134305
sigtypes = (typeof(+), Int, Int)

0 commit comments

Comments
 (0)