Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit a3843f7

Browse files
authored
Merge pull request #28 from SciML/s/fix-rebase
Cherry-pick into Refactor plus Fix a promote bug
2 parents 69c416a + e687ae4 commit a3843f7

23 files changed

+948
-989
lines changed

src/SparsityDetection.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ using Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
66
using Cassette: tagged_new_tuple, ContextTagged, BindingMeta, DisableHooks, nametype
77
using Core: SSAValue
88

9-
export Sparsity, hsparsity, sparsity!
9+
export Sparsity, jacobian_sparsity, hessian_sparsity, hsparsity, sparsity!
1010

11-
include("program_sparsity.jl")
12-
include("sparsity_tracker.jl")
13-
include("path.jl")
14-
include("take_all_branches.jl")
15-
include("terms.jl")
11+
include("util.jl")
12+
include("controlflow.jl")
13+
include("propagate_tags.jl")
1614
include("linearity.jl")
15+
include("jacobian.jl")
1716
include("hessian.jl")
1817
include("blas.jl")
19-
include("linearity_special.jl")
18+
19+
sparsity!(args...; kwargs...) = jacobian_sparsity(args...; kwargs...)
20+
hsparsity(args...; kwargs...) = hessian_sparsity(args...; kwargs...)
2021

2122
end

src/blas.jl

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,12 @@
1-
# generic implementations
2-
3-
_name(x::Symbol) = x
4-
_name(x::Expr) = (@assert x.head == :(::); x.args[1])
5-
macro reroute(f, g)
6-
fname = f.args[1]
7-
fargs = f.args[2:end]
8-
gname = g.args[1]
9-
gargs = g.args[2:end]
10-
quote
11-
@inline function Cassette.overdub(ctx::SparsityContext,
12-
f::typeof($(esc(fname))),
13-
$(fargs...))
14-
Cassette.recurse(
15-
ctx,
16-
invoke,
17-
$(esc(gname)),
18-
$(esc(:(Tuple{$(gargs...)}))),
19-
$(map(_name, fargs)...))
20-
end
1+
# Forward BLAS calls to generic implementation
2+
#
3+
using LinearAlgebra
4+
import LinearAlgebra.BLAS
215

22-
@inline function Cassette.overdub(ctx::HessianSparsityContext,
23-
f::typeof($(esc(fname))),
24-
$(fargs...))
25-
Cassette.recurse(
26-
ctx,
27-
invoke,
28-
$(esc(gname)),
29-
$(esc(:(Tuple{$(gargs...)}))),
30-
$(map(_name, fargs)...))
31-
end
32-
end
33-
end
6+
# generic implementations
347

