Skip to content

Commit 1cda885

Browse files
committed
an example of timer with codeinfo works
1 parent a1391cd commit 1cda885

File tree

3 files changed

+195
-23
lines changed

3 files changed

+195
-23
lines changed

docs/src/lecture_09/codeinfo.jl

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Generated functions
2+
using Dictionaries
3+
function retrieve_code_info(sigtypes, world = Base.get_world_counter())
4+
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
5+
_methods = Base._methods_by_ftype(S, -1, world)
6+
if isempty(_methods)
7+
@info("method $(sigtypes) does not exist, may-be run it once")
8+
return(nothing)
9+
end
10+
type_signature, raw_static_params, method = _methods[1] # method is the same as we would get by invoking methods(+, (Int, Int)).ms[1]
11+
12+
# this provides us with the CodeInfo
13+
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
14+
code_info = Core.Compiler.retrieve_code_info(method_instance)
15+
end
16+
17+
struct Calls
18+
stamps::Vector{Float64} # contains the time stamps
19+
event::Vector{Symbol} # name of the function that is being recorded
20+
startstop::Vector{Symbol} # if the time stamp corresponds to start or to stop
21+
i::Ref{Int}
22+
end
23+
24+
function Calls(n::Int)
25+
Calls(Vector{Float64}(undef, n+1), Vector{Symbol}(undef, n+1), Vector{Symbol}(undef, n+1), Ref{Int}(0))
26+
end
27+
28+
function Base.show(io::IO, calls::Calls)
29+
offset = 0
30+
for i in 1:calls.i[]
31+
offset -= calls.startstop[i] == :stop
32+
foreach(_ -> print(io, " "), 1:max(offset, 0))
33+
rel_time = calls.stamps[i] - calls.stamps[1]
34+
println(io, calls.event[i], ": ", rel_time)
35+
offset += calls.startstop[i] == :start
36+
end
37+
end
38+
39+
global const to = Calls(100)
40+
41+
function record_start(ev::Symbol)
42+
calls = Main.to
43+
n = calls.i[] = calls.i[] + 1
44+
n > length(calls.stamps) && return
45+
calls.event[n] = ev
46+
calls.startstop[n] = :start
47+
calls.stamps[n] = time_ns()
48+
end
49+
50+
function record_end(ev::Symbol)
51+
t = time_ns()
52+
calls = Main.to
53+
n = calls.i[] = calls.i[] + 1
54+
n > length(calls.stamps) && return
55+
calls.event[n] = ev
56+
calls.startstop[n] = :stop
57+
calls.stamps[n] = t
58+
end
59+
60+
reset!(calls::Calls) = calls.i[] = 0
61+
62+
63+
function overdubbable(ex::Expr)
64+
ex.head != :call && return(false)
65+
length(ex.args) < 2 && return(false)
66+
(ex.args[1] isa Core.IntrinsicFunction) && return(false)
67+
return(true)
68+
end
69+
# overdubbable(ex::Expr) = false
70+
overdubbable(ex) = false
71+
# overdubbable(ctx::Context, ex::Expr) = ex.head == :call
72+
# overdubbable(ctx::Context, ex) = false
73+
# timable(ctx::Context{Nothing}, ex) = true
74+
timable(ex::Expr) = ex.head == :call
75+
timable(ex) = false
76+
77+
rename_args(ex, slotvar, ssamap) = ex
78+
rename_args(c::Core.GotoIfNot, slotvar, ssamap) = Core.GotoIfNot(rename_args(c.cond, slotvar, ssamap), ssamap[c.dest])
79+
rename_args(ex::Expr, slotvar, ssamap) = Expr(ex.head, rename_args(ex.args, slotvar, ssamap)...)
80+
rename_args(args::AbstractArray, slotvar, ssamap) = map(a -> rename_args(a, slotvar, ssamap), args)
81+
rename_args(r::Core.ReturnNode, slotvar, ssamap) = Core.ReturnNode(rename_args(r.val, slotvar, ssamap))
82+
rename_args(a::Core.SlotNumber, slotvar, ssamap) = slotvar[a.id]
83+
rename_args(a::Core.SSAValue, slotvar, ssamap) = Core.SSAValue(ssamap[a.id])
84+
85+
exportname(ex::GlobalRef) = QuoteNode(ex.name)
86+
exportname(ex::Expr) = exportname(ex.args[1])
87+
exportname(i::Int) = QuoteNode(Symbol("Int(",i,")"))
88+
89+
using Base: invokelatest
90+
dummy() = return
91+
92+
overdub(f::Core.IntrinsicFunction, args...) = f(args...)
93+
94+
@generated function overdub(f::F, args...) where {F}
95+
@show (F, args...)
96+
ci = retrieve_code_info((F, args...))
97+
if ci === nothing
98+
@show Expr(:call, :f, [:(args[$(i)]) for i in 1:length(args)]...)
99+
return(Expr(:call, :f, [:(args[$(i)]) for i in 1:length(args)]...))
100+
end
101+
# this is to initialize a new CodeInfo and fill it with values from the
102+
# overdubbed function
103+
new_ci = code_lowered(dummy, Tuple{})[1]
104+
empty!(new_ci.code)
105+
empty!(new_ci.slotnames)
106+
foreach(s -> push!(new_ci.slotnames, s), ci.slotnames)
107+
new_ci.slotnames = vcat([Symbol("#self#"), :f, :args], ci.slotnames[length(args)+2:end])
108+
new_ci.slotflags = vcat([0x00, 0x00, 0x00], ci.slotflags[length(args)+2:end])
109+
empty!(new_ci.linetable)
110+
foreach(s -> push!(new_ci.linetable, s), ci.linetable)
111+
empty!(new_ci.codelocs)
112+
113+
ssamap = Dict{Int, Int}()
114+
slotvar = Dict{Int, Any}()
115+
for i in 1:length(args)
116+
push!(new_ci.code, Expr(:call, Base.getindex, Core.SlotNumber(3), i))
117+
slotvar[i+1] = Core.SSAValue(i)
118+
push!(new_ci.codelocs, ci.codelocs[1])
119+
end
120+
slotvar[1] = Core.SlotNumber(1)
121+
foreach(i -> slotvar[i[2]] = Core.SlotNumber(i[1]+3), enumerate(length(args)+2:length(ci.slotnames)))
122+
123+
j = length(args)
124+
for (i, ex) in enumerate(ci.code)
125+
if timable(ex)
126+
fname = exportname(ex)
127+
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_start), fname))
128+
push!(new_ci.codelocs, ci.codelocs[i])
129+
j += 1
130+
# ex = overdubbable(ex) ? Expr(:call, :overdub, ex.args...) : ex
131+
# ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
132+
push!(new_ci.code, ex)
133+
push!(new_ci.codelocs, ci.codelocs[i])
134+
j += 1
135+
ssamap[i] = j
136+
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_end), fname))
137+
push!(new_ci.codelocs, ci.codelocs[i])
138+
j += 1
139+
else
140+
push!(new_ci.code, ex)
141+
push!(new_ci.codelocs, ci.codelocs[i])
142+
j += 1
143+
ssamap[i] = j
144+
end
145+
end
146+
for i in length(args)+1:length(new_ci.code)
147+
new_ci.code[i] = rename_args(new_ci.code[i], slotvar, ssamap)
148+
end
149+
new_ci
150+
151+
# Core.Compiler.replace_code_newstyle!(ci, ir, length(ir.argtypes)-1)
152+
new_ci.inferred = false
153+
new_ci.ssavaluetypes = length(new_ci.code)
154+
# new_ci
155+
# new_ci
156+
return(new_ci)
157+
end
158+
159+
function foo(x, y)
160+
z = x * y
161+
z + sin(y)
162+
end
163+
164+
reset!(to)
165+
overdub(foo, 1.0, 1.0)
166+
overdub(sin, 1.0)
167+
to
168+
169+
function overdub2(::typeof(foo), args...)
170+
x = args[1]
171+
y = args[2]
172+
push!(to, :start, :fun)
173+
z = x * y
174+
push!(to, :stop, :fun)
175+
push!(to, :start, :fun)
176+
r = z + sin(y)
177+
push!(to, :stop, :fun)
178+
r
179+
end
180+
181+
ci = retrieve_code_info((typeof(overdub2), typeof(foo), Float64, Float64))
182+

