Skip to content

Commit 4c362bb

Browse files
committed
overdubbing sin example works by now
1 parent 1cda885 commit 4c362bb

File tree

2 files changed

+150
-88
lines changed

2 files changed

+150
-88
lines changed

docs/src/lecture_09/calls.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
struct Calls
3+
stamps::Vector{Float64} # contains the time stamps
4+
event::Vector{Symbol} # name of the function that is being recorded
5+
startstop::Vector{Symbol} # if the time stamp corresponds to start or to stop
6+
i::Ref{Int}
7+
end
8+
9+
function Calls(n::Int)
10+
Calls(Vector{Float64}(undef, n+1), Vector{Symbol}(undef, n+1), Vector{Symbol}(undef, n+1), Ref{Int}(0))
11+
end
12+
13+
function Base.show(io::IO, calls::Calls)
14+
offset = 0
15+
for i in 1:calls.i[]
16+
offset -= calls.startstop[i] == :stop
17+
foreach(_ -> print(io, " "), 1:max(offset, 0))
18+
rel_time = calls.stamps[i] - calls.stamps[1]
19+
println(io, calls.event[i], ": ", rel_time)
20+
offset += calls.startstop[i] == :start
21+
end
22+
end
23+
24+
global const to = Calls(100)
25+
26+
"""
27+
record_start(ev::Symbol)
28+
29+
record the start of the event, the time stamp is recorded after all counters are
30+
appropriately increased
31+
"""
32+
function record_start(ev::Symbol)
33+
calls = Main.to
34+
n = calls.i[] = calls.i[] + 1
35+
n > length(calls.stamps) && return
36+
calls.event[n] = ev
37+
calls.startstop[n] = :start
38+
calls.stamps[n] = time_ns()
39+
end
40+
41+
"""
42+
record_end(ev::Symbol)
43+
44+
record the end of the event, the time stamp is recorded before all counters are
45+
appropriately increased
46+
"""
47+
function record_end(ev::Symbol)
48+
t = time_ns()
49+
calls = Main.to
50+
n = calls.i[] = calls.i[] + 1
51+
n > length(calls.stamps) && return
52+
calls.event[n] = ev
53+
calls.startstop[n] = :stop
54+
calls.stamps[n] = t
55+
end
56+
57+
reset!(calls::Calls) = calls.i[] = 0
58+

docs/src/lecture_09/codeinfo.jl

Lines changed: 92 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,60 @@
11
# Generated functions
22
using Dictionaries
3+
include("calls.jl")
4+
35
function retrieve_code_info(sigtypes, world = Base.get_world_counter())
46
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
57
_methods = Base._methods_by_ftype(S, -1, world)
68
if isempty(_methods)
7-
@info("method $(sigtypes) does not exist, may-be run it once")
9+
@info("method $(sigtypes) does not exist")
810
return(nothing)
911
end
10-
type_signature, raw_static_params, method = _methods[1] # method is the same as we would get by invoking methods(+, (Int, Int)).ms[1]
11-
12-
# this provides us with the CodeInfo
12+
type_signature, raw_static_params, method = _methods[1]
1313
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
1414
code_info = Core.Compiler.retrieve_code_info(method_instance)
1515
end
1616

17-
struct Calls
18-
stamps::Vector{Float64} # contains the time stamps
19-
event::Vector{Symbol} # name of the function that is being recorded
20-
startstop::Vector{Symbol} # if the time stamp corresponds to start or to stop
21-
i::Ref{Int}
22-
end
23-
24-
function Calls(n::Int)
25-
Calls(Vector{Float64}(undef, n+1), Vector{Symbol}(undef, n+1), Vector{Symbol}(undef, n+1), Ref{Int}(0))
26-
end
27-
28-
function Base.show(io::IO, calls::Calls)
29-
offset = 0
30-
for i in 1:calls.i[]
31-
offset -= calls.startstop[i] == :stop
32-
foreach(_ -> print(io, " "), 1:max(offset, 0))
33-
rel_time = calls.stamps[i] - calls.stamps[1]
34-
println(io, calls.event[i], ": ", rel_time)
35-
offset += calls.startstop[i] == :start
36-
end
37-
end
38-
39-
global const to = Calls(100)
40-
41-
function record_start(ev::Symbol)
42-
calls = Main.to
43-
n = calls.i[] = calls.i[] + 1
44-
n > length(calls.stamps) && return
45-
calls.event[n] = ev
46-
calls.startstop[n] = :start
47-
calls.stamps[n] = time_ns()
48-
end
49-
50-
function record_end(ev::Symbol)
51-
t = time_ns()
52-
calls = Main.to
53-
n = calls.i[] = calls.i[] + 1
54-
n > length(calls.stamps) && return
55-
calls.event[n] = ev
56-
calls.startstop[n] = :stop
57-
calls.stamps[n] = t
58-
end
59-
60-
reset!(calls::Calls) = calls.i[] = 0
61-
62-
6317
function overdubbable(ex::Expr)
6418
ex.head != :call && return(false)
6519
length(ex.args) < 2 && return(false)
66-
(ex.args[1] isa Core.IntrinsicFunction) && return(false)
67-
return(true)
20+
ex.args[1] isa Core.GlobalRef && return(true)
21+
ex.args[1] isa Symbol && return(true)
22+
return(false)
6823
end
24+
6925
# overdubbable(ex::Expr) = false
7026
overdubbable(ex) = false
7127
# overdubbable(ctx::Context, ex::Expr) = ex.head == :call
7228
# overdubbable(ctx::Context, ex) = false
7329
# timable(ctx::Context{Nothing}, ex) = true
74-
timable(ex::Expr) = ex.head == :call
30+
timable(ex::Expr) = overdubbable(ex)
7531
timable(ex) = false
7632

