Skip to content

Commit a51047f

Browse files
author
Jeremiah Lewis
committed
fix naming
1 parent 9625f16 commit a51047f

File tree

6 files changed

+211
-9
lines changed

6 files changed

+211
-9
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
export ElasticArraySARTTraces
1+
export ElasticArraySARTSTraces
22

33
using ElasticArrays: ElasticArray, resize_lastdim!
44

5-
const ElasticArraySARTTraces = Traces{
5+
const ElasticArraySARTSTraces = Traces{
66
SS′AA′RT,
77
<:Tuple{
88
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
@@ -12,7 +12,7 @@ const ElasticArraySARTTraces = Traces{
1212
}
1313
}
1414

15-
function ElasticArraySARTTraces(;
15+
function ElasticArraySARTSTraces(;
1616
state=Int => (),
1717
action=Int => (),
1818
reward=Float32 => (),
@@ -37,4 +37,4 @@ end
3737

3838
Base.push!(a::ElasticArray, x) = append!(a, x)
3939
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
40-
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
40+
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)

src/common/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ include("CircularArraySARTSTraces.jl")
1414
include("CircularArraySARTSATraces.jl")
1515
include("CircularArraySLARTTraces.jl")
1616
include("CircularPrioritizedTraces.jl")
17-
include("ElasticArraySARTTraces.jl")
17+
include("ElasticArraySARTSTraces.jl")

src/episodes.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export EpisodesBuffer, PartialNamedTuple
22
import DataStructures.CircularBuffer
3+
using ElasticArrays: ElasticArray, resize_lastdim!
34

45
"""
56
EpisodesBuffer(traces::AbstractTraces)
@@ -90,6 +91,7 @@ function pad!(trace::Trace)
9091
return nothing
9192
end
9293

94+
pad!(vect::ElasticArray{T, Vector{T}}) where {T} = pad!(vect, zero(T))
9395
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
9496
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
9597

src/traces.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces
33
import MacroTools: @forward
44

55
import CircularArrayBuffers.CircularArrayBuffer
6+
using ElasticArrays: ElasticArray
67
import Adapt
78

89
#####
@@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
5556
capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent)
5657
capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t)
5758
capacity(::AbstractVector) = Inf
59+
capacity(::ElasticArray) = Inf
5860

5961
#####
6062

test/common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ end
9494
@test batch.terminal == Bool[0, 0, 0] |> gpu
9595
end
9696

97-
@testset "ElasticArraySARTTraces" begin
98-
t = ElasticArraySARTTraces(;
97+
@testset "ElasticArraySARTSTraces" begin
98+
t = ElasticArraySARTSTraces(;
9999
state=Float32 => (2, 3),
100100
action=Int => (),
101101
reward=Float32 => (),
102102
terminal=Bool => ()
103103
)
104104

105-
@test t isa ElasticArraySARTTraces
105+
@test t isa ElasticArraySARTSTraces
106106

107107
push!(t, (state=ones(Float32, 2, 3), action=1))
108108
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
@@ -185,4 +185,4 @@ end
185185

186186
eb[:priority, [1, 2]] = [0, 0]
187187
@test eb[:priority] == [zeros(2);ones(8)]
188-
end
188+
end

test/episodes.jl

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,202 @@ using Test
197197
@test eb.step_numbers == [1:16;1:16]
198198
@test length(eb) == 31
199199
end
200+
@testset "with elastic traces" begin
201+
t = ElasticArraySARTSTraces(;
202+
state=Int => (),
203+
action=Int => (),
204+
reward=Float32 => (),
205+
terminal=Bool => ()
206+
)
207+
208+
eb = EpisodesBuffer(t)
209+
push!(eb, (state = 1,)) #partial inserting
210+
for i = 1:15
211+
push!(eb, (state = i+1, reward =i))
212+
end
213+
@test length(eb.traces) == 15
214+
@test eb.sampleable_inds == [fill(true, 15); [false]]
215+
@test all(==(15), eb.episodes_lengths)
216+
@test eb.step_numbers == [1:16;]
217+
push!(eb, (state = 1,)) #partial inserting
218+
for i = 1:15
219+
push!(eb, (state = i+1, reward =i))
220+
end
221+
@test eb.sampleable_inds == [fill(true, 15); [false];fill(true, 15); [false]]
222+
@test all(==(15), eb.episodes_lengths)
223+
@test eb.step_numbers == [1:16;1:16]
224+
@test length(eb) == 31
225+
end
226+
@testset "with circular traces" begin
227+
eb = EpisodesBuffer(
228+
CircularArraySARTSTraces(;
229+
capacity=10)
230+
)
231+
#push a first episode l=5
232+
push!(eb, (state = 1,))
233+
@test eb.sampleable_inds[end] == 0
234+
@test eb.episodes_lengths[end] == 0
235+
@test eb.step_numbers[end] == 1
236+
for i = 1:5
237+
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
238+
@test eb.sampleable_inds[end] == 0
239+
@test eb.sampleable_inds[end-1] == 1
240+
@test eb.step_numbers[end] == i + 1
241+
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
242+
end
243+
@test eb.sampleable_inds == [1,1,1,1,1,0]
244+
@test length(eb.traces) == 5
245+
#start new episode of 6 periods.
246+
push!(eb, (state = 7,))
247+
@test eb.sampleable_inds[end] == 0
248+
@test eb.sampleable_inds[end-1] == 0
249+
@test eb.episodes_lengths[end] == 0
250+
@test eb.step_numbers[end] == 1
251+
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
252+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero
253+
ep2_len = 0
254+
for (j,i) = enumerate(8:11)
255+
ep2_len += 1
256+
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
257+
@test eb.sampleable_inds[end] == 0
258+
@test eb.sampleable_inds[end-1] == 1
259+
@test eb.step_numbers[end] == j + 1
260+
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
261+
end
262+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
263+
@test length(eb.traces) == 10
264+
#three last steps replace oldest steps in the buffer.
265+
for (i, s) = enumerate(12:13)
266+
ep2_len += 1
267+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
268+
@test eb.sampleable_inds[end] == 0
269+
@test eb.sampleable_inds[end-1] == 1
270+
@test eb.step_numbers[end] == i + 1 + 4
271+
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
272+
end
273+
#episode 1
274+
for (i,s) in enumerate(3:13)
275+
if i in (4, 11)
276+
@test eb.sampleable_inds[i] == 0
277+
continue
278+
else
279+
@test eb.sampleable_inds[i] == 1
280+
end
281+
b = eb[i]
282+
@test b[:state] == b[:action] == b[:reward] == s
283+
@test b[:next_state] == s + 1
284+
end
285+
#episode 2
286+
#start a third episode
287+
push!(eb, (state = 14, ))
288+
@test eb.sampleable_inds[end] == 0
289+
@test eb.sampleable_inds[end-1] == 0
290+
@test eb.episodes_lengths[end] == 0
291+
@test eb.step_numbers[end] == 1
292+
#push until it reaches it own start
293+
for (i,s) in enumerate(15:26)
294+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
295+
end
296+
@test eb.sampleable_inds == [fill(true, 10); [false]]
297+
@test eb.episodes_lengths == fill(length(15:26), 11)
298+
@test eb.step_numbers == [3:13;]
299+
step = popfirst!(eb)
300+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
301+
@test first(eb.step_numbers) == 4
302+
step = pop!(eb)
303+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8
304+
@test last(eb.step_numbers) == 12
305+
@test size(eb) == size(eb.traces) == (8,)
306+
empty!(eb)
307+
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
308+
show(eb);
309+
end
310+
@testset "with PartialNamedTuple" begin
311+
eb = EpisodesBuffer(
312+
CircularArraySARTSATraces(;
313+
capacity=10)
314+
)
315+
#push a first episode l=5
316+
push!(eb, (state = 1,))
317+
@test eb.sampleable_inds[end] == 0
318+
@test eb.episodes_lengths[end] == 0
319+
@test eb.step_numbers[end] == 1
320+
for i = 1:5
321+
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
322+
@test eb.sampleable_inds[end] == 0
323+
@test eb.sampleable_inds[end-1] == 1
324+
@test eb.step_numbers[end] == i + 1
325+
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
326+
end
327+
push!(eb, PartialNamedTuple((action = 6,)))
328+
@test eb.sampleable_inds == [1,1,1,1,1,0]
329+
@test length(eb.traces) == 5
330+
#start new episode of 6 periods.
331+
push!(eb, (state = 7,))
332+
@test eb.sampleable_inds[end] == 0
333+
@test eb.sampleable_inds[end-1] == 0
334+
@test eb.episodes_lengths[end] == 0
335+
@test eb.step_numbers[end] == 1
336+
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
337+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
338+
ep2_len = 0
339+
for (j,i) = enumerate(8:11)
340+
ep2_len += 1
341+
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
342+
@test eb.sampleable_inds[end] == 0
343+
@test eb.sampleable_inds[end-1] == 1
344+
@test eb.step_numbers[end] == j + 1
345+
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
346+
end
347+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
348+
@test length(eb.traces) == 9 #an action is missing at this stage
349+
#three last steps replace oldest steps in the buffer.
350+
for (i, s) = enumerate(12:13)
351+
ep2_len += 1
352+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
353+
@test eb.sampleable_inds[end] == 0
354+
@test eb.sampleable_inds[end-1] == 1
355+
@test eb.step_numbers[end] == i + 1 + 4
356+
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
357+
end
358+
push!(eb, PartialNamedTuple((action = 13,)))
359+
@test length(eb.traces) == 10
360+
#episode 1
361+
for (i,s) in enumerate(3:13)
362+
if i in (4, 11)
363+
@test eb.sampleable_inds[i] == 0
364+
continue
365+
else
366+
@test eb.sampleable_inds[i] == 1
367+
end
368+
b = eb[i]
369+
@test b[:state] == b[:action] == b[:reward] == s
370+
@test b[:next_state] == b[:next_action] == s + 1
371+
end
372+
#episode 2
373+
#start a third episode
374+
push!(eb, (state = 14,))
375+
@test eb.sampleable_inds[end] == 0
376+
@test eb.sampleable_inds[end-1] == 0
377+
@test eb.episodes_lengths[end] == 0
378+
@test eb.step_numbers[end] == 1
379+
#push until it reaches it own start
380+
for (i,s) in enumerate(15:26)
381+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
382+
end
383+
push!(eb, PartialNamedTuple((action = 26,)))
384+
@test eb.sampleable_inds == [fill(true, 10); [false]]
385+
@test eb.episodes_lengths == fill(length(15:26), 11)
386+
@test eb.step_numbers == [3:13;]
387+
step = popfirst!(eb)
388+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
389+
@test first(eb.step_numbers) == 4
390+
step = pop!(eb)
391+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8
392+
@test last(eb.step_numbers) == 12
393+
@test size(eb) == size(eb.traces) == (8,)
394+
empty!(eb)
395+
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
396+
show(eb);
397+
end
200398
end

0 commit comments

Comments
 (0)