@@ -197,4 +197,202 @@ using Test
197197 @test eb. step_numbers == [1 : 16 ;1 : 16 ]
198198 @test length (eb) == 31
199199 end
200+ @testset " with elastic traces" begin
201+ t = ElasticArraySARTSTraces (;
202+ state= Int => (),
203+ action= Int => (),
204+ reward= Float32 => (),
205+ terminal= Bool => ()
206+ )
207+
208+ eb = EpisodesBuffer (t)
209+ push! (eb, (state = 1 ,)) # partial inserting
210+ for i = 1 : 15
211+ push! (eb, (state = i+ 1 , reward = i))
212+ end
213+ @test length (eb. traces) == 15
214+ @test eb. sampleable_inds == [fill (true , 15 ); [false ]]
215+ @test all (== (15 ), eb. episodes_lengths)
216+ @test eb. step_numbers == [1 : 16 ;]
217+ push! (eb, (state = 1 ,)) # partial inserting
218+ for i = 1 : 15
219+ push! (eb, (state = i+ 1 , reward = i))
220+ end
221+ @test eb. sampleable_inds == [fill (true , 15 ); [false ];fill (true , 15 ); [false ]]
222+ @test all (== (15 ), eb. episodes_lengths)
223+ @test eb. step_numbers == [1 : 16 ;1 : 16 ]
224+ @test length (eb) == 31
225+ end
226+ @testset " with circular traces" begin
227+ eb = EpisodesBuffer (
228+ CircularArraySARTSTraces (;
229+ capacity= 10 )
230+ )
231+ # push a first episode l=5
232+ push! (eb, (state = 1 ,))
233+ @test eb. sampleable_inds[end ] == 0
234+ @test eb. episodes_lengths[end ] == 0
235+ @test eb. step_numbers[end ] == 1
236+ for i = 1 : 5
237+ push! (eb, (state = i+ 1 , action = i, reward = i, terminal = false ))
238+ @test eb. sampleable_inds[end ] == 0
239+ @test eb. sampleable_inds[end - 1 ] == 1
240+ @test eb. step_numbers[end ] == i + 1
241+ @test eb. episodes_lengths[end - i: end ] == fill (i, i+ 1 )
242+ end
243+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ]
244+ @test length (eb. traces) == 5
245+ # start new episode of 6 periods.
246+ push! (eb, (state = 7 ,))
247+ @test eb. sampleable_inds[end ] == 0
248+ @test eb. sampleable_inds[end - 1 ] == 0
249+ @test eb. episodes_lengths[end ] == 0
250+ @test eb. step_numbers[end ] == 1
251+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,0 ]
252+ @test eb[6 ][:reward ] == 0 # 6 is not a valid index, the reward there is filled as zero
253+ ep2_len = 0
254+ for (j,i) = enumerate (8 : 11 )
255+ ep2_len += 1
256+ push! (eb, (state = i, action = i- 1 , reward = i- 1 , terminal = false ))
257+ @test eb. sampleable_inds[end ] == 0
258+ @test eb. sampleable_inds[end - 1 ] == 1
259+ @test eb. step_numbers[end ] == j + 1
260+ @test eb. episodes_lengths[end - j: end ] == fill (ep2_len, ep2_len + 1 )
261+ end
262+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,1 ,1 ,1 ,1 ,0 ]
263+ @test length (eb. traces) == 10
264+ # three last steps replace oldest steps in the buffer.
265+ for (i, s) = enumerate (12 : 13 )
266+ ep2_len += 1
267+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
268+ @test eb. sampleable_inds[end ] == 0
269+ @test eb. sampleable_inds[end - 1 ] == 1
270+ @test eb. step_numbers[end ] == i + 1 + 4
271+ @test eb. episodes_lengths[end - ep2_len: end ] == fill (ep2_len, ep2_len + 1 )
272+ end
273+ # episode 1
274+ for (i,s) in enumerate (3 : 13 )
275+ if i in (4 , 11 )
276+ @test eb. sampleable_inds[i] == 0
277+ continue
278+ else
279+ @test eb. sampleable_inds[i] == 1
280+ end
281+ b = eb[i]
282+ @test b[:state ] == b[:action ] == b[:reward ] == s
283+ @test b[:next_state ] == s + 1
284+ end
285+ # episode 2
286+ # start a third episode
287+ push! (eb, (state = 14 , ))
288+ @test eb. sampleable_inds[end ] == 0
289+ @test eb. sampleable_inds[end - 1 ] == 0
290+ @test eb. episodes_lengths[end ] == 0
291+ @test eb. step_numbers[end ] == 1
292+ # push until it reaches it own start
293+ for (i,s) in enumerate (15 : 26 )
294+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
295+ end
296+ @test eb. sampleable_inds == [fill (true , 10 ); [false ]]
297+ @test eb. episodes_lengths == fill (length (15 : 26 ), 11 )
298+ @test eb. step_numbers == [3 : 13 ;]
299+ step = popfirst! (eb)
300+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 9
301+ @test first (eb. step_numbers) == 4
302+ step = pop! (eb)
303+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 8
304+ @test last (eb. step_numbers) == 12
305+ @test size (eb) == size (eb. traces) == (8 ,)
306+ empty! (eb)
307+ @test size (eb) == (0 ,) == size (eb. traces) == size (eb. sampleable_inds) == size (eb. episodes_lengths) == size (eb. step_numbers)
308+ show (eb);
309+ end
310+ @testset " with PartialNamedTuple" begin
311+ eb = EpisodesBuffer (
312+ CircularArraySARTSATraces (;
313+ capacity= 10 )
314+ )
315+ # push a first episode l=5
316+ push! (eb, (state = 1 ,))
317+ @test eb. sampleable_inds[end ] == 0
318+ @test eb. episodes_lengths[end ] == 0
319+ @test eb. step_numbers[end ] == 1
320+ for i = 1 : 5
321+ push! (eb, (state = i+ 1 , action = i, reward = i, terminal = false ))
322+ @test eb. sampleable_inds[end ] == 0
323+ @test eb. sampleable_inds[end - 1 ] == 1
324+ @test eb. step_numbers[end ] == i + 1
325+ @test eb. episodes_lengths[end - i: end ] == fill (i, i+ 1 )
326+ end
327+ push! (eb, PartialNamedTuple ((action = 6 ,)))
328+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ]
329+ @test length (eb. traces) == 5
330+ # start new episode of 6 periods.
331+ push! (eb, (state = 7 ,))
332+ @test eb. sampleable_inds[end ] == 0
333+ @test eb. sampleable_inds[end - 1 ] == 0
334+ @test eb. episodes_lengths[end ] == 0
335+ @test eb. step_numbers[end ] == 1
336+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,0 ]
337+ @test eb[6 ][:reward ] == 0 # 6 is not a valid index, the reward there is dummy, filled as zero
338+ ep2_len = 0
339+ for (j,i) = enumerate (8 : 11 )
340+ ep2_len += 1
341+ push! (eb, (state = i, action = i- 1 , reward = i- 1 , terminal = false ))
342+ @test eb. sampleable_inds[end ] == 0
343+ @test eb. sampleable_inds[end - 1 ] == 1
344+ @test eb. step_numbers[end ] == j + 1
345+ @test eb. episodes_lengths[end - j: end ] == fill (ep2_len, ep2_len + 1 )
346+ end
347+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,1 ,1 ,1 ,1 ,0 ]
348+ @test length (eb. traces) == 9 # an action is missing at this stage
349+ # three last steps replace oldest steps in the buffer.
350+ for (i, s) = enumerate (12 : 13 )
351+ ep2_len += 1
352+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
353+ @test eb. sampleable_inds[end ] == 0
354+ @test eb. sampleable_inds[end - 1 ] == 1
355+ @test eb. step_numbers[end ] == i + 1 + 4
356+ @test eb. episodes_lengths[end - ep2_len: end ] == fill (ep2_len, ep2_len + 1 )
357+ end
358+ push! (eb, PartialNamedTuple ((action = 13 ,)))
359+ @test length (eb. traces) == 10
360+ # episode 1
361+ for (i,s) in enumerate (3 : 13 )
362+ if i in (4 , 11 )
363+ @test eb. sampleable_inds[i] == 0
364+ continue
365+ else
366+ @test eb. sampleable_inds[i] == 1
367+ end
368+ b = eb[i]
369+ @test b[:state ] == b[:action ] == b[:reward ] == s
370+ @test b[:next_state ] == b[:next_action ] == s + 1
371+ end
372+ # episode 2
373+ # start a third episode
374+ push! (eb, (state = 14 ,))
375+ @test eb. sampleable_inds[end ] == 0
376+ @test eb. sampleable_inds[end - 1 ] == 0
377+ @test eb. episodes_lengths[end ] == 0
378+ @test eb. step_numbers[end ] == 1
379+ # push until it reaches it own start
380+ for (i,s) in enumerate (15 : 26 )
381+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
382+ end
383+ push! (eb, PartialNamedTuple ((action = 26 ,)))
384+ @test eb. sampleable_inds == [fill (true , 10 ); [false ]]
385+ @test eb. episodes_lengths == fill (length (15 : 26 ), 11 )
386+ @test eb. step_numbers == [3 : 13 ;]
387+ step = popfirst! (eb)
388+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 9
389+ @test first (eb. step_numbers) == 4
390+ step = pop! (eb)
391+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 8
392+ @test last (eb. step_numbers) == 12
393+ @test size (eb) == size (eb. traces) == (8 ,)
394+ empty! (eb)
395+ @test size (eb) == (0 ,) == size (eb. traces) == size (eb. sampleable_inds) == size (eb. episodes_lengths) == size (eb. step_numbers)
396+ show (eb);
397+ end
200398end
0 commit comments