11# Generated functions
22using Dictionaries
3+ include (" calls.jl" )
4+
35function retrieve_code_info (sigtypes, world = Base. get_world_counter ())
46 S = Tuple{map (s -> Core. Compiler. has_free_typevars (s) ? typeof (s. parameters[1 ]) : s, sigtypes)... }
57 _methods = Base. _methods_by_ftype (S, - 1 , world)
68 if isempty (_methods)
7- @info (" method $(sigtypes) does not exist, may-be run it once " )
9+ @info (" method $(sigtypes) does not exist" )
810 return (nothing )
911 end
10- type_signature, raw_static_params, method = _methods[1 ] # method is the same as we would get by invoking methods(+, (Int, Int)).ms[1]
11-
12- # this provides us with the CodeInfo
12+ type_signature, raw_static_params, method = _methods[1 ]
1313 method_instance = Core. Compiler. specialize_method (method, type_signature, raw_static_params, false )
1414 code_info = Core. Compiler. retrieve_code_info (method_instance)
1515end
1616
17- struct Calls
18- stamps:: Vector{Float64} # contains the time stamps
19- event:: Vector{Symbol} # name of the function that is being recorded
20- startstop:: Vector{Symbol} # if the time stamp corresponds to start or to stop
21- i:: Ref{Int}
22- end
23-
24- function Calls (n:: Int )
25- Calls (Vector {Float64} (undef, n+ 1 ), Vector {Symbol} (undef, n+ 1 ), Vector {Symbol} (undef, n+ 1 ), Ref {Int} (0 ))
26- end
27-
28- function Base. show (io:: IO , calls:: Calls )
29- offset = 0
30- for i in 1 : calls. i[]
31- offset -= calls. startstop[i] == :stop
32- foreach (_ -> print (io, " " ), 1 : max (offset, 0 ))
33- rel_time = calls. stamps[i] - calls. stamps[1 ]
34- println (io, calls. event[i], " : " , rel_time)
35- offset += calls. startstop[i] == :start
36- end
37- end
38-
39- global const to = Calls (100 )
40-
41- function record_start (ev:: Symbol )
42- calls = Main. to
43- n = calls. i[] = calls. i[] + 1
44- n > length (calls. stamps) && return
45- calls. event[n] = ev
46- calls. startstop[n] = :start
47- calls. stamps[n] = time_ns ()
48- end
49-
50- function record_end (ev:: Symbol )
51- t = time_ns ()
52- calls = Main. to
53- n = calls. i[] = calls. i[] + 1
54- n > length (calls. stamps) && return
55- calls. event[n] = ev
56- calls. startstop[n] = :stop
57- calls. stamps[n] = t
58- end
59-
60- reset! (calls:: Calls ) = calls. i[] = 0
61-
62-
6317function overdubbable (ex:: Expr )
6418 ex. head != :call && return (false )
6519 length (ex. args) < 2 && return (false )
66- (ex. args[1 ] isa Core. IntrinsicFunction) && return (false )
67- return (true )
20+ ex. args[1 ] isa Core. GlobalRef && return (true )
21+ ex. args[1 ] isa Symbol && return (true )
22+ return (false )
6823end
24+
6925# overdubbable(ex::Expr) = false
7026overdubbable (ex) = false
7127# overdubbable(ctx::Context, ex::Expr) = ex.head == :call
7228# overdubbable(ctx::Context, ex) = false
7329# timable(ctx::Context{Nothing}, ex) = true
74- timable (ex:: Expr ) = ex . head == :call
30+ timable (ex:: Expr ) = overdubbable (ex)
7531timable (ex) = false
7632
77- rename_args (ex, slotvar, ssamap) = ex
78- rename_args (c:: Core.GotoIfNot , slotvar, ssamap) = Core. GotoIfNot (rename_args (c. cond, slotvar, ssamap), ssamap[c. dest])
79- rename_args (ex:: Expr , slotvar, ssamap) = Expr (ex. head, rename_args (ex. args, slotvar, ssamap)... )
80- rename_args (args:: AbstractArray , slotvar, ssamap) = map (a -> rename_args (a, slotvar, ssamap), args)
81- rename_args (r:: Core.ReturnNode , slotvar, ssamap) = Core. ReturnNode (rename_args (r. val, slotvar, ssamap))
82- rename_args (a:: Core.SlotNumber , slotvar, ssamap) = slotvar[a. id]
83- rename_args (a:: Core.SSAValue , slotvar, ssamap) = Core. SSAValue (ssamap[a. id])
33+ #
34+ remap (ex:: Expr , maps) = Expr (ex. head, remap (ex. args, maps)... )
35+ remap (args:: AbstractArray , maps) = map (a -> remap (a, maps), args)
36+ remap (c:: Core.GotoNode , maps) = Core. GotoNode (maps. goto[c. label])
37+ remap (c:: Core.GotoIfNot , maps) = Core. GotoIfNot (remap (c. cond, maps), maps. goto[c. dest])
38+ remap (r:: Core.ReturnNode , maps) = Core. ReturnNode (remap (r. val, maps))
39+ remap (a:: Core.SlotNumber , maps) = maps. slots[a. id]
40+ remap (a:: Core.SSAValue , maps) = Core. SSAValue (maps. ssa[a. id])
41+ remap (a:: Core.NewvarNode , maps) = Core. NewvarNode (maps. slots[a. slot. id])
42+ remap (a:: GlobalRef , maps) = a
43+ remap (a:: QuoteNode , maps) = a
44+ remap (ex, maps) = ex
45+
46+ # remove static parameters (see https://discourse.julialang.org/t/does-overdubbing-in-generated-function-inserts-inlined-code/71868)
47+ remove_static (ex) = ex
48+ function remove_static (ex:: Expr )
49+ ex. head != :call && return (ex)
50+ length (ex. args) != 2 && return (ex)
51+ ! (ex. args[1 ] isa Expr) && return (ex)
52+ (ex. args[1 ]. head == :static_parameter ) && return (ex. args[2 ])
53+ return (ex)
54+ end
8455
8556exportname (ex:: GlobalRef ) = QuoteNode (ex. name)
57+ exportname (ex:: Symbol ) = QuoteNode (ex)
8658exportname (ex:: Expr ) = exportname (ex. args[1 ])
8759exportname (i:: Int ) = QuoteNode (Symbol (" Int(" ,i," )" ))
8860
@@ -103,69 +75,87 @@ overdub(f::Core.IntrinsicFunction, args...) = f(args...)
10375 new_ci = code_lowered (dummy, Tuple{})[1 ]
10476 empty! (new_ci. code)
10577 empty! (new_ci. slotnames)
106- foreach (s -> push! (new_ci. slotnames, s), ci. slotnames)
107- new_ci. slotnames = vcat ([Symbol (" #self#" ), :f , :args ], ci. slotnames[length (args)+ 2 : end ])
108- new_ci. slotflags = vcat ([0x00 , 0x00 , 0x00 ], ci. slotflags[length (args)+ 2 : end ])
10978 empty! (new_ci. linetable)
110- foreach (s -> push! (new_ci. linetable, s), ci. linetable)
11179 empty! (new_ci. codelocs)
11280
113- ssamap = Dict {Int, Int} ()
114- slotvar = Dict {Int, Any} ()
81+ new_ci. slotnames = vcat ([Symbol (" #self#" ), :f , :args ], ci. slotnames[2 : end ])
82+ new_ci. slotflags = vcat ([0x00 , 0x00 , 0x00 ], ci. slotflags[2 : end ])
83+ foreach (s -> push! (new_ci. linetable, s), ci. linetable)
84+
85+ maps = (
86+ ssa = Dict {Int, Int} (),
87+ slots = Dict {Int, Any} (),
88+ goto = Dict {Int,Int} (),
89+ )
90+
91+ # we need to map indexes of slot-variables from ci to their new values.
92+ # except the first one, we just remap them
93+ maps. slots[1 ] = Core. SlotNumber (1 )
94+ foreach (i -> maps. slots[i] = Core. SlotNumber (i + 2 ), 2 : length (ci. slotnames)) # they are shifted by 2 accomondating inserted `f` and `args`
95+ @assert all (ci. slotnames[i] == new_ci. slotnames[maps. slots[i]. id] for i in 1 : length (ci. slotnames)) # test that the remapping is right
96+
97+ # if somewhere the original parameters of the functions will be used
98+ # they needs to be remapped to an SSAValue from here, since the overdubbed
99+ # function has signatures overdub(f, args...) instead of f(x,y,z...)
100+ ssa_no = 0
115101 for i in 1 : length (args)
102+ ssa_no += 1
116103 push! (new_ci. code, Expr (:call , Base. getindex, Core. SlotNumber (3 ), i))
117- slotvar [i+ 1 ] = Core. SSAValue (i )
104+ maps . slots [i+ 1 ] = Core. SSAValue (ssa_no )
118105 push! (new_ci. codelocs, ci. codelocs[1 ])
119106 end
120- slotvar[1 ] = Core. SlotNumber (1 )
121- foreach (i -> slotvar[i[2 ]] = Core. SlotNumber (i[1 ]+ 3 ), enumerate (length (args)+ 2 : length (ci. slotnames)))
122107
123- j = length (args)
124- for (i, ex) in enumerate (ci. code)
108+ for (ci_line, ex) in enumerate (ci. code)
125109 if timable (ex)
126110 fname = exportname (ex)
127111 push! (new_ci. code, Expr (:call , GlobalRef (Main, :record_start ), fname))
128- push! (new_ci. codelocs, ci. codelocs[i ])
129- j += 1
130- # ex = overdubbable(ex) ? Expr(:call, :overdub, ex.args...) : ex
131- # ex = overdubbable(ex) ? Expr(:call, GlobalRef(Main, :overdub), ex.args...) : ex
112+ push! (new_ci. codelocs, ci. codelocs[ci_line ])
113+ ssa_no += 1
114+ maps . goto[ci_line] = ssa_no
115+ ex = overdubbable (ex) ? Expr (:call , GlobalRef (Main, :overdub ), ex. args... ) : ex
132116 push! (new_ci. code, ex)
133- push! (new_ci. codelocs, ci. codelocs[i ])
134- j += 1
135- ssamap[i ] = j
117+ push! (new_ci. codelocs, ci. codelocs[ci_line ])
118+ ssa_no += 1
119+ maps . ssa[ci_line ] = ssa_no
136120 push! (new_ci. code, Expr (:call , GlobalRef (Main, :record_end ), fname))
137- push! (new_ci. codelocs, ci. codelocs[i ])
138- j += 1
121+ push! (new_ci. codelocs, ci. codelocs[ci_line ])
122+ ssa_no += 1
139123 else
140124 push! (new_ci. code, ex)
141- push! (new_ci. codelocs, ci. codelocs[i ])
142- j += 1
143- ssamap[i ] = j
125+ push! (new_ci. codelocs, ci. codelocs[ci_line ])
126+ ssa_no += 1
127+ maps . ssa[ci_line ] = ssa_no
144128 end
145129 end
130+
146131 for i in length (args)+ 1 : length (new_ci. code)
147- new_ci. code[i] = rename_args (new_ci. code[i], slotvar, ssamap)
132+ ex = remove_static (new_ci. code[i])
133+ new_ci. code[i] = remap (ex, maps)
148134 end
149135 new_ci
150136
151137 # Core.Compiler.replace_code_newstyle!(ci, ir, length(ir.argtypes)-1)
152138 new_ci. inferred = false
153139 new_ci. ssavaluetypes = length (new_ci. code)
154140 # new_ci
155- # new_ci
156141 return (new_ci)
157142end
158143
159144function foo (x, y)
160- z = x * y
145+ z = x * y
161146 z + sin (y)
162147end
163148
149+
164150reset! (to)
165151overdub (foo, 1.0 , 1.0 )
166- overdub (sin, 1.0 )
152+ new_ci = overdub (sin, 1.0 )
167153to
168154
155+ # Seems like now, I am crashing here
156+ # typeof(Base._promote), Irrational{:π}, Int64)
157+
158+
169159function overdub2 (:: typeof (foo), args... )
170160 x = args[1 ]
171161 y = args[2 ]
@@ -177,6 +167,20 @@ function overdub2(::typeof(foo), args...)
177167 push! (to, :stop , :fun )
178168 r
179169end
180-
181170ci = retrieve_code_info ((typeof (overdub2), typeof (foo), Float64, Float64))
182171
172+
173+ function test (x:: T ) where T<: Union{Float64, Float32}
174+ x < T (pi )
175+ end
176+ ci = retrieve_code_info ((typeof (test), Float64))
177+
178+
179+ function overdub_test (:: typeof (test), args... )
180+ x = args[1 ]
181+ T = eltype (x)
182+ x < T (pi )
183+ end
184+
185+ ci = retrieve_code_info ((typeof (overdub_test), typeof (test), Float64))
186+
0 commit comments