Skip to content

Commit b7c352c

Browse files
committed
improved the lecture, added the partial inlining
1 parent 7a131ea commit b7c352c

File tree

3 files changed

+473
-289
lines changed

3 files changed

+473
-289
lines changed

docs/src/lecture_09/codeinfo.jl

Lines changed: 61 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,36 @@ function retrieve_code_info(sigtypes, world = Base.get_world_counter())
1010
return(nothing)
1111
end
1212
type_signature, raw_static_params, method = _methods[1]
13-
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
14-
code_info = Core.Compiler.retrieve_code_info(method_instance)
13+
mi = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
14+
ci = Base.isgenerated(mi) ? Core.Compiler.get_staged(mi) : Base.uncompressed_ast(method)
15+
Base.Meta.partially_inline!(ci.code, [], method.sig, Any[raw_static_params...], 0, 0, :propagate)
16+
ci
1517
end
1618

19+
# # We migth consider this from IRTools. Importantly, it has partially_inline to get
20+
# # rid of static parameters
21+
# function meta(T; types = T, world = worldcounter())
22+
# F = T.parameters[1]
23+
# F == typeof(invoke) && return invoke_meta(T; world = world)
24+
# F isa DataType && (F.name.module === Core.Compiler ||
25+
# F <: Core.Builtin ||
26+
# F <: Core.Builtin) && return nothing
27+
# _methods = Base._methods_by_ftype(T, -1, world)
28+
# length(_methods) == 0 && return nothing
29+
# type_signature, sps, method = last(_methods)
30+
# sps = svec(map(untvar, sps)...)
31+
# @static if VERSION >= v"1.2-"
32+
# mi = Core.Compiler.specialize_method(method, types, sps)
33+
# ci = Base.isgenerated(mi) ? Core.Compiler.get_staged(mi) : Base.uncompressed_ast(method)
34+
# else
35+
# mi = Core.Compiler.code_for_method(method, types, sps, world, false)
36+
# ci = Base.isgenerated(mi) ? Core.Compiler.get_staged(mi) : Base.uncompressed_ast(mi)
37+
# end
38+
# Base.Meta.partially_inline!(ci.code, [], method.sig, Any[sps...], 0, 0, :propagate)
39+
# Meta(method, mi, ci, method.nargs, sps)
40+
# end
41+
42+
1743
function overdubbable(ex::Expr)
1844
ex.head != :call && return(false)
1945
length(ex.args) < 2 && return(false)
@@ -22,13 +48,8 @@ function overdubbable(ex::Expr)
2248
return(false)
2349
end
2450

25-
# overdubbable(ex::Expr) = false
2651
overdubbable(ex) = false
27-
# overdubbable(ctx::Context, ex::Expr) = ex.head == :call
28-
# overdubbable(ctx::Context, ex) = false
29-
# timable(ctx::Context{Nothing}, ex) = true
30-
timable(ex::Expr) = overdubbable(ex)
31-
timable(ex) = false
52+
timable(ex) = overdubbable(ex)
3253

3354
#
3455
remap(ex::Expr, maps) = Expr(ex.head, remap(ex.args, maps)...)
@@ -43,41 +64,32 @@ remap(a::GlobalRef, maps) = a
4364
remap(a::QuoteNode, maps) = a
4465
remap(ex, maps) = ex
4566

46-
# remove static parameters (see https://discourse.julialang.org/t/does-overdubbing-in-generated-function-inserts-inlined-code/71868)
47-
remove_static(ex) = ex
48-
function remove_static(ex::Expr)
49-
ex.head != :call && return(ex)
50-
length(ex.args) != 2 && return(ex)
51-
!(ex.args[1] isa Expr) && return(ex)
52-
(ex.args[1].head == :static_parameter) && return(ex.args[2])
53-
return(ex)
54-
end
55-
5667
exportname(ex::GlobalRef) = QuoteNode(ex.name)
5768
exportname(ex::Symbol) = QuoteNode(ex)
5869
exportname(ex::Expr) = exportname(ex.args[1])
5970
exportname(i::Int) = QuoteNode(Symbol("Int(",i,")"))
6071

