@@ -2,7 +2,8 @@ use std::collections::HashMap;
22use std:: iter:: once;
33use std:: num:: NonZero ;
44use std:: sync:: Arc ;
5- use tokio:: task:: JoinHandle ;
5+ use tokio:: runtime:: Handle ;
6+ use tokio:: task:: JoinSet ;
67
78use anyhow:: { Context , Result } ;
89use nuts_storable:: { ItemType , Value } ;
@@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage {
4344 arrays : Arc < ArrayCollection > ,
4445 chain : u64 ,
4546 last_sample_was_warmup : bool ,
46- pending_writes : Vec < JoinHandle < Result < ( ) > > > ,
47+ pending_writes : Arc < tokio :: sync :: Mutex < JoinSet < Result < ( ) > > > > ,
4748 rt_handle : tokio:: runtime:: Handle ,
49+ max_queued_writes : usize ,
4850}
4951
5052/// Write a chunk of data to a Zarr array asynchronously
@@ -240,22 +242,28 @@ impl ZarrAsyncChainStorage {
240242 chain : u64 ,
241243 rt_handle : tokio:: runtime:: Handle ,
242244 ) -> Self {
243- let draw_buffers = draw_types
245+ let draw_buffers: HashMap < String , SampleBuffer > = draw_types
244246 . iter ( )
245247 . map ( |( name, item_type) | ( name. clone ( ) , SampleBuffer :: new ( * item_type, buffer_size) ) )
246248 . collect ( ) ;
247249
248- let stats_buffers = param_types
250+ let stats_buffers: HashMap < String , SampleBuffer > = param_types
249251 . iter ( )
250252 . map ( |( name, item_type) | ( name. clone ( ) , SampleBuffer :: new ( * item_type, buffer_size) ) )
251253 . collect ( ) ;
254+
255+ let num_arrays = draw_buffers. len ( ) + stats_buffers. len ( ) ;
256+
252257 Self {
253258 draw_buffers,
254259 stats_buffers,
255260 arrays,
256261 chain,
257262 last_sample_was_warmup : true ,
258- pending_writes : Vec :: new ( ) ,
263+ pending_writes : Arc :: new ( tokio:: sync:: Mutex :: new ( JoinSet :: new ( ) ) ) ,
264+ // We allow up to the number of arrays in pending writes, so
265+ // that we queue one write per draw.
266+ max_queued_writes : num_arrays. max ( 1 ) ,
259267 rt_handle,
260268 }
261269 }
@@ -275,10 +283,15 @@ impl ZarrAsyncChainStorage {
275283 self . arrays . sample_param_arrays [ name] . clone ( )
276284 } ;
277285 let chain = self . chain ;
278- let handle = self
279- . rt_handle
280- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
281- self . pending_writes . push ( handle) ;
286+
287+ queue_write (
288+ & self . rt_handle ,
289+ self . pending_writes . clone ( ) ,
290+ self . max_queued_writes ,
291+ array,
292+ chunk,
293+ chain,
294+ ) ?;
282295 }
283296 Ok ( ( ) )
284297 }
@@ -298,15 +311,57 @@ impl ZarrAsyncChainStorage {
298311 self . arrays . sample_draw_arrays [ name] . clone ( )
299312 } ;
300313 let chain = self . chain ;
301- let handle = self
302- . rt_handle
303- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
304- self . pending_writes . push ( handle) ;
314+
315+ queue_write (
316+ & self . rt_handle ,
317+ self . pending_writes . clone ( ) ,
318+ self . max_queued_writes ,
319+ array,
320+ chunk,
321+ chain,
322+ ) ?;
305323 }
306324 Ok ( ( ) )
307325 }
308326}
309327
328+ fn queue_write (
329+ handle : & Handle ,
330+ queue : Arc < tokio:: sync:: Mutex < JoinSet < Result < ( ) > > > > ,
331+ max_queued_writes : usize ,
332+ array : Array ,
333+ chunk : Chunk ,
334+ chain : u64 ,
335+ ) -> Result < ( ) > {
336+ let rt_handle = handle. clone ( ) ;
337+ // We need an async task to interface with the async storage
338+ // and JoinSet API.
339+ let spawn_write_task = handle. spawn ( async move {
340+ // This should never actually block, because this lock
341+ // is only held in tasks that are spawned and immediately blocked_on
342+ // from the sampling thread.
343+ let mut writes_guard = queue. lock ( ) . await ;
344+
345+ while writes_guard. len ( ) >= max_queued_writes {
346+ let out = writes_guard. join_next ( ) . await ;
347+ if let Some ( out) = out {
348+ out. context ( "Failed to await previous trace write operation" ) ?
349+ . context ( "Chunk write operation failed" ) ?;
350+ } else {
351+ break ;
352+ }
353+ }
354+ writes_guard. spawn_on (
355+ async move { store_zarr_chunk_async ( array, chunk, chain) . await } ,
356+ & rt_handle,
357+ ) ;
358+ Ok ( ( ) )
359+ } ) ;
360+ let res: Result < ( ) > = handle. block_on ( spawn_write_task) ?;
361+ res?;
362+ Ok ( ( ) )
363+ }
364+
310365impl ChainStorage for ZarrAsyncChainStorage {
311366 type Finalized = ( ) ;
312367
@@ -323,20 +378,30 @@ impl ChainStorage for ZarrAsyncChainStorage {
323378 if let Some ( chunk) = buffer. reset ( ) {
324379 let array = self . arrays . warmup_draw_arrays [ key] . clone ( ) ;
325380 let chain = self . chain ;
326- let handle = self
327- . rt_handle
328- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
329- self . pending_writes . push ( handle) ;
381+
382+ queue_write (
383+ & self . rt_handle ,
384+ self . pending_writes . clone ( ) ,
385+ self . max_queued_writes ,
386+ array,
387+ chunk,
388+ chain,
389+ ) ?;
330390 }
331391 }
332392 for ( key, buffer) in self . stats_buffers . iter_mut ( ) {
333393 if let Some ( chunk) = buffer. reset ( ) {
334394 let array = self . arrays . warmup_param_arrays [ key] . clone ( ) ;
335395 let chain = self . chain ;
336- let handle = self
337- . rt_handle
338- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
339- self . pending_writes . push ( handle) ;
396+
397+ queue_write (
398+ & self . rt_handle ,
399+ self . pending_writes . clone ( ) ,
400+ self . max_queued_writes ,
401+ array,
402+ chunk,
403+ chain,
404+ ) ?;
340405 }
341406 }
342407 self . last_sample_was_warmup = false ;
@@ -382,11 +447,14 @@ impl ChainStorage for ZarrAsyncChainStorage {
382447 }
383448
384449 // Join all pending writes
450+ // All tasks that hold a reference to the queue are blocked_on
451+ // right away, so we hold the only refercne to `self.pending_writes`.
452+ let pending_writes = Arc :: into_inner ( self . pending_writes )
453+ . expect ( "Could not take ownership of pending writes queue" )
454+ . into_inner ( ) ;
385455 self . rt_handle . block_on ( async move {
386- for join_handle in self . pending_writes {
387- let _ = join_handle
388- . await
389- . context ( "Failed to await async chunk write operation" ) ?;
456+ for join_handle in pending_writes. join_all ( ) . await {
457+ let _ = join_handle. context ( "Failed to await async chunk write operation" ) ?;
390458 }
391459 Ok :: < ( ) , anyhow:: Error > ( ( ) )
392460 } ) ?;
@@ -420,6 +488,21 @@ impl ChainStorage for ZarrAsyncChainStorage {
420488 }
421489 }
422490
491+ // Join all pending writes
492+ let pending_writes = self . pending_writes . clone ( ) ;
493+ self . rt_handle . block_on ( async move {
494+ let mut pending_writes = pending_writes. lock ( ) . await ;
495+ loop {
496+ let Some ( join_handle) = pending_writes. join_next ( ) . await else {
497+ break ;
498+ } ;
499+ join_handle
500+ . context ( "Failed to await async chunk write operation" ) ?
501+ . context ( "Chunk write operation failed" ) ?;
502+ }
503+ Ok :: < ( ) , anyhow:: Error > ( ( ) )
504+ } ) ?;
505+
423506 Ok ( ( ) )
424507 }
425508}
0 commit comments