Skip to content

Commit 2de532a

Browse files
author
Jeremiah Lewis
committed
Update ElasticArraySARTSTraces and add new files
1 parent a51047f commit 2de532a

File tree

8 files changed

+155
-52
lines changed

8 files changed

+155
-52
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
export ElasticArraySARTSTraces
2+
3+
const ElasticArraySARTSATraces = Traces{
4+
SS′AA′RT,
5+
<:Tuple{
6+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
7+
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
8+
<:Trace{<:ElasticArray},
9+
<:Trace{<:ElasticArray},
10+
}
11+
}
12+
13+
function ElasticArraySARTSATraces(;
14+
state=Int => (),
15+
action=Int => (),
16+
reward=Float32 => (),
17+
terminal=Bool => ()
18+
)
19+
state_eltype, state_size = state
20+
action_eltype, action_size = action
21+
reward_eltype, reward_size = reward
22+
terminal_eltype, terminal_size = terminal
23+
24+
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
25+
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
26+
Traces(
27+
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
28+
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
29+
)
30+
end
31+
Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,30 @@
11
export ElasticArraySARTSTraces
22

3-
using ElasticArrays: ElasticArray, resize_lastdim!
4-
53
const ElasticArraySARTSTraces = Traces{
6-
SS′AA′RT,
4+
SS′ART,
75
<:Tuple{
8-
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
9-
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
10-
<:Trace{<:ElasticArray},
11-
<:Trace{<:ElasticArray},
6+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}},
7+
<:Trace{<:ElasticArrayBuffer},
8+
<:Trace{<:ElasticArrayBuffer},
9+
<:Trace{<:ElasticArrayBuffer},
1210
}
1311
}
1412

1513
function ElasticArraySARTSTraces(;
1614
state=Int => (),
1715
action=Int => (),
1816
reward=Float32 => (),
19-
terminal=Bool => ()
20-
)
17+
terminal=Bool => ())
18+
2119
state_eltype, state_size = state
2220
action_eltype, action_size = action
2321
reward_eltype, reward_size = reward
2422
terminal_eltype, terminal_size = terminal
2523

26-
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
27-
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
24+
MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity+1)) +
2825
Traces(
29-
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
30-
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
26+
action = ElasticArrayBuffer{action_eltype}(action_size..., capacity),
27+
reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity),
28+
terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3129
)
3230
end
33-
34-
#####
35-
# extensions for ElasticArrays
36-
#####
37-
38-
Base.push!(a::ElasticArray, x) = append!(a, x)
39-
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
40-
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
export ElasticArraySLARTTraces
2+
3+
const ElasticArraySLARTTraces = Traces{
4+
SS′LL′AA′RT,
5+
<:Tuple{
6+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}},
7+
<:MultiplexTraces{LL′,<:Trace{<:ElasticArrayBuffer}},
8+
<:MultiplexTraces{AA′,<:Trace{<:ElasticArrayBuffer}},
9+
<:Trace{<:ElasticArrayBuffer},
10+
<:Trace{<:ElasticArrayBuffer},
11+
}
12+
}
13+
14+
function ElasticArraySLARTTraces(;
15+
capacity::Int,
16+
state=Int => (),
17+
legal_actions_mask=Bool => (),
18+
action=Int => (),
19+
reward=Float32 => (),
20+
terminal=Bool => ()
21+
)
22+
state_eltype, state_size = state
23+
action_eltype, action_size = action
24+
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
25+
reward_eltype, reward_size = reward
26+
terminal_eltype, terminal_size = terminal
27+
28+
MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29+
MultiplexTraces{LL′}(ElasticArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
30+
MultiplexTraces{AA′}(ElasticArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
31+
Traces(
32+
reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity),
33+
terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity),
34+
)
35+
end
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
export ElasticPrioritizedTraces
2+
3+
struct ElasticPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
4+
keys::ElasticVectorBuffer{Int,Vector{Int}}
5+
priorities::SumTree{Float32}
6+
traces::T
7+
default_priority::Float32
8+
end
9+
10+
function ElasticPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
11+
new_names = (:key, :priority, names...)
12+
new_Ts = Tuple{Int,Float32,Ts.parameters...}
13+
c = capacity(traces)
14+
ElasticPrioritizedTraces{typeof(traces),new_names,new_Ts}(
15+
ElasticVectorBuffer{Int}(c),
16+
SumTree(c),
17+
traces,
18+
default_priority
19+
)
20+
end
21+
22+
function Base.push!(t::ElasticPrioritizedTraces, x)
23+
push!(t.traces, x)
24+
if length(t.traces) == 1
25+
push!(t.keys, 1)
26+
push!(t.priorities, t.default_priority)
27+
elseif length(t.traces) > 1
28+
push!(t.keys, t.keys[end] + 1)
29+
push!(t.priorities, t.default_priority)
30+
else
31+
# may be partial inserting at the first step, ignore it
32+
end
33+
end
34+
35+
function Base.setindex!(t::ElasticPrioritizedTraces, vs, k::Symbol, keys)
36+
if k === :priority
37+
@assert length(vs) == length(keys)
38+
for (i, v) in zip(keys, vs)
39+
if t.keys[1] <= i <= t.keys[end]
40+
t.priorities[i-t.keys[1]+1] = v
41+
end
42+
end
43+
else
44+
@error "unsupported yet"
45+
end
46+
end
47+
48+
Base.size(t::ElasticPrioritizedTraces) = size(t.traces)
49+
50+
function Base.getindex(ts::ElasticPrioritizedTraces, s::Symbol)
51+
if s === :priority
52+
Trace(ts.priorities)
53+
elseif s === :key
54+
Trace(ts.keys)
55+
else
56+
ts.traces[s]
57+
end
58+
end
59+
60+
Base.getindex(t::ElasticPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))