358
@reroute LinearAlgebra.BLAS.dot(x,y) LinearAlgebra.dot(Any, Any)
36-
@reroute LinearAlgebra.BLAS.axpy!(x, y) LinearAlgebra.axpy!(Any,
9+
@reroute LinearAlgebra.BLAS.axpy!(a, x, y) LinearAlgebra.axpy!(Any,
3710
AbstractArray,
3811
AbstractArray)
3912

src/controlflow.jl

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#### Path
2+
3+
# First just do it for the case where there we assume
4+
# tainted gotoifnots do not go in a loop!
5+
# TODO: write a thing to detect this! (overdub predicates only in tainted ifs)
6+
# implement snapshotting function state as an optimization for branch exploration
7+
mutable struct Path
8+
path::BitVector
9+
cursor::Int
10+
end
11+
12+
Path() = Path([], 1)
13+
14+
function increment!(bitvec)
15+
for i=1:length(bitvec)
16+
if bitvec[i] === true
17+
bitvec[i] = false
18+
else
19+
bitvec[i] = true
20+
break
21+
end
22+
end
23+
end
24+
25+
function reset!(p::Path)
26+
p.cursor=1
27+
increment!(p.path)
28+
nothing
29+
end
30+
31+
function alldone(p::Path) # must be called at the end of the function!
32+
all(identity, p.path)
33+
end
34+
35+
function current_predicate!(p::Path)
36+
if p.cursor > length(p.path)
37+
push!(p.path, false)
38+
else
39+
p.path[p.cursor]
40+
end
41+
val = p.path[p.cursor]
42+
p.cursor+=1
43+
val
44+
end
45+
46+
alldone(c) = alldone(c.metadata[2])
47+
reset!(c) = reset!(c.metadata[2])
48+
current_predicate!(c) = current_predicate!(c.metadata[2])
49+
50+
#=
51+
julia> p=Path()
52+
Path(Bool[], 1)
53+
54+
julia> alldone(p) # must be called at the end of a full run
55+
true
56+
57+
julia> current_predicate!(p)
58+
false
59+
60+
julia> alldone(p) # must be called at the end of a full run
61+
false
62+
63+
julia> current_predicate!(p)
64+
false
65+
66+
julia> p
67+
Path(Bool[false, false], 3)
68+
69+
julia> alldone(p)
70+
false
71+
72+
julia> reset!(p)
73+
74+
julia> p
75+
Path(Bool[true, false], 1)
76+
77+
julia> current_predicate!(p)
78+
true
79+
80+
julia> current_predicate!(p)
81+
false
82+
83+
julia> alldone(p)
84+
false
85+
86+
julia> reset!(p)
87+
88+
julia> p
89+
Path(Bool[false, true], 1)
90+
91+
julia> current_predicate!(p)
92+
false
93+
94+
julia> current_predicate!(p)
95+
true
96+
97+
julia> reset!(p)
98+
99+
julia> current_predicate!(p)
100+
true
101+
102+
julia> current_predicate!(p)
103+
true
104+
105+
julia> alldone(p)
106+
true
107+
=#
108+
109+
"""
110+
`abstract_run(g, ctx, overdubbed_fn, args...)`
111+
112+
First rewrites every if statement
113+
114+
```julia
115+
if <expr>
116+
...
117+
end
118+
119+
as
120+
121+
```julia
122+
cond = <expr>
123+
if istainted(ctx, cond) ? @amb(true, false) : cond
124+
...
125+
end
126+
```
127+
128+
Then runs `g(Cassette.overdub(ctx, overdubbed_fn, args...)`
129+
as many times as there are available paths. i.e. `2^n` ways
130+
where `n` is the number of tainted branch conditions.
131+
132+
# Examples:
133+
```
134+
meta = Any[]
135+
abstract_run(ctx, f. args...) do result
136+
push!(meta, metadata(result, ctx))
137+
end
138+
# do something to merge metadata from all the runs
139+
```
140+
"""
141+
function abstract_run(acc, ctx::Cassette.Context, overdub_fn, args...; verbose=true)
142+
path = Path()
143+
pass_ctx = Cassette.similarcontext(ctx, metadata=(ctx.metadata, path), pass=AbsintPass)
144+
145+
while true
146+
acc(Cassette.recurse(pass_ctx, ()->overdub_fn(args...)))
147+
148+
verbose && println("Explored path: ", path)
149+
alldone(path) && break
150+
reset!(path)
151+
end
152+
end
153+
154+
"""
155+
`istainted(ctx, cond)`
156+
157+
Does `cond` have any metadata?
158+
"""
159+
function istainted(ctx, cond)
160+
error("Method needed: istainted(::$(typeof(ctx)), ::Bool)." *
161+
" See docs for `istainted`.")
162+
end
163+
164+
# Must return 7 exprs
165+
function rewrite_branch(ctx, stmt, extraslot, i)
166+
# turn
167+
# gotoifnot %p #g
168+
# into
169+
# %t = istainted(%p)
170+
# gotoifnot %t #orig
171+
# %rec = @amb true false
172+
# gotoifnot %rec #orig+1 (the next statement after gotoifnot)
173+
174+
exprs = Any[]
175+
cond = stmt.args[1] # already an SSAValue
176+
177+
# insert a check to see if SSAValue(i) isa Tainted
178+
istainted_ssa = Core.SSAValue(i)
179+
push!(exprs, :($(Expr(:nooverdub, istainted))($(Expr(:contextslot)),
180+
$cond)))
181+
182+
# not tainted? jump to the penultimate statement
183+
push!(exprs, Expr(:gotoifnot, istainted_ssa, i+5))
184+
185+
# tainted? then use current_predicate!(SSAValue(1))
186+
current_pred = i+2
187+
push!(exprs, :($(Expr(:nooverdub, current_predicate!))($(Expr(:contextslot)))))
188+
189+
# Store the interpreter-provided predicate in the slot
190+
push!(exprs, Expr(:(=), extraslot, SSAValue(i+2)))
191+
192+
push!(exprs, Core.GotoNode(i+6))
193+
194+
push!(exprs, Expr(:(=), extraslot, cond))
195+
196+
# here we put in the original code
197+
stmt1 = copy(stmt)
198+
stmt.args[1] = extraslot
199+
push!(exprs, stmt)
200+
201+
exprs
202+
end
203+
204+
function rewrite_ir(ctx, ref)
205+
# turn
206+
# <val> ? t : f
207+
# into
208+
# istainted(<val>) ? current_predicate!(p) : <val> ? t : f
209+
210+
ir = ref.code_info
211+
ir = copy(ir)
212+
213+
extraslot = gensym("tmp")
214+
push!(ir.slotnames, extraslot)
215+
push!(ir.slotflags, 0x00)
216+
extraslot = Core.SlotNumber(length(ir.slotnames))
217+
218+
Cassette.insert_statements!(ir.code, ir.codelocs,
219+
(stmt, i) -> Base.Meta.isexpr(stmt, :gotoifnot) ? 7 : nothing,
220+
(stmt, i) -> rewrite_branch(ctx, stmt, extraslot, i))
221+
222+
ir.ssavaluetypes = length(ir.code)
223+
# Core.Compiler.validate_code(ir)
224+
#@show ref.method
225+
#@show ir
226+
return ir
227+
end
228+
229+
const AbsintPass = Cassette.@pass rewrite_ir

0 commit comments

Comments
 (0)