Skip to content

Commit c40390a

Browse files
committed
fixup! SumTree tests
1 parent 3cf057f commit c40390a

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

test/sum_tree.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
function gen_rand_sumtree(n, seed, type::DataType=Float32)
2+
rng = StableRNG(seed)
3+
a = SumTree(type, n)
4+
append!(a, rand(rng, type, n))
5+
return a
6+
end
7+
8+
function gen_sumtree_with_zeros(n, seed, type::DataType=Float32)
9+
a = gen_rand_sumtree(n, seed, type)
10+
b = rand(StableRNG(seed), Bool, n)
11+
return copy_multiply(a, b)
12+
end
13+
14+
function copy_multiply(stree, m)
15+
new_tree = deepcopy(stree)
16+
new_tree .*= m
17+
return new_tree
18+
end
19+
20+
function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1)
21+
for _ in iters
22+
(_, p) = rand(rng, t)
23+
p == 0 && return false
24+
end
25+
return true
26+
end
27+
sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters)
28+
sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds)
29+
30+
31+
function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length)
32+
for i = 1:iters
33+
indices[i], priorities[i] = rand(rng, t)
34+
end
35+
imap = countmap(indices)
36+
est_pdf = Dict(k=>v/length(indices) for (k, v) in imap)
37+
ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t))
38+
abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)]
39+
return abserrs
40+
end
41+
sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters)
42+
function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n)
43+
p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
44+
i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
45+
results = Vector{Vector{Float64}}(undef, length(seeds))
46+
Threads.@threads for ix = 1:length(seeds)
47+
results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters)
48+
end
49+
return results
50+
end
51+
52+
n = 1024
53+
seeds = 1:100
54+
nozero_iters=1024
55+
distr_iters=1024*10_000
56+
abstol = 0.05
57+
maxerr=0.01
58+
sumtree_distribution(n, seeds, distr_iters)
59+
60+
# @test sumtree_nozero(n, seeds, nozero_iters)
61+
# @test all(x->all(x .< maxerr) && sum(abs2, x) < abstol,
62+
# sumtree_distribution(n, seeds, distr_iters))

0 commit comments

Comments
 (0)