Skip to content

Commit 14b52c4

Browse files
author
Jeremiah Lewis
committed
tweaks
1 parent 996484c commit 14b52c4

File tree

6 files changed

+42
-76
lines changed

6 files changed

+42
-76
lines changed

src/common/ElasticArraySARTSATraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export ElasticArraySARTSTraces
1+
export ElasticArraySARTSATraces
22

33
const ElasticArraySARTSATraces = Traces{
44
SS′AA′RT,

src/common/ElasticPrioritizedTraces.jl

Lines changed: 0 additions & 62 deletions
This file was deleted.

src/common/common.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@ include("common_elastic_array.jl")
1818
include("ElasticArraySARTSTraces.jl")
1919
include("ElasticArraySARTSATraces.jl")
2020
include("ElasticArraySLARTTraces.jl")
21-
include("ElasticPrioritizedTraces.jl")

src/episodes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ end
8585
ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
8686
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
8787
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)
88+
ispartial_insert(traces::ElasticPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)
8889

8990
function pad!(trace::Trace)
9091
pad!(trace.parent)
@@ -130,6 +131,8 @@ fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
130131

131132
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
132133

134+
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:ElasticPrioritizedTraces}) = fill_multiplex(es.traces.traces)
135+
133136
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
134137
push!(eb.traces, xs)
135138
partial = ispartial_insert(eb, xs)

test/episodes.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ using Test
114114
@test eb.episodes_lengths[end] == 0
115115
@test eb.step_numbers[end] == 1
116116
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
117+
@test eb[:action][6] == 6
118+
@test eb[:next_action][6] == 6
117119
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
118120
ep2_len = 0
119121
for (j,i) = enumerate(8:11)
@@ -235,7 +237,7 @@ using Test
235237
end
236238
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
237239
@test length(eb.traces) == 10
238-
#three last steps replace oldest steps in the buffer.
240+
239241
for (i, s) = enumerate(12:13)
240242
ep2_len += 1
241243
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
@@ -245,16 +247,16 @@ using Test
245247
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
246248
end
247249
#episode 1
248-
for (i,s) in enumerate(3:13)
249-
if i in (4, 11)
250+
for i in 3:13
251+
if i in (6, 13)
250252
@test eb.sampleable_inds[i] == 0
251253
continue
252254
else
253255
@test eb.sampleable_inds[i] == 1
254256
end
255257
b = eb[i]
256-
@test b[:state] == b[:action] == b[:reward] == s
257-
@test b[:next_state] == s + 1
258+
@test b[:state] == b[:action] == b[:reward] == i
259+
@test b[:next_state] == i + 1
258260
end
259261
#episode 2
260262
#start a third episode
@@ -263,13 +265,14 @@ using Test
263265
@test eb.sampleable_inds[end-1] == 0
264266
@test eb.episodes_lengths[end] == 0
265267
@test eb.step_numbers[end] == 1
266-
#push until it reaches it own start
268+
267269
for (i,s) in enumerate(15:26)
268270
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
269271
end
270-
@test eb.sampleable_inds == [fill(true, 10); [false]]
271-
@test eb.episodes_lengths == fill(length(15:26), 11)
272-
@test eb.step_numbers == [3:13;]
272+
@test eb.sampleable_inds[end-5:end] == [fill(true, 5); [false]]
273+
@test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11)
274+
@test eb.step_numbers[end-10:end] == [3:13;]
275+
#= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays
273276
step = popfirst!(eb)
274277
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
275278
@test first(eb.step_numbers) == 4
@@ -280,11 +283,11 @@ using Test
280283
empty!(eb)
281284
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
282285
show(eb);
286+
=#
283287
end
284-
@testset "with PartialNamedTuple" begin
288+
@testset "ElasticArraySARTSATraces with PartialNamedTuple" begin
285289
eb = EpisodesBuffer(
286-
ElasticArraySARTSATraces(;
287-
capacity=10)
290+
ElasticArraySARTSATraces()
288291
)
289292
#push a first episode l=5
290293
push!(eb, (state = 1,))
@@ -308,6 +311,8 @@ using Test
308311
@test eb.episodes_lengths[end] == 0
309312
@test eb.step_numbers[end] == 1
310313
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
314+
@test eb[:action][6] == 6
315+
@test eb[:next_action][6] == 6
311316
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
312317
ep2_len = 0
313318
for (j,i) = enumerate(8:11)
@@ -358,6 +363,7 @@ using Test
358363
@test eb.sampleable_inds == [fill(true, 10); [false]]
359364
@test eb.episodes_lengths == fill(length(15:26), 11)
360365
@test eb.step_numbers == [3:13;]
366+
#= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays
361367
step = popfirst!(eb)
362368
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
363369
@test first(eb.step_numbers) == 4
@@ -368,5 +374,6 @@ using Test
368374
empty!(eb)
369375
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
370376
show(eb);
377+
=#
371378
end
372379
end

test/traces.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ using ReinforcementLearningTrajectories: build_trace_index
137137
build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2])
138138
end
139139

140+
@testset "build_trace_index ElasticArraySARTSATraces" begin
141+
t1 = ElasticArraySARTSATraces(;
142+
capacity=3,
143+
state=Float32 => (2, 3),
144+
action=Float32 => (2,),
145+
reward=Float32 => (),
146+
terminal=Bool => ()
147+
)
148+
@test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict(:reward => 3,
149+
:next_state => 1,
150+
:state => 1,
151+
:action => 2,
152+
:next_action => 2,
153+
:terminal => 4)
154+
155+
t2 = Traces(; a=[2, 3], b=[false, true])
156+
build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2])
157+
end
158+
140159
@testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin
141160
t1 = CircularArraySARTSATraces(;
142161
capacity=3,

0 commit comments

Comments
 (0)