61-
using Base: invokelatest
6272
dummy() = return
73+
function empty_codeinfo()
74+
new_ci = code_lowered(dummy, Tuple{})[1]
75+
empty!(new_ci.code)
76+
empty!(new_ci.slotnames)
77+
empty!(new_ci.linetable)
78+
empty!(new_ci.codelocs)
79+
new_ci
80+
end
81+
6382

6483
overdub(f::Core.IntrinsicFunction, args...) = f(args...)
6584

6685
@generated function overdub(f::F, args...) where {F}
6786
@show (F, args...)
6887
ci = retrieve_code_info((F, args...))
6988
if ci === nothing
70-
@show Expr(:call, :f, [:(args[$(i)]) for i in 1:length(args)]...)
7189
return(Expr(:call, :f, [:(args[$(i)]) for i in 1:length(args)]...))
7290
end
73-
# this is to initialize a new CodeInfo and fill it with values from the
74-
# overdubbed function
75-
new_ci = code_lowered(dummy, Tuple{})[1]
76-
empty!(new_ci.code)
77-
empty!(new_ci.slotnames)
78-
empty!(new_ci.linetable)
79-
empty!(new_ci.codelocs)
8091

92+
new_ci = empty_codeinfo()
8193
new_ci.slotnames = vcat([Symbol("#self#"), :f, :args], ci.slotnames[2:end])
8294
new_ci.slotflags = vcat([0x00, 0x00, 0x00], ci.slotflags[2:end])
8395
foreach(s -> push!(new_ci.linetable, s), ci.linetable)
@@ -97,44 +109,44 @@ overdub(f::Core.IntrinsicFunction, args...) = f(args...)
97109
#if somewhere the original parameters of the functions will be used
98110
#they needs to be remapped to an SSAValue from here, since the overdubbed
99111
# function has signatures overdub(f, args...) instead of f(x,y,z...)
100-
ssa_no = 0
112+
newci_no = 0
101113
for i in 1:length(args)
102-
ssa_no +=1
114+
newci_no +=1
103115
push!(new_ci.code, Expr(:call, Base.getindex, Core.SlotNumber(3), i))
104-
maps.slots[i+1] = Core.SSAValue(ssa_no)
116+
maps.slots[i+1] = Core.SSAValue(newci_no)
105117
push!(new_ci.codelocs, ci.codelocs[1])
106118
end
107119

108-
for (ci_line, ex) in enumerate(ci.code)
120+
for (ci_no, ex) in enumerate(ci.code)
109121
if timable(ex)
110122
fname = exportname(ex)
111123
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_start), fname))
112-
push!(new_ci.codelocs, ci.codelocs[ci_line])
113-
ssa_no += 1
114-
maps.goto[ci_line] = ssa_no
115-
ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
124+
push!(new_ci.codelocs, ci.codelocs[ci_no])
125+
newci_no += 1
126+
maps.goto[ci_no] = newci_no
127+
# @show ex
128+
# @show Expr(:call, GlobalRef(Main, :overdub), ex.args...)
129+
# ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
130+
# ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
116131
push!(new_ci.code, ex)
117-
push!(new_ci.codelocs, ci.codelocs[ci_line])
118-
ssa_no += 1
119-
maps.ssa[ci_line] = ssa_no
132+
push!(new_ci.codelocs, ci.codelocs[ci_no])
133+
newci_no += 1
134+
maps.ssa[ci_no] = newci_no
120135
push!(new_ci.code, Expr(:call, GlobalRef(Main, :record_end), fname))
121-
push!(new_ci.codelocs, ci.codelocs[ci_line])
122-
ssa_no += 1
136+
push!(new_ci.codelocs, ci.codelocs[ci_no])
137+
newci_no += 1
123138
else
124139
push!(new_ci.code, ex)
125-
push!(new_ci.codelocs, ci.codelocs[ci_line])
126-
ssa_no += 1
127-
maps.ssa[ci_line] = ssa_no
140+
push!(new_ci.codelocs, ci.codelocs[ci_no])
141+
newci_no += 1
142+
maps.ssa[ci_no] = newci_no
128143
end
129144
end
130145

