@@ -2,7 +2,8 @@ use std::collections::HashMap;
22use std:: iter:: once;
33use std:: num:: NonZero ;
44use std:: sync:: Arc ;
5- use tokio:: task:: JoinHandle ;
5+ use tokio:: runtime:: Handle ;
6+ use tokio:: task:: JoinSet ;
67
78use anyhow:: { Context , Result } ;
89use nuts_storable:: { ItemType , Value } ;
@@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage {
4344 arrays : Arc < ArrayCollection > ,
4445 chain : u64 ,
4546 last_sample_was_warmup : bool ,
46- pending_writes : Vec < JoinHandle < Result < ( ) > > > ,
47+ pending_writes : Arc < tokio :: sync :: Mutex < JoinSet < Result < ( ) > > > > ,
4748 rt_handle : tokio:: runtime:: Handle ,
49+ max_queued_writes : usize ,
4850}
4951
5052/// Write a chunk of data to a Zarr array asynchronously
@@ -240,22 +242,26 @@ impl ZarrAsyncChainStorage {
240242 chain : u64 ,
241243 rt_handle : tokio:: runtime:: Handle ,
242244 ) -> Self {
243- let draw_buffers = draw_types
245+ let draw_buffers: HashMap < String , SampleBuffer > = draw_types
244246 . iter ( )
245247 . map ( |( name, item_type) | ( name. clone ( ) , SampleBuffer :: new ( * item_type, buffer_size) ) )
246248 . collect ( ) ;
247249
248- let stats_buffers = param_types
250+ let stats_buffers: HashMap < String , SampleBuffer > = param_types
249251 . iter ( )
250252 . map ( |( name, item_type) | ( name. clone ( ) , SampleBuffer :: new ( * item_type, buffer_size) ) )
251253 . collect ( ) ;
254+
255+ let num_arrays = draw_buffers. len ( ) + stats_buffers. len ( ) ;
256+
252257 Self {
253258 draw_buffers,
254259 stats_buffers,
255260 arrays,
256261 chain,
257262 last_sample_was_warmup : true ,
258- pending_writes : Vec :: new ( ) ,
263+ pending_writes : Arc :: new ( tokio:: sync:: Mutex :: new ( JoinSet :: new ( ) ) ) ,
264+ max_queued_writes : 2 * ( num_arrays. max ( 1 ) ) ,
259265 rt_handle,
260266 }
261267 }
@@ -275,10 +281,15 @@ impl ZarrAsyncChainStorage {
275281 self . arrays . sample_param_arrays [ name] . clone ( )
276282 } ;
277283 let chain = self . chain ;
278- let handle = self
279- . rt_handle
280- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
281- self . pending_writes . push ( handle) ;
284+
285+ queue_write (
286+ & self . rt_handle ,
287+ self . pending_writes . clone ( ) ,
288+ self . max_queued_writes ,
289+ array,
290+ chunk,
291+ chain,
292+ ) ?;
282293 }
283294 Ok ( ( ) )
284295 }
@@ -298,15 +309,52 @@ impl ZarrAsyncChainStorage {
298309 self . arrays . sample_draw_arrays [ name] . clone ( )
299310 } ;
300311 let chain = self . chain ;
301- let handle = self
302- . rt_handle
303- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
304- self . pending_writes . push ( handle) ;
312+
313+ queue_write (
314+ & self . rt_handle ,
315+ self . pending_writes . clone ( ) ,
316+ self . max_queued_writes ,
317+ array,
318+ chunk,
319+ chain,
320+ ) ?;
305321 }
306322 Ok ( ( ) )
307323 }
308324}
309325
326+ fn queue_write (
327+ handle : & Handle ,
328+ queue : Arc < tokio:: sync:: Mutex < JoinSet < Result < ( ) > > > > ,
329+ max_queued_writes : usize ,
330+ array : Array ,
331+ chunk : Chunk ,
332+ chain : u64 ,
333+ ) -> Result < ( ) > {
334+ let rt_handle = handle. clone ( ) ;
335+ let spawn_write_task = handle. spawn ( async move {
336+ let mut writes_guard = queue. lock ( ) . await ;
337+
338+ while writes_guard. len ( ) >= max_queued_writes {
339+ let out = writes_guard. join_next ( ) . await ;
340+ if let Some ( out) = out {
341+ out. context ( "Failed to await previous trace write operation" ) ?
342+ . context ( "Chunk write operation failed" ) ?;
343+ } else {
344+ break ;
345+ }
346+ }
347+ writes_guard. spawn_on (
348+ async move { store_zarr_chunk_async ( array, chunk, chain) . await } ,
349+ & rt_handle,
350+ ) ;
351+ Ok ( ( ) )
352+ } ) ;
353+ let res: Result < ( ) > = handle. block_on ( spawn_write_task) ?;
354+ res?;
355+ Ok ( ( ) )
356+ }
357+
310358impl ChainStorage for ZarrAsyncChainStorage {
311359 type Finalized = ( ) ;
312360
@@ -323,20 +371,30 @@ impl ChainStorage for ZarrAsyncChainStorage {
323371 if let Some ( chunk) = buffer. reset ( ) {
324372 let array = self . arrays . warmup_draw_arrays [ key] . clone ( ) ;
325373 let chain = self . chain ;
326- let handle = self
327- . rt_handle
328- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
329- self . pending_writes . push ( handle) ;
374+
375+ queue_write (
376+ & self . rt_handle ,
377+ self . pending_writes . clone ( ) ,
378+ self . max_queued_writes ,
379+ array,
380+ chunk,
381+ chain,
382+ ) ?;
330383 }
331384 }
332385 for ( key, buffer) in self . stats_buffers . iter_mut ( ) {
333386 if let Some ( chunk) = buffer. reset ( ) {
334387 let array = self . arrays . warmup_param_arrays [ key] . clone ( ) ;
335388 let chain = self . chain ;
336- let handle = self
337- . rt_handle
338- . spawn ( async move { store_zarr_chunk_async ( array, chunk, chain) . await } ) ;
339- self . pending_writes . push ( handle) ;
389+
390+ queue_write (
391+ & self . rt_handle ,
392+ self . pending_writes . clone ( ) ,
393+ self . max_queued_writes ,
394+ array,
395+ chunk,
396+ chain,
397+ ) ?;
340398 }
341399 }
342400 self . last_sample_was_warmup = false ;
@@ -382,11 +440,12 @@ impl ChainStorage for ZarrAsyncChainStorage {
382440 }
383441
384442 // Join all pending writes
443+ let pending_writes = Arc :: into_inner ( self . pending_writes )
444+ . expect ( "Could not take ownership of pending writes queue" )
445+ . into_inner ( ) ;
385446 self . rt_handle . block_on ( async move {
386- for join_handle in self . pending_writes {
387- let _ = join_handle
388- . await
389- . context ( "Failed to await async chunk write operation" ) ?;
447+ for join_handle in pending_writes. join_all ( ) . await {
448+ let _ = join_handle. context ( "Failed to await async chunk write operation" ) ?;
390449 }
391450 Ok :: < ( ) , anyhow:: Error > ( ( ) )
392451 } ) ?;
@@ -420,6 +479,19 @@ impl ChainStorage for ZarrAsyncChainStorage {
420479 }
421480 }
422481
482+ // Join all pending writes
483+ let pending_writes = self . pending_writes . clone ( ) ;
484+ self . rt_handle . block_on ( async move {
485+ let mut pending_writes = pending_writes. lock ( ) . await ;
486+ loop {
487+ let Some ( join_handle) = pending_writes. join_next ( ) . await else {
488+ break ;
489+ } ;
490+ let _ = join_handle. context ( "Failed to await async chunk write operation" ) ?;
491+ }
492+ Ok :: < ( ) , anyhow:: Error > ( ( ) )
493+ } ) ?;
494+
423495 Ok ( ( ) )
424496 }
425497}
0 commit comments