77-
rename_args(ex, slotvar, ssamap) = ex
78-
rename_args(c::Core.GotoIfNot, slotvar, ssamap) = Core.GotoIfNot(rename_args(c.cond, slotvar, ssamap), ssamap[c.dest])
79-
rename_args(ex::Expr, slotvar, ssamap) = Expr(ex.head, rename_args(ex.args, slotvar, ssamap)...)
80-
rename_args(args::AbstractArray, slotvar, ssamap) = map(a -> rename_args(a, slotvar, ssamap), args)
81-
rename_args(r::Core.ReturnNode, slotvar, ssamap) = Core.ReturnNode(rename_args(r.val, slotvar, ssamap))
82-
rename_args(a::Core.SlotNumber, slotvar, ssamap) = slotvar[a.id]
83-
rename_args(a::Core.SSAValue, slotvar, ssamap) = Core.SSAValue(ssamap[a.id])
33+
#
34+
remap(ex::Expr, maps) = Expr(ex.head, remap(ex.args, maps)...)
35+
remap(args::AbstractArray, maps) = map(a -> remap(a, maps), args)
36+
remap(c::Core.GotoNode, maps) = Core.GotoNode(maps.goto[c.label])
37+
remap(c::Core.GotoIfNot, maps) = Core.GotoIfNot(remap(c.cond, maps), maps.goto[c.dest])
38+
remap(r::Core.ReturnNode, maps) = Core.ReturnNode(remap(r.val, maps))
39+
remap(a::Core.SlotNumber, maps) = maps.slots[a.id]
40+
remap(a::Core.SSAValue, maps) = Core.SSAValue(maps.ssa[a.id])
41+
remap(a::Core.NewvarNode, maps) = Core.NewvarNode(maps.slots[a.slot.id])
42+
remap(a::GlobalRef, maps) = a
43+
remap(a::QuoteNode, maps) = a
44+
remap(ex, maps) = ex
45+
46+
# remove static parameters (see https://discourse.julialang.org/t/does-overdubbing-in-generated-function-inserts-inlined-code/71868)
47+
remove_static(ex) = ex
48+
function remove_static(ex::Expr)
49+
ex.head != :call && return(ex)
50+
length(ex.args) != 2 && return(ex)
51+
!(ex.args[1] isa Expr) && return(ex)
52+
(ex.args[1].head == :static_parameter) && return(ex.args[2])
53+
return(ex)
54+
end
8455

8556
exportname(ex::GlobalRef) = QuoteNode(ex.name)
57+
exportname(ex::Symbol) = QuoteNode(ex)
8658
exportname(ex::Expr) = exportname(ex.args[1])
8759
exportname(i::Int) = QuoteNode(Symbol("Int(",i,")"))
8860

@@ -103,69 +75,87 @@ overdub(f::Core.IntrinsicFunction, args...) = f(args...)
10375
new_ci = code_lowered(dummy, Tuple{})[1]
10476
empty!(new_ci.code)
10577
empty!(new_ci.slotnames)
106-
foreach(s -> push!(new_ci.slotnames, s), ci.slotnames)
107-
new_ci.slotnames = vcat([Symbol("#self#"), :f, :args], ci.slotnames[length(args)+2:end])
108-
new_ci.slotflags = vcat([0x00, 0x00, 0x00], ci.slotflags[length(args)+2:end])
10978
empty!(new_ci.linetable)
110-
foreach(s -> push!(new_ci.linetable, s), ci.linetable)
11179
empty!(new_ci.codelocs)
11280