131146
for i in length(args)+1:length(new_ci.code)
132-
ex = remove_static(new_ci.code[i])
133-
new_ci.code[i] = remap(ex, maps)
147+
new_ci.code[i] = remap(new_ci.code[i], maps)
134148
end
135149
new_ci
136-
137-
# Core.Compiler.replace_code_newstyle!(ci, ir, length(ir.argtypes)-1)
138150
new_ci.inferred = false
139151
new_ci.ssavaluetypes = length(new_ci.code)
140152
# new_ci
@@ -149,38 +161,7 @@ end
149161

150162
reset!(to)
151163
overdub(foo, 1.0, 1.0)
164+
to
165+
reset!(to)
152166
new_ci = overdub(sin, 1.0)
153167
to
154-
155-
# Seems like now, I am crashing here
156-
# typeof(Base._promote), Irrational{:π}, Int64)
157-
158-
159-
function overdub2(::typeof(foo), args...)
160-
x = args[1]
161-
y = args[2]
162-
push!(to, :start, :fun)
163-
z = x * y
164-
push!(to, :stop, :fun)
165-
push!(to, :start, :fun)
166-
r = z + sin(y)
167-
push!(to, :stop, :fun)
168-
r
169-
end
170-
ci = retrieve_code_info((typeof(overdub2), typeof(foo), Float64, Float64))
171-
172-
173-
function test(x::T) where T<:Union{Float64, Float32}
174-
x < T(pi)
175-
end
176-
ci = retrieve_code_info((typeof(test), Float64))
177-
178-
179-
function overdub_test(::typeof(test), args...)
180-
x = args[1]
181-
T = eltype(x)
182-
x < T(pi)
183-
end
184-
185-
ci = retrieve_code_info((typeof(overdub_test), typeof(test), Float64))
186-

docs/src/lecture_09/irtools.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Generated functions
2+
using Dictionaries, IRTools
3+
include("calls.jl")
4+
using IRTools: var, xcall, insert!, insertafter!, func
5+
6+
function timable(ex::Expr)
7+
ex.head != :call && return(false)
8+
length(ex.args) < 2 && return(false)
9+
ex.args[1] isa Core.GlobalRef && return(true)
10+
ex.args[1] isa Symbol && return(true)
11+
return(false)
12+
end
13+
timable(ex) = false
14+
15+
exportname(ex::GlobalRef) = QuoteNode(ex.name)
16+
exportname(ex::Symbol) = QuoteNode(ex)
17+
exportname(ex::Expr) = exportname(ex.args[1])
18+
exportname(i::Int) = QuoteNode(Symbol("Int(",i,")"))
19+
20+
function foo(x, y)
21+
z = x * y
22+
z + sin(y)
23+
end
24+
25+
# ir = @code_ir foo(1.0, 1.0)
26+
ir = @code_ir sin(1.0)
27+
28+
# writing our profiler would be relatively
29+
# we will iterate over the ir code and inserts appropriate logs
30+
for b in IRTools.blocks(ir)
31+
for (v, ex) in b
32+
if timable(ex.expr)
33+
fname = exportname(ex.expr)
34+
insert!(b, v, xcall(Main, :record_start, fname))
35+
insertafter!(b, v, xcall(Main, :record_end, fname))
36+
end
37+
end
38+
end
39+
40+
@generated function profile_fun(f, args...)
41+
m = IRTools.Inner.meta(Tuple{f,args...})
42+
ir = IRTools.Inner.IR(m)
43+
for b in IRTools.blocks(ir)
44+
for (v, ex) in b
45+
if timable(ex.expr)
46+
fname = exportname(ex.expr)
47+
insert!(b, v, xcall(Main, :record_start, fname))
48+
insertafter!(b, v, xcall(Main, :record_end, fname))
49+
end
50+
end
51+
end
52+
53+
# we need to deal with the problem that ir has different set so f arguments than profile_fun(f, args...)
54+
# THis is what a dynamo does for us
55+
return(IRTools.Inner.build_codeinfo(ir))
56+
end
57+
58+
f = func(ir)
59+
# f(nothing, 1.0, 1.0)
60+
# f = func(ir)
61+
f(nothing, 1.0)
62+
63+
function func(m::Module, ir::IR)
64+
@eval @generated function $(gensym())($([Symbol(:arg, i) for i = 1:length(arguments(ir))]...))
65+
return build_codeinfo($ir)
66+
end
67+
end

0 commit comments

Comments
 (0)