Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit 69e736e

Browse files
committed
remove Amb, add example test
1 parent 1cfad3a commit 69e736e

File tree

7 files changed

+161
-29
lines changed

7 files changed

+161
-29
lines changed

src/controlflow.jl

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,110 @@
1-
using Amb
1+
#### Path
2+
3+
# First just do it for the case where there we assume
4+
# tainted gotoifnots do not go in a loop!
5+
# TODO: write a thing to detect this! (overdub predicates only in tainted ifs)
6+
# implement snapshotting function state as an optimization for branch exploration
7+
mutable struct Path
8+
path::BitVector
9+
cursor::Int
10+
end
11+
12+
Path() = Path([], 1)
13+
14+
function increment!(bitvec)
15+
for i=1:length(bitvec)
16+
if bitvec[i] === true
17+
bitvec[i] = false
18+
else
19+
bitvec[i] = true
20+
break
21+
end
22+
end
23+
end
24+
25+
function reset!(p::Path)
26+
p.cursor=1
27+
increment!(p.path)
28+
nothing
29+
end
30+
31+
function alldone(p::Path) # must be called at the end of the function!
32+
all(identity, p.path)
33+
end
34+
35+
function current_predicate!(p::Path)
36+
if p.cursor > length(p.path)
37+
push!(p.path, false)
38+
else
39+
p.path[p.cursor]
40+
end
41+
val = p.path[p.cursor]
42+
p.cursor+=1
43+
val
44+
end
45+
46+
alldone(c) = alldone(c.metadata[2])
47+
reset!(c) = reset!(c.metadata[2])
48+
current_predicate!(c) = current_predicate!(c.metadata[2])
49+
50+
#=
51+
julia> p=Path()
52+
Path(Bool[], 1)
53+
54+
julia> alldone(p) # must be called at the end of a full run
55+
true
56+
57+
julia> current_predicate!(p)
58+
false
59+
60+
julia> alldone(p) # must be called at the end of a full run
61+
false
62+
63+
julia> current_predicate!(p)
64+
false
65+
66+
julia> p
67+
Path(Bool[false, false], 3)
68+
69+
julia> alldone(p)
70+
false
71+
72+
julia> reset!(p)
73+
74+
julia> p
75+
Path(Bool[true, false], 1)
76+
77+
julia> current_predicate!(p)
78+
true
79+
80+
julia> current_predicate!(p)
81+
false
82+
83+
julia> alldone(p)
84+
false
85+
86+
julia> reset!(p)
87+
88+
julia> p
89+
Path(Bool[false, true], 1)
90+
91+
julia> current_predicate!(p)
92+
false
93+
94+
julia> current_predicate!(p)
95+
true
96+
97+
julia> reset!(p)
98+
99+
julia> current_predicate!(p)
100+
true
101+
102+
julia> current_predicate!(p)
103+
true
104+
105+
julia> alldone(p)
106+
true
107+
=#
2108

3109
"""
4110
`abstract_run(g, ctx, overdubbed_fn, args...)`
@@ -32,9 +138,17 @@ end
32138
# do something to merge metadata from all the runs
33139
```
34140
"""
35-
function abstract_run(acc, ctx::Cassette.Context, overdub_fn, args...)
36-
pass_ctx = Cassette.similarcontext(ctx, pass=AbsintPass)
37-
acc(Cassette.overdub(pass_ctx, overdub_fn, args...))
141+
function abstract_run(acc, ctx::Cassette.Context, overdub_fn, args...; verbose=true)
142+
path = Path()
143+
pass_ctx = Cassette.similarcontext(ctx, metadata=(ctx.metadata, path), pass=AbsintPass)
144+
145+
while true
146+
acc(Cassette.recurse(pass_ctx, ()->overdub_fn(args...)))
147+
148+
verbose && println("Explored path: ", path)
149+
alldone(path) && break
150+
reset!(path)
151+
end
38152
end
39153

