|
| 1 | +function gen_rand_sumtree(n, seed, type::DataType=Float32) |
| 2 | + rng = StableRNG(seed) |
| 3 | + a = SumTree(type, n) |
| 4 | + append!(a, rand(rng, type, n)) |
| 5 | + return a |
| 6 | +end |
| 7 | + |
| 8 | +function gen_sumtree_with_zeros(n, seed, type::DataType=Float32) |
| 9 | + a = gen_rand_sumtree(n, seed, type) |
| 10 | + b = rand(StableRNG(seed), Bool, n) |
| 11 | + return copy_multiply(a, b) |
| 12 | +end |
| 13 | + |
| 14 | +function copy_multiply(stree, m) |
| 15 | + new_tree = deepcopy(stree) |
| 16 | + new_tree .*= m |
| 17 | + return new_tree |
| 18 | +end |
| 19 | + |
| 20 | +function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1) |
| 21 | + for _ in iters |
| 22 | + (_, p) = rand(rng, t) |
| 23 | + p == 0 && return false |
| 24 | + end |
| 25 | + return true |
| 26 | +end |
| 27 | +sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters) |
| 28 | +sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds) |
| 29 | + |
| 30 | + |
| 31 | +function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length) |
| 32 | + for i = 1:iters |
| 33 | + indices[i], priorities[i] = rand(rng, t) |
| 34 | + end |
| 35 | + imap = countmap(indices) |
| 36 | + est_pdf = Dict(k=>v/length(indices) for (k, v) in imap) |
| 37 | + ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t)) |
| 38 | + abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)] |
| 39 | + return abserrs |
| 40 | +end |
| 41 | +sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters) |
| 42 | +function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n) |
| 43 | + p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()] |
| 44 | + i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()] |
| 45 | + results = Vector{Vector{Float64}}(undef, length(seeds)) |
| 46 | + Threads.@threads for ix = 1:length(seeds) |
| 47 | + results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters) |
| 48 | + end |
| 49 | + return results |
| 50 | +end |
| 51 | + |
| 52 | +n = 1024 |
| 53 | +seeds = 1:100 |
| 54 | +nozero_iters=1024 |
| 55 | +distr_iters=1024*10_000 |
| 56 | +abstol = 0.05 |
| 57 | +maxerr=0.01 |
| 58 | +sumtree_distribution(n, seeds, distr_iters) |
| 59 | + |
| 60 | +# @test sumtree_nozero(n, seeds, nozero_iters) |
| 61 | +# @test all(x->all(x .< maxerr) && sum(abs2, x) < abstol, |
| 62 | +# sumtree_distribution(n, seeds, distr_iters)) |
0 commit comments