Skip to content

Commit d0d1120

Browse files
committed
fix: Handle backpressure in async zarr storage
1 parent 4f0d50e commit d0d1120

File tree

2 files changed

+109
-26
lines changed

2 files changed

+109
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ nuts-derive = { path = "./nuts-derive", version = "0.1.0" }
4242
nuts-storable = { path = "./nuts-storable", version = "0.2.0" }
4343
serde = { version = "1.0.219", features = ["derive"] }
4444
serde_json = "1.0"
45-
tokio = { version = "1.0", features = ["rt"], optional = true }
45+
tokio = { version = "1.0", features = ["rt", "sync", "fs"], optional = true }
4646

4747
[dev-dependencies]
4848
proptest = "1.6.0"

src/storage/zarr/async_impl.rs

Lines changed: 108 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use std::collections::HashMap;
22
use std::iter::once;
33
use std::num::NonZero;
44
use std::sync::Arc;
5-
use tokio::task::JoinHandle;
5+
use tokio::runtime::Handle;
6+
use tokio::task::JoinSet;
67

78
use anyhow::{Context, Result};
89
use nuts_storable::{ItemType, Value};
@@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage {
4344
arrays: Arc<ArrayCollection>,
4445
chain: u64,
4546
last_sample_was_warmup: bool,
46-
pending_writes: Vec<JoinHandle<Result<()>>>,
47+
pending_writes: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
4748
rt_handle: tokio::runtime::Handle,
49+
max_queued_writes: usize,
4850
}
4951

5052
/// Write a chunk of data to a Zarr array asynchronously
@@ -240,22 +242,28 @@ impl ZarrAsyncChainStorage {
240242
chain: u64,
241243
rt_handle: tokio::runtime::Handle,
242244
) -> Self {
243-
let draw_buffers = draw_types
245+
let draw_buffers: HashMap<String, SampleBuffer> = draw_types
244246
.iter()
245247
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
246248
.collect();
247249

248-
let stats_buffers = param_types
250+
let stats_buffers: HashMap<String, SampleBuffer> = param_types
249251
.iter()
250252
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
251253
.collect();
254+
255+
let num_arrays = draw_buffers.len() + stats_buffers.len();
256+
252257
Self {
253258
draw_buffers,
254259
stats_buffers,
255260
arrays,
256261
chain,
257262
last_sample_was_warmup: true,
258-
pending_writes: Vec::new(),
263+
pending_writes: Arc::new(tokio::sync::Mutex::new(JoinSet::new())),
264+
// We allow up to the number of arrays in pending writes, so
265+
// that we queue one write per draw.
266+
max_queued_writes: num_arrays.max(1),
259267
rt_handle,
260268
}
261269
}
@@ -275,10 +283,15 @@ impl ZarrAsyncChainStorage {
275283
self.arrays.sample_param_arrays[name].clone()
276284
};
277285
let chain = self.chain;
278-
let handle = self
279-
.rt_handle
280-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
281-
self.pending_writes.push(handle);
286+
287+
queue_write(
288+
&self.rt_handle,
289+
self.pending_writes.clone(),
290+
self.max_queued_writes,
291+
array,
292+
chunk,
293+
chain,
294+
)?;
282295
}
283296
Ok(())
284297
}
@@ -298,15 +311,57 @@ impl ZarrAsyncChainStorage {
298311
self.arrays.sample_draw_arrays[name].clone()
299312
};
300313
let chain = self.chain;
301-
let handle = self
302-
.rt_handle
303-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
304-
self.pending_writes.push(handle);
314+
315+
queue_write(
316+
&self.rt_handle,
317+
self.pending_writes.clone(),
318+
self.max_queued_writes,
319+
array,
320+
chunk,
321+
chain,
322+
)?;
305323
}
306324
Ok(())
307325
}
308326
}
309327