src/common/common.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ include("CircularArraySARTSTraces.jl")
1414
include("CircularArraySARTSATraces.jl")
1515
include("CircularArraySLARTTraces.jl")
1616
include("CircularPrioritizedTraces.jl")
17+
include("common_elastic_array.jl")
1718
include("ElasticArraySARTSTraces.jl")
19+
include("ElasticArraySARTSATraces.jl")
20+
include("ElasticArraySLARTTraces.jl")
21+
include("ElasticPrioritizedTraces.jl")

src/common/common_elastic_array.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using ElasticArrays: ElasticArray, resize_lastdim!
2+
3+
#####
4+
# extensions for ElasticArrays
5+
#####
6+
7+
Base.push!(a::ElasticArray, x) = append!(a, x)
8+
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
9+
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)

src/episodes.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export EpisodesBuffer, PartialNamedTuple
22
import DataStructures.CircularBuffer
3-
using ElasticArrays: ElasticArray, resize_lastdim!
3+
using ElasticArrays: ElasticArray, ElasticVector, resize_lastdim!
44

55
"""
66
EpisodesBuffer(traces::AbstractTraces)
@@ -91,7 +91,8 @@ function pad!(trace::Trace)
9191
return nothing
9292
end
9393

94-
pad!(vect::ElasticArray{T, Vector{T}}) where {T} = pad!(vect, zero(T))
94+
pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T))
95+
pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T))
9596
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
9697
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
9798

test/episodes.jl

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -197,36 +197,9 @@ using Test
197197
@test eb.step_numbers == [1:16;1:16]
198198
@test length(eb) == 31
199199
end
200-
@testset "with elastic traces" begin
201-
t = ElasticArraySARTSTraces(;
202-
state=Int => (),
203-
action=Int => (),
204-
reward=Float32 => (),
205-
terminal=Bool => ()
206-
)
207-
208-
eb = EpisodesBuffer(t)
209-
push!(eb, (state = 1,)) #partial inserting
210-
for i = 1:15
211-
push!(eb, (state = i+1, reward =i))
212-
end
213-
@test length(eb.traces) == 15
214-
@test eb.sampleable_inds == [fill(true, 15); [false]]
215-
@test all(==(15), eb.episodes_lengths)
216-
@test eb.step_numbers == [1:16;]
217-
push!(eb, (state = 1,)) #partial inserting
218-
for i = 1:15
219-
push!(eb, (state = i+1, reward =i))
220-
end
221-
@test eb.sampleable_inds == [fill(true, 15); [false];fill(true, 15); [false]]
222-
@test all(==(15), eb.episodes_lengths)
223-
@test eb.step_numbers == [1:16;1:16]
224-
@test length(eb) == 31
225-
end
226-
@testset "with circular traces" begin
200+
@testset "with ElasticArraySARTSTraces traces" begin
227201
eb = EpisodesBuffer(
228-
CircularArraySARTSTraces(;
229-
capacity=10)
202+
ElasticArraySARTSTraces()
230203
)
231204
#push a first episode l=5
232205
push!(eb, (state = 1,))

0 commit comments

Comments
 (0)