40154
"""
@@ -47,8 +161,6 @@ function istainted(ctx, cond)
47161
" See docs for `istainted`.")
48162
end
49163

50-
_choice() = (@amb true false)
51-
52164
# Must return 7 exprs
53165
function rewrite_branch(ctx, stmt, extraslot, i)
54166
# turn
@@ -70,9 +182,9 @@ function rewrite_branch(ctx, stmt, extraslot, i)
70182
# not tainted? jump to the penultimate statement
71183
push!(exprs, Expr(:gotoifnot, istainted_ssa, i+5))
72184

73-
# tainted? then use this_here_predicate!(SSAValue(1))
185+
# tainted? then use current_predicate!(SSAValue(1))
74186
current_pred = i+2
75-
push!(exprs, :($_choice()))
187+
push!(exprs, :($(Expr(:nooverdub, current_predicate!))($(Expr(:contextslot)))))
76188

77189
# Store the interpreter-provided predicate in the slot
78190
push!(exprs, Expr(:(=), extraslot, SSAValue(i+2)))
@@ -93,7 +205,7 @@ function rewrite_ir(ctx, ref)
93205
# turn
94206
# <val> ? t : f
95207
# into
96-
# istainted(<val>) ? this_here_predicate!(p) : <val> ? t : f
208+
# istainted(<val>) ? current_predicate!(p) : <val> ? t : f
97209

98210
ir = ref.code_info
99211
ir = copy(ir)

src/jacobian.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,16 @@ function Base.push!(S::Sparsity, i::Int, j::Int)
2525
push!(S.J, j)
2626
end
2727

28-
struct ProvinanceSet{T}
29-
set::T # Set, Array, Int, Tuple, anything!
28+
struct ProvinanceSet
29+
set::Set{Int}
30+
ProvinanceSet(s::Set) = new(s)
31+
ProvinanceSet(s) = new(Set(s))
3032
end
3133

3234
# note: this is not strictly set union, just some efficient way of concating
3335
Base.union(p::ProvinanceSet, ::Cassette.NoMetaData) = p
3436
Base.union(::Cassette.NoMetaData, p::ProvinanceSet) = p
3537

36-
Base.union(p::ProvinanceSet{<:Tuple},
37-
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set..., q.set,))
38-
Base.union(p::ProvinanceSet{<:Integer},
39-
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set, q.set...,))
40-
Base.union(p::ProvinanceSet{<:Integer},
41-
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set, q.set,))
42-
Base.union(p::ProvinanceSet{<:Tuple},
43-
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set..., q.set...,))
4438
Base.union(p::ProvinanceSet,
4539
q::ProvinanceSet) = ProvinanceSet(union(p.set, q.set))
4640
Base.union(p::ProvinanceSet,
@@ -115,7 +109,7 @@ function Cassette.overdub(ctx::JacobianSparsityContext,
115109
Y::Tagged,
116110
val::Tagged,
117111
idx::Int...)
118-
S = ctx.metadata
112+
S = ctx.metadata[1]
119113
if metatype(Y, ctx) <: JacOutput
120114
set = metadata(val, ctx)
121115
if set isa ProvinanceSet
@@ -159,7 +153,7 @@ function Cassette.overdub(ctx::JacobianSparsityContext,
159153
Y::Tagged,
160154
ystart,
161155
len)
162-
S = ctx.metadata
156+
S = ctx.metadata[1]
163157
if metatype(Y, ctx) <: JacInput && metatype(X, ctx) <: JacOutput
164158
# Write directly to the output sparsity
165159
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)

src/propagate_tags.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11

2+
@inline anytagged() = false
3+
@inline anytagged(x::Tagged, args...) = true
4+
@inline anytagged(x, args...) = anytagged(args...)
5+
6+
27
macro proptagcontext(name)
38
quote
49
Cassette.@context($name)
510

611
function Cassette.overdub(ctx::$name, f, args...)
712
# this check can be inferred (in theory)
8-
if any(x->x isa Tagged, args)
13+
if anytagged(args...)
914
# This is a slower check
1015
if !any(x->!(metatype(x, ctx) <: Cassette.NoMetaData), args)
1116
return Cassette.recurse(ctx, f, args...)

test/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ function jactester(f, Y, X, args...)
1010
ctx, val = jacobian_sparsity(f, Y, X, args...; raw=true)
1111
end
1212

13-
jactestmeta(args...) = jactester(args...)[1].metadata
13+
jactestmeta(args...) = jactester(args...)[1].metadata[1]
1414
jactestval(args...) = jactester(args...) |> ((ctx,val),) -> untag(val, ctx)
1515
jactesttag(args...) = jactester(args...) |> ((ctx,val),) -> metadata(val, ctx)
1616

1717
function hesstester(f, X, args...)
1818
ctx, val = hessian_sparsity(f, X, args...; raw=true)
1919
end
2020

21-
hesstestmeta(args...) = hesstester(args...)[1].metadata
21+
hesstestmeta(args...) = hesstester(args...)[1].metadata[1]
2222
hesstestval(args...) = hesstester(args...) |> ((ctx,val),) -> untag(val, ctx)
2323
hesstesttag(args...) = hesstester(args...) |> ((ctx,val),) -> metadata(val, ctx)
2424

test/examples.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
fcalls = 0
2+
function f(dx,x)
3+
global fcalls += 1
4+
for i in 2:length(x)-1
5+
dx[i] = x[i-1] - 2x[i] + x[i+1]
6+
end
7+
dx[1] = -2x[1] + x[2]
8+
dx[end] = x[end-1] - 2x[end]
9+
nothing
10+
end
11+
12+
using SparsityDetection, SparseArrays
13+
input = rand(10)
14+
output = similar(input)
15+
sparsity_pattern = sparsity!(f,output,input)
16+
jac = Float64.(sparse(sparsity_pattern))

test/jacobian.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ let
1818
g(y,x) = y[:] .= x .+ 1
1919
#g(y,x) = y .= x .+ 1 -- memove
2020

21-
@test sparse(jactestmeta(g, [1], [2])[1]) == sparse([1], [1], true)
21+
println("Broadcast timings")
22+
println(" y .= x")
2223
# test path of unsafe_copy from Input to Output
23-
@test sparsity!((y,x) -> y .= x, [1,2,3], [1,2,3]) == sparse([1,2,3], [1,2,3], true)
24+
@test @time jacobian_sparsity((y,x) -> y .= x, [1,2,3], [1,2,3]) == sparse([1,2,3], [1,2,3], true)
25+
println(" y[:] .= x .+ 1")
26+
@test @time sparse(jactestmeta(g, [1], [2])) == sparse([1], [1], true)
27+
println(" y[1:2] .= x[2:3]")
2428
# test path of unsafe_copy from Input to an intermediary
25-
@test sparsity!((y,x) -> y[1:2] .= x[2:3], [1,2,3], [1,2,3]) == sparse([1,2],[2,3],true, 3,3)
29+
@test @time jacobian_sparsity((y,x) -> y[1:2] .= x[2:3], [1,2,3], [1,2,3]) == sparse([1,2],[2,3],true, 3,3)
2630

2731
using LinearAlgebra, SparsityDetection
2832

@@ -31,7 +35,7 @@ let
3135
mul!(out, A, x)
3236
end
3337
x = [1:4;]; out = similar(x);
34-
@test sparsity!(testsparse!, out, x) == sparse([1,2,1,2,3,2,3,4,3,4],
38+
@test jacobian_sparsity(testsparse!, out, x) == sparse([1,2,1,2,3,2,3,4,3,4],
3539
[1,1,2,2,2,3,3,3,4,4], true)
3640
end
3741

@@ -44,5 +48,5 @@ end
4448

4549
x = [1.0:10;]
4650
out = similar(x)
47-
@test all(sparsity!(f, out, x) .== 1)
51+
@test all(jacobian_sparsity(f, out, x) .== 1)
4852
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ include("common.jl")
55
@testset "Paraboloid example" begin include("paraboloid.jl") end
66

77
@testset "Exploration" begin include("ifsandbuts.jl") end
8+
@testset "Examples" begin include("examples.jl") end

0 commit comments

Comments
 (0)