docs/src/lecture_09/lecture.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,13 @@ end
313313
Expr(:block, exprs...)
314314
```
315315
316-
316+
Wrapping the code generation to the generated function `overdub`,
317317
```julia
318318
@generated function overdub(ctx::Context, f::F, args...) where {F}
319319
@show (F, args...)
320320
ci = retrieve_code_info((F, args...))
321321
slot_vars = Dict(enumerate(ci.slotnames))
322-
# ssa_vars = Dict(i => gensym(:left) for i in 1:length(ci.code))
323-
ssa_vars = Dict(i => Symbol(:L,i) for i in 1:length(ci.code))
322+
ssa_vars = Dict(i => gensym(:left) for i in 1:length(ci.code))
324323
used = assigned_vars(ci.code) |> distinct
325324
exprs = []
326325
for (i, ex) in enumerate(ci.code)
@@ -343,6 +342,8 @@ Expr(:block, exprs...)
343342
Expr(:block, exprs...)
344343
end
345344
```
345+
we can try to call it as `overdub(ctx, foo, 1.0, 1.0)`, but the code fails due to unknown variable `x`. Why is that? The original function `foo(x, y)` had arguments `x` and `y`, whereas our generated function does not know them. Instead, it knows `args...`. Therefore we will put assigning statems to the beggining of the generated function call
346+
```
346347