113-
ssamap = Dict{Int, Int}()
114-
slotvar = Dict{Int, Any}()
81+
new_ci.slotnames = vcat([Symbol("#self#"), :f, :args], ci.slotnames[2:end])
82+
new_ci.slotflags = vcat([0x00, 0x00, 0x00], ci.slotflags[2:end])
83+
foreach(s -> push!(new_ci.linetable, s), ci.linetable)
84+
85+
maps = (
86+
ssa = Dict{Int, Int}(),
87+
slots = Dict{Int, Any}(),
88+
goto = Dict{Int,Int}(),
89+
)
90+
91+
#we need to map indexes of slot-variables from ci to their new values.
92+
# except the first one, we just remap them
93+
maps.slots[1] = Core.SlotNumber(1)
94+
foreach(i -> maps.slots[i] = Core.SlotNumber(i + 2), 2:length(ci.slotnames)) # they are shifted by 2 accomondating inserted `f` and `args`
95+
@assert all(ci.slotnames[i] == new_ci.slotnames[maps.slots[i].id] for i in 1:length(ci.slotnames)) #test that the remapping is right
96+
97+
#if somewhere the original parameters of the functions will be used
98+
#they needs to be remapped to an SSAValue from here, since the overdubbed
99+
# function has signatures overdub(f, args...) instead of f(x,y,z...)
100+
ssa_no = 0
115101
for i in 1:length(args)
102+
ssa_no +=1
116103
push!(new_ci.code, Expr(:call, Base.getindex, Core.SlotNumber(3), i))
117-
slotvar[i+1] = Core.SSAValue(i)
104+
maps.slots[i+1] = Core.SSAValue(ssa_no)
118105
push!(new_ci.codelocs, ci.codelocs[1])
119106
end
120-
slotvar[1] = Core.SlotNumber(1)
121-
foreach(i -> slotvar[i[2]] = Core.SlotNumber(i[1]+3), enumerate(length(args)+2:length(ci.slotnames)))
122107

123-
j = length(args)
124-
for (i, ex) in enumerate(ci.code)
108+
for (ci_line, ex) in enumerate(ci.code)
125109
if timable(ex)
126110
fname = exportname(ex)
127111
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_start), fname))
128-
push!(new_ci.codelocs, ci.codelocs[i])
129-
j += 1
130-
# ex = overdubbable(ex) ? Expr(:call, :overdub, ex.args...) : ex
131-
# ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
112+
push!(new_ci.codelocs, ci.codelocs[ci_line])
113+
ssa_no += 1
114+
maps.goto[ci_line] = ssa_no
115+
ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
132116
push!(new_ci.code, ex)
133-
push!(new_ci.codelocs, ci.codelocs[i])
134-
j += 1
135-
ssamap[i] = j
117+
push!(new_ci.codelocs, ci.codelocs[ci_line])
118+
ssa_no += 1
119+
maps.ssa[ci_line] = ssa_no
136120
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_end), fname))
137-
push!(new_ci.codelocs, ci.codelocs[i])
138-
j += 1
121+
push!(new_ci.codelocs, ci.codelocs[ci_line])
122+
ssa_no += 1
139123
else
140124
push!(new_ci.code, ex)
141-
push!(new_ci.codelocs, ci.codelocs[i])
142-
j += 1
143-
ssamap[i] = j
125+
push!(new_ci.codelocs, ci.codelocs[ci_line])
126+
ssa_no += 1
127+
maps.ssa[ci_line] = ssa_no
144128
end
145129
end
130+
146131
for i in length(args)+1:length(new_ci.code)
147-
new_ci.code[i] = rename_args(new_ci.code[i], slotvar, ssamap)
132+
ex = remove_static(new_ci.code[i])
133+
new_ci.code[i] = remap(ex, maps)
148134
end
149135
new_ci
150136

151137
# Core.Compiler.replace_code_newstyle!(ci, ir, length(ir.argtypes)-1)
152138
new_ci.inferred = false
153139
new_ci.ssavaluetypes = length(new_ci.code)
154140
# new_ci
155-
# new_ci
156141
return(new_ci)
157142
end
158143

159144
function foo(x, y)
160-
z = x * y
145+
z = x * y
161146
z + sin(y)
162147
end
163148

149+
164150
reset!(to)
165151
overdub(foo, 1.0, 1.0)
166-
overdub(sin, 1.0)
152+
new_ci = overdub(sin, 1.0)
167153
to
168154

155+
# Seems like now, I am crashing here
156+
# typeof(Base._promote), Irrational{:π}, Int64)
157+
158+
169159
function overdub2(::typeof(foo), args...)
170160
x = args[1]
171161
y = args[2]
@@ -177,6 +167,20 @@ function overdub2(::typeof(foo), args...)
177167
push!(to, :stop, :fun)
178168
r
179169
end
180-
181170
ci = retrieve_code_info((typeof(overdub2), typeof(foo), Float64, Float64))
182171

172+
173+
function test(x::T) where T<:Union{Float64, Float32}
174+
x < T(pi)
175+
end
176+
ci = retrieve_code_info((typeof(test), Float64))
177+
178+
179+
function overdub_test(::typeof(test), args...)
180+
x = args[1]
181+
T = eltype(x)
182+
x < T(pi)
183+
end
184+
185+
ci = retrieve_code_info((typeof(overdub_test), typeof(test), Float64))
186+

0 commit comments

Comments
 (0)