Skip to content

Commit 8d593aa

Browse files
author
Jan Kaul
committed
directly spawn consumers
1 parent 414a7f4 commit 8d593aa

File tree

3 files changed

+69
-294
lines changed

3 files changed

+69
-294
lines changed

datafusion_iceberg/tests/insert_csv.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,4 @@ async fn test_insert_csv() {
163163
}
164164

165165
assert!(once);
166-
panic!();
167166
}
168-

iceberg-rust/src/arrow/partition.rs

Lines changed: 2 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,7 @@
88
//! * Efficient handling of distinct partition values
99
//! * Automatic management of partition streams and channels
1010
11-
use std::{
12-
collections::{HashSet, VecDeque},
13-
hash::Hash,
14-
pin::Pin,
15-
task::{Context, Poll},
16-
thread::available_parallelism,
17-
};
11+
use std::{collections::HashSet, hash::Hash};
1812

1913
use arrow::{
2014
array::{
@@ -29,136 +23,12 @@ use arrow::{
2923
error::ArrowError,
3024
record_batch::RecordBatch,
3125
};
32-
use futures::{
33-
channel::mpsc::{channel, Receiver, Sender},
34-
Stream,
35-
};
3626
use itertools::{iproduct, Itertools};
3727

3828
use iceberg_rust_spec::{partition::BoundPartitionField, spec::values::Value};
39-
use lru::LruCache;
40-
use pin_project_lite::pin_project;
41-
42-
use crate::error::Error;
4329

4430
use super::transform::transform_arrow;
4531

46-
type RecordBatchSender = Sender<Result<RecordBatch, ArrowError>>;
47-
type RecordBatchReceiver = Receiver<Result<RecordBatch, ArrowError>>;
48-
49-
pin_project! {
50-
/// A stream that partitions Arrow record batches according to partition field specifications
51-
///
52-
/// This struct implements Stream to process record batches asynchronously, splitting them into
53-
/// separate streams based on partition values. It maintains internal state to:
54-
/// * Track active partition streams
55-
/// * Buffer pending record batches
56-
/// * Manage channel senders/receivers for each partition
57-
///
58-
/// # Type Parameters
59-
/// * `'a` - Lifetime of the partition field specifications
60-
pub(crate) struct PartitionStream<'a> {
61-
#[pin]
62-
record_batches: Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>,
63-
partition_fields: &'a [BoundPartitionField<'a>],
64-
partition_streams: LruCache<Vec<Value>, RecordBatchSender>,
65-
queue: VecDeque<Result<(Vec<Value>, RecordBatchReceiver), Error>>,
66-
sends: Vec<(RecordBatchSender, RecordBatch)>,
67-
}
68-
}
69-
70-
impl<'a> PartitionStream<'a> {
71-
pub(crate) fn new(
72-
record_batches: Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>,
73-
partition_fields: &'a [BoundPartitionField<'a>],
74-
) -> Self {
75-
Self {
76-
record_batches,
77-
partition_fields,
78-
partition_streams: LruCache::unbounded(),
79-
queue: VecDeque::new(),
80-
sends: Vec::new(),
81-
}
82-
}
83-
}
84-
85-
impl Stream for PartitionStream<'_> {
86-
type Item = Result<(Vec<Value>, RecordBatchReceiver), Error>;
87-
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
88-
let mut this = self.project();
89-
90-
loop {
91-
if let Some(result) = this.queue.pop_front() {
92-
break Poll::Ready(Some(result));
93-
}
94-
95-
if !this.sends.is_empty() {
96-
let mut new_sends = Vec::with_capacity(this.sends.len());
97-
while let Some((mut sender, batch)) = this.sends.pop() {
98-
match sender.poll_ready(cx) {
99-
Poll::Pending => {
100-
new_sends.push((sender, batch));
101-
}
102-
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
103-
Poll::Ready(Ok(())) => {
104-
sender.start_send(Ok(batch))?;
105-
}
106-
}
107-
}
108-
*this.sends = new_sends;
109-
110-
if !this.sends.is_empty() {
111-
break Poll::Pending;
112-
}
113-
}
114-
115-
// Limit the number of open partition_streams by available parallelism
116-
if this.partition_streams.len() > available_parallelism().unwrap().get() {
117-
if let Some((_, mut sender)) = this.partition_streams.pop_lru() {
118-
sender.close_channel();
119-
}
120-
}
121-
122-
match this.record_batches.as_mut().poll_next(cx) {
123-
Poll::Pending => {
124-
break Poll::Pending;
125-
}
126-
Poll::Ready(None) => {
127-
while let Some((_, sender)) = this.partition_streams.pop_lru() {
128-
if !sender.is_closed() {
129-
sender.clone().close_channel();
130-
}
131-
}
132-
break Poll::Ready(None);
133-
}
134-
Poll::Ready(Some(Err(err))) => {
135-
break Poll::Ready(Some(Err(err.into())));
136-
}
137-
Poll::Ready(Some(Ok(batch))) => {
138-
for result in partition_record_batch(&batch, this.partition_fields)? {
139-
let (partition_values, batch) = result?;
140-
141-
let sender = if let Some(sender) =
142-
this.partition_streams.get(&partition_values).cloned()
143-
{
144-
sender
145-
} else {
146-
let (sender, reciever) = channel(1);
147-
this.queue
148-
.push_back(Ok((partition_values.clone(), reciever)));
149-
this.partition_streams
150-
.push(partition_values, sender.clone());
151-
sender
152-
};
153-
154-
this.sends.push((sender, batch));
155-
}
156-
}
157-
}
158-
}
159-
}
160-
}
161-
16232
/// Partitions a record batch according to the given partition fields.
16333
///
16434
/// This function takes a record batch and partition field specifications, then splits the batch into
@@ -178,7 +48,7 @@ impl Stream for PartitionStream<'_> {
17848
/// * Required columns are missing from the record batch
17949
/// * Transformation operations fail
18050
/// * Data type conversions fail
181-
fn partition_record_batch<'a>(
51+
pub(crate) fn partition_record_batch<'a>(
18252
record_batch: &'a RecordBatch,
18353
partition_fields: &[BoundPartitionField<'_>],
18454
) -> Result<impl Iterator<Item = Result<(Vec<Value>, RecordBatch), ArrowError>> + 'a, ArrowError> {
@@ -369,122 +239,3 @@ enum DistinctValues {
369239
Long(HashSet<i64>),
370240
String(HashSet<String>),
371241
}
372-
373-
#[cfg(test)]
374-
mod tests {
375-
use futures::{stream, StreamExt};
376-
use std::sync::Arc;
377-
use tokio::task::JoinSet;
378-
379-
use arrow::{
380-
array::{ArrayRef, Int64Array, StringArray},
381-
error::ArrowError,
382-
record_batch::RecordBatch,
383-
};
384-
385-
use iceberg_rust_spec::{
386-
partition::BoundPartitionField,
387-
spec::{
388-
partition::{PartitionField, PartitionSpec, Transform},
389-
schema::Schema,
390-
types::{PrimitiveType, StructField, Type},
391-
},
392-
};
393-
394-
use crate::{arrow::partition::PartitionStream, error::Error};
395-
396-
#[tokio::test]
397-
async fn test_partition() {
398-
let batch1 = RecordBatch::try_from_iter(vec![
399-
(
400-
"x",
401-
Arc::new(Int64Array::from(vec![1, 1, 1, 1, 2, 3])) as ArrayRef,
402-
),
403-
(
404-
"y",
405-
Arc::new(Int64Array::from(vec![1, 2, 2, 2, 1, 1])) as ArrayRef,
406-
),
407-
(
408-
"z",
409-
Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E", "F"])) as ArrayRef,
410-
),
411-
])
412-
.unwrap();
413-
let batch2 = RecordBatch::try_from_iter(vec![
414-
(
415-
"x",
416-
Arc::new(Int64Array::from(vec![1, 1, 2, 2, 2, 3])) as ArrayRef,
417-
),
418-
(
419-
"y",
420-
Arc::new(Int64Array::from(vec![1, 2, 2, 2, 1, 1])) as ArrayRef,
421-
),
422-
(
423-
"z",
424-
Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E", "F"])) as ArrayRef,
425-
),
426-
])
427-
.unwrap();
428-
let record_batches = stream::iter(
429-
vec![Ok::<_, ArrowError>(batch1), Ok::<_, ArrowError>(batch2)].into_iter(),
430-
);
431-
432-
let schema = Schema::builder()
433-
.with_schema_id(0)
434-
.with_struct_field(StructField {
435-
id: 1,
436-
name: "x".to_string(),
437-
field_type: Type::Primitive(PrimitiveType::Int),
438-
required: true,
439-
doc: None,
440-
})
441-
.with_struct_field(StructField {
442-
id: 2,
443-
name: "y".to_string(),
444-
field_type: Type::Primitive(PrimitiveType::Int),
445-
required: true,
446-
doc: None,
447-
})
448-
.with_struct_field(StructField {
449-
id: 3,
450-
name: "z".to_string(),
451-
field_type: Type::Primitive(PrimitiveType::String),
452-
required: true,
453-
doc: None,
454-
})
455-
.build()
456-
.unwrap();
457-
458-
let partition_spec = PartitionSpec::builder()
459-
.with_partition_field(PartitionField::new(1, 1001, "x", Transform::Identity))
460-
.build()
461-
.unwrap();
462-
let partition_fields = partition_spec
463-
.fields()
464-
.iter()
465-
.map(|partition_field| {
466-
let field =
467-
schema
468-
.get(*partition_field.source_id() as usize)
469-
.ok_or(Error::NotFound(format!(
470-
"Schema field with id {}",
471-
partition_field.source_id(),
472-
)))?;
473-
Ok(BoundPartitionField::new(partition_field, field))
474-
})
475-
.collect::<Result<Vec<_>, Error>>()
476-
.unwrap();
477-
let mut streams = PartitionStream::new(Box::pin(record_batches), &partition_fields);
478-
let mut set = JoinSet::new();
479-
while let Some(Ok((_, receiver))) = streams.next().await {
480-
set.spawn(async move { receiver.collect::<Vec<_>>().await });
481-
}
482-
let output = set.join_all().await;
483-
484-
for x in output {
485-
for y in x {
486-
y.unwrap();
487-
}
488-
}
489-
}
490-
}

0 commit comments

Comments
 (0)