328+
fn queue_write(
329+
handle: &Handle,
330+
queue: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
331+
max_queued_writes: usize,
332+
array: Array,
333+
chunk: Chunk,
334+
chain: u64,
335+
) -> Result<()> {
336+
let rt_handle = handle.clone();
337+
// We need an async task to interface with the async storage
338+
// and JoinSet API.
339+
let spawn_write_task = handle.spawn(async move {
340+
// This should never actually block, because this lock
341+
// is only held in tasks that are spawned and immediately blocked_on
342+
// from the sampling thread.
343+
let mut writes_guard = queue.lock().await;
344+
345+
while writes_guard.len() >= max_queued_writes {
346+
let out = writes_guard.join_next().await;
347+
if let Some(out) = out {
348+
out.context("Failed to await previous trace write operation")?
349+
.context("Chunk write operation failed")?;
350+
} else {
351+
break;
352+
}
353+
}
354+
writes_guard.spawn_on(
355+
async move { store_zarr_chunk_async(array, chunk, chain).await },
356+
&rt_handle,
357+
);
358+
Ok(())
359+
});
360+
let res: Result<()> = handle.block_on(spawn_write_task)?;
361+
res?;
362+
Ok(())
363+
}
364+
310365
impl ChainStorage for ZarrAsyncChainStorage {
311366
type Finalized = ();
312367

@@ -323,20 +378,30 @@ impl ChainStorage for ZarrAsyncChainStorage {
323378
if let Some(chunk) = buffer.reset() {
324379
let array = self.arrays.warmup_draw_arrays[key].clone();
325380
let chain = self.chain;
326-
let handle = self
327-
.rt_handle
328-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
329-
self.pending_writes.push(handle);
381+
382+
queue_write(
383+
&self.rt_handle,
384+
self.pending_writes.clone(),
385+
self.max_queued_writes,
386+
array,
387+
chunk,
388+
chain,
389+
)?;
330390
}
331391
}
332392
for (key, buffer) in self.stats_buffers.iter_mut() {
333393
if let Some(chunk) = buffer.reset() {
334394
let array = self.arrays.warmup_param_arrays[key].clone();
335395
let chain = self.chain;
336-
let handle = self
337-
.rt_handle
338-
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
339-
self.pending_writes.push(handle);
396+
397+
queue_write(
398+
&self.rt_handle,
399+
self.pending_writes.clone(),
400+
self.max_queued_writes,
401+
array,
402+
chunk,
403+
chain,
404+
)?;
340405
}
341406
}
342407
self.last_sample_was_warmup = false;
@@ -382,11 +447,14 @@ impl ChainStorage for ZarrAsyncChainStorage {
382447
}
383448

384449
// Join all pending writes
450+
// All tasks that hold a reference to the queue are blocked_on
451+
// right away, so we hold the only refercne to `self.pending_writes`.
452+
let pending_writes = Arc::into_inner(self.pending_writes)
453+
.expect("Could not take ownership of pending writes queue")
454+
.into_inner();
385455
self.rt_handle.block_on(async move {
386-
for join_handle in self.pending_writes {
387-
let _ = join_handle
388-
.await
389-
.context("Failed to await async chunk write operation")?;
456+
for join_handle in pending_writes.join_all().await {
457+
let _ = join_handle.context("Failed to await async chunk write operation")?;
390458
}
391459
Ok::<(), anyhow::Error>(())
392460
})?;
@@ -420,6 +488,21 @@ impl ChainStorage for ZarrAsyncChainStorage {
420488
}
421489
}
422490

491+
// Join all pending writes
492+
let pending_writes = self.pending_writes.clone();
493+
self.rt_handle.block_on(async move {
494+
let mut pending_writes = pending_writes.lock().await;
495+
loop {
496+
let Some(join_handle) = pending_writes.join_next().await else {
497+
break;
498+
};
499+
join_handle
500+
.context("Failed to await async chunk write operation")?
501+
.context("Chunk write operation failed")?;
502+
}
503+
Ok::<(), anyhow::Error>(())
504+
})?;
505+
423506
Ok(())
424507
}
425508
}

0 commit comments

Comments
 (0)