Skip to content

Commit 36d05b3

Browse files
committed
fix: Handle backpressure in async zarr storage
1 parent 4f0d50e commit 36d05b3

File tree

1 file changed

+97
-25
lines changed

1 file changed

+97
-25
lines changed

src/storage/zarr/async_impl.rs

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use std::collections::HashMap;
22
use std::iter::once;
33
use std::num::NonZero;
44
use std::sync::Arc;
5-
use tokio::task::JoinHandle;
5+
use tokio::runtime::Handle;
6+
use tokio::task::JoinSet;
67

78
use anyhow::{Context, Result};
89
use nuts_storable::{ItemType, Value};
@@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage {
4344
arrays: Arc<ArrayCollection>,
4445
chain: u64,
4546
last_sample_was_warmup: bool,
46-
pending_writes: Vec<JoinHandle<Result<()>>>,
47+
pending_writes: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
4748
rt_handle: tokio::runtime::Handle,
49+
max_queued_writes: usize,
4850
}
4951

5052
/// Write a chunk of data to a Zarr array asynchronously
@@ -240,22 +242,26 @@ impl ZarrAsyncChainStorage {
240242
chain: u64,
241243
rt_handle: tokio::runtime::Handle,
242244
) -> Self {
243-
let draw_buffers = draw_types
245+
let draw_buffers: HashMap<String, SampleBuffer> = draw_types
244246
.iter()
245247
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
246248
.collect();
247249

248-
let stats_buffers = param_types
250+
let stats_buffers: HashMap<String, SampleBuffer> = param_types
249251
.iter()
250252
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
251253
.collect();
254+
255+
let num_arrays = draw_buffers.len() + stats_buffers.len();
256+
252257
Self {
253258
draw_buffers,
254259
stats_buffers,
255260
arrays,
256261
chain,
257262
last_sample_was_warmup: true,
258-
pending_writes: Vec::new(),
263+
pending_writes: Arc::new(tokio::sync::Mutex::new(JoinSet::new())),
264+
max_queued_writes: 2 * (num_arrays.max(1)),
259265
rt_handle,
260266
}
261267
}
@@ -275,10 +281,15 @@ impl ZarrAsyncChainStorage {
275281
self.arrays.sample_param_arrays[name].clone()
276282
};
277283
let chain = self.chain;
278-
let handle = self
279-
.rt_handle
280-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
281-
self.pending_writes.push(handle);
284+
285+
queue_write(
286+
&self.rt_handle,
287+
self.pending_writes.clone(),
288+
self.max_queued_writes,
289+
array,
290+
chunk,
291+
chain,
292+
)?;
282293
}
283294
Ok(())
284295
}
@@ -298,15 +309,52 @@ impl ZarrAsyncChainStorage {
298309
self.arrays.sample_draw_arrays[name].clone()
299310
};
300311
let chain = self.chain;
301-
let handle = self
302-
.rt_handle
303-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
304-
self.pending_writes.push(handle);
312+
313+
queue_write(
314+
&self.rt_handle,
315+
self.pending_writes.clone(),
316+
self.max_queued_writes,
317+
array,
318+
chunk,
319+
chain,
320+
)?;
305321
}
306322
Ok(())
307323
}
308324
}
309325

326+
fn queue_write(
327+
handle: &Handle,
328+
queue: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
329+
max_queued_writes: usize,
330+
array: Array,
331+
chunk: Chunk,
332+
chain: u64,
333+
) -> Result<()> {
334+
let rt_handle = handle.clone();
335+
let spawn_write_task = handle.spawn(async move {
336+
let mut writes_guard = queue.lock().await;
337+
338+
while writes_guard.len() >= max_queued_writes {
339+
let out = writes_guard.join_next().await;
340+
if let Some(out) = out {
341+
out.context("Failed to await previous trace write operation")?
342+
.context("Chunk write operation failed")?;
343+
} else {
344+
break;
345+
}
346+
}
347+
writes_guard.spawn_on(
348+
async move { store_zarr_chunk_async(array, chunk, chain).await },
349+
&rt_handle,
350+
);
351+
Ok(())
352+
});
353+
let res: Result<()> = handle.block_on(spawn_write_task)?;
354+
res?;
355+
Ok(())
356+
}
357+
310358
impl ChainStorage for ZarrAsyncChainStorage {
311359
type Finalized = ();
312360

@@ -323,20 +371,30 @@ impl ChainStorage for ZarrAsyncChainStorage {
323371
if let Some(chunk) = buffer.reset() {
324372
let array = self.arrays.warmup_draw_arrays[key].clone();
325373
let chain = self.chain;
326-
let handle = self
327-
.rt_handle
328-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
329-
self.pending_writes.push(handle);
374+
375+
queue_write(
376+
&self.rt_handle,
377+
self.pending_writes.clone(),
378+
self.max_queued_writes,
379+
array,
380+
chunk,
381+
chain,
382+
)?;
330383
}
331384
}
332385
for (key, buffer) in self.stats_buffers.iter_mut() {
333386
if let Some(chunk) = buffer.reset() {
334387
let array = self.arrays.warmup_param_arrays[key].clone();
335388
let chain = self.chain;
336-
let handle = self
337-
.rt_handle
338-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
339-
self.pending_writes.push(handle);
389+
390+
queue_write(
391+
&self.rt_handle,
392+
self.pending_writes.clone(),
393+
self.max_queued_writes,
394+
array,
395+
chunk,
396+
chain,
397+
)?;
340398
}
341399
}
342400
self.last_sample_was_warmup = false;
@@ -382,11 +440,12 @@ impl ChainStorage for ZarrAsyncChainStorage {
382440
}
383441

384442
// Join all pending writes
443+
let pending_writes = Arc::into_inner(self.pending_writes)
444+
.expect("Could not take ownership of pending writes queue")
445+
.into_inner();
385446
self.rt_handle.block_on(async move {
386-
for join_handle in self.pending_writes {
387-
let _ = join_handle
388-
.await
389-
.context("Failed to await async chunk write operation")?;
447+
for join_handle in pending_writes.join_all().await {
448+
let _ = join_handle.context("Failed to await async chunk write operation")?;
390449
}
391450
Ok::<(), anyhow::Error>(())
392451
})?;
@@ -420,6 +479,19 @@ impl ChainStorage for ZarrAsyncChainStorage {
420479
}
421480
}
422481

482+
// Join all pending writes
483+
let pending_writes = self.pending_writes.clone();
484+
self.rt_handle.block_on(async move {
485+
let mut pending_writes = pending_writes.lock().await;
486+
loop {
487+
let Some(join_handle) = pending_writes.join_next().await else {
488+
break;
489+
};
490+
let _ = join_handle.context("Failed to await async chunk write operation")?;
491+
}
492+
Ok::<(), anyhow::Error>(())
493+
})?;
494+
423495
Ok(())
424496
}
425497
}

0 commit comments

Comments
 (0)