347348
```
348349
macro meta(ex)

docs/src/lecture_09/timer.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ function Calls(n::Int)
2222
Calls(Vector{Float64}(undef, n+1), Vector{Symbol}(undef, n+1), Vector{Symbol}(undef, n+1), Ref{Int}(0))
2323
end
2424

25+
global const to = Calls(100)
26+
2527
function Base.show(io::IO, calls::Calls)
2628
for i in 1:calls.i[]
2729
println(io, calls.stamps[i] - calls.stamps[1]," ", calls.startstop[i]," ",calls.event[i])
@@ -43,27 +45,11 @@ struct Context{T<:Union{Nothing, Vector{Symbol}}}
4345
end
4446
Context() = Context(nothing)
4547

46-
ctx = Context()
47-
48-
function overdubbable(ex::Expr)
49-
ex.head != :call && return(false)
50-
length(ex.args) < 2 && return(false)
51-
(ex.args[1] isa Core.IntrinsicFunction) && return(false)
52-
return(true)
53-
end
5448
overdubbable(ex::Expr) = false
5549
overdubbable(ex) = false
56-
# overdubbable(ctx::Context, ex::Expr) = ex.head == :call
57-
# overdubbable(ctx::Context, ex) = false
58-
# timable(ctx::Context{Nothing}, ex) = true
5950
timable(ex::Expr) = ex.head == :call
6051
timable(ex) = false
6152

62-
function foo(x, y)
63-
z = x * y
64-
z + sin(y)
65-
end
66-
6753
rename_args(ex, slot_vars, ssa_vars) = ex
6854
rename_args(ex::Expr, slot_vars, ssa_vars) = Expr(ex.head, rename_args(ex.args, slot_vars, ssa_vars)...)
6955
rename_args(args::AbstractArray, slot_vars, ssa_vars) = map(a -> rename_args(a, slot_vars, ssa_vars), args)
@@ -95,7 +81,6 @@ overdub(ctx::Context, f::Core.IntrinsicFunction, args...) = f(args...)
9581
end
9682
for (i, ex) in enumerate(ci.code)
9783
ex = rename_args(ex, slot_vars, ssa_vars)
98-
@show ex
9984
if ex isa Core.ReturnNode
10085
push!(exprs, Expr(:return, ex.val))
10186
continue
@@ -120,10 +105,14 @@ overdub(ctx::Context, f::Core.IntrinsicFunction, args...) = f(args...)
120105
end
121106

122107

123-
global const to = Calls(100)
108+
function foo(x, y)
109+
z = x * y
110+
z + sin(y)
111+
end
124112
reset!(to)
125-
ctx = Context()
126-
overdub(ctx, foo, 1.0, 1.0)
113+
overdub(Context(), foo, 1.0, 1.0)
114+
to
115+
127116
reset!(to)
128117
overdub(ctx, Base.Math.sin_kernel, 1.0)
129118

0 commit comments

Comments
 (0)