88//! * Efficient handling of distinct partition values
99//! * Automatic management of partition streams and channels
1010
11- use std:: {
12- collections:: { HashSet , VecDeque } ,
13- hash:: Hash ,
14- pin:: Pin ,
15- task:: { Context , Poll } ,
16- thread:: available_parallelism,
17- } ;
11+ use std:: { collections:: HashSet , hash:: Hash } ;
1812
1913use arrow:: {
2014 array:: {
@@ -29,136 +23,12 @@ use arrow::{
2923 error:: ArrowError ,
3024 record_batch:: RecordBatch ,
3125} ;
32- use futures:: {
33- channel:: mpsc:: { channel, Receiver , Sender } ,
34- Stream ,
35- } ;
3626use itertools:: { iproduct, Itertools } ;
3727
3828use iceberg_rust_spec:: { partition:: BoundPartitionField , spec:: values:: Value } ;
39- use lru:: LruCache ;
40- use pin_project_lite:: pin_project;
41-
42- use crate :: error:: Error ;
4329
4430use super :: transform:: transform_arrow;
4531
46- type RecordBatchSender = Sender < Result < RecordBatch , ArrowError > > ;
47- type RecordBatchReceiver = Receiver < Result < RecordBatch , ArrowError > > ;
48-
49- pin_project ! {
50- /// A stream that partitions Arrow record batches according to partition field specifications
51- ///
52- /// This struct implements Stream to process record batches asynchronously, splitting them into
53- /// separate streams based on partition values. It maintains internal state to:
54- /// * Track active partition streams
55- /// * Buffer pending record batches
56- /// * Manage channel senders/receivers for each partition
57- ///
58- /// # Type Parameters
59- /// * `'a` - Lifetime of the partition field specifications
60- pub ( crate ) struct PartitionStream <' a> {
61- #[ pin]
62- record_batches: Pin <Box <dyn Stream <Item = Result <RecordBatch , ArrowError >> + Send >>,
63- partition_fields: & ' a [ BoundPartitionField <' a>] ,
64- partition_streams: LruCache <Vec <Value >, RecordBatchSender >,
65- queue: VecDeque <Result <( Vec <Value >, RecordBatchReceiver ) , Error >>,
66- sends: Vec <( RecordBatchSender , RecordBatch ) >,
67- }
68- }
69-
70- impl < ' a > PartitionStream < ' a > {
71- pub ( crate ) fn new (
72- record_batches : Pin < Box < dyn Stream < Item = Result < RecordBatch , ArrowError > > + Send > > ,
73- partition_fields : & ' a [ BoundPartitionField < ' a > ] ,
74- ) -> Self {
75- Self {
76- record_batches,
77- partition_fields,
78- partition_streams : LruCache :: unbounded ( ) ,
79- queue : VecDeque :: new ( ) ,
80- sends : Vec :: new ( ) ,
81- }
82- }
83- }
84-
85- impl Stream for PartitionStream < ' _ > {
86- type Item = Result < ( Vec < Value > , RecordBatchReceiver ) , Error > ;
87- fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
88- let mut this = self . project ( ) ;
89-
90- loop {
91- if let Some ( result) = this. queue . pop_front ( ) {
92- break Poll :: Ready ( Some ( result) ) ;
93- }
94-
95- if !this. sends . is_empty ( ) {
96- let mut new_sends = Vec :: with_capacity ( this. sends . len ( ) ) ;
97- while let Some ( ( mut sender, batch) ) = this. sends . pop ( ) {
98- match sender. poll_ready ( cx) {
99- Poll :: Pending => {
100- new_sends. push ( ( sender, batch) ) ;
101- }
102- Poll :: Ready ( Err ( err) ) => return Poll :: Ready ( Some ( Err ( err. into ( ) ) ) ) ,
103- Poll :: Ready ( Ok ( ( ) ) ) => {
104- sender. start_send ( Ok ( batch) ) ?;
105- }
106- }
107- }
108- * this. sends = new_sends;
109-
110- if !this. sends . is_empty ( ) {
111- break Poll :: Pending ;
112- }
113- }
114-
115- // Limit the number of open partition_streams by available parallelism
116- if this. partition_streams . len ( ) > available_parallelism ( ) . unwrap ( ) . get ( ) {
117- if let Some ( ( _, mut sender) ) = this. partition_streams . pop_lru ( ) {
118- sender. close_channel ( ) ;
119- }
120- }
121-
122- match this. record_batches . as_mut ( ) . poll_next ( cx) {
123- Poll :: Pending => {
124- break Poll :: Pending ;
125- }
126- Poll :: Ready ( None ) => {
127- while let Some ( ( _, sender) ) = this. partition_streams . pop_lru ( ) {
128- if !sender. is_closed ( ) {
129- sender. clone ( ) . close_channel ( ) ;
130- }
131- }
132- break Poll :: Ready ( None ) ;
133- }
134- Poll :: Ready ( Some ( Err ( err) ) ) => {
135- break Poll :: Ready ( Some ( Err ( err. into ( ) ) ) ) ;
136- }
137- Poll :: Ready ( Some ( Ok ( batch) ) ) => {
138- for result in partition_record_batch ( & batch, this. partition_fields ) ? {
139- let ( partition_values, batch) = result?;
140-
141- let sender = if let Some ( sender) =
142- this. partition_streams . get ( & partition_values) . cloned ( )
143- {
144- sender
145- } else {
146- let ( sender, reciever) = channel ( 1 ) ;
147- this. queue
148- . push_back ( Ok ( ( partition_values. clone ( ) , reciever) ) ) ;
149- this. partition_streams
150- . push ( partition_values, sender. clone ( ) ) ;
151- sender
152- } ;
153-
154- this. sends . push ( ( sender, batch) ) ;
155- }
156- }
157- }
158- }
159- }
160- }
161-
16232/// Partitions a record batch according to the given partition fields.
16333///
16434/// This function takes a record batch and partition field specifications, then splits the batch into
@@ -178,7 +48,7 @@ impl Stream for PartitionStream<'_> {
17848/// * Required columns are missing from the record batch
17949/// * Transformation operations fail
18050/// * Data type conversions fail
181- fn partition_record_batch < ' a > (
51+ pub ( crate ) fn partition_record_batch < ' a > (
18252 record_batch : & ' a RecordBatch ,
18353 partition_fields : & [ BoundPartitionField < ' _ > ] ,
18454) -> Result < impl Iterator < Item = Result < ( Vec < Value > , RecordBatch ) , ArrowError > > + ' a , ArrowError > {
@@ -369,122 +239,3 @@ enum DistinctValues {
369239 Long ( HashSet < i64 > ) ,
370240 String ( HashSet < String > ) ,
371241}
372-
373- #[ cfg( test) ]
374- mod tests {
375- use futures:: { stream, StreamExt } ;
376- use std:: sync:: Arc ;
377- use tokio:: task:: JoinSet ;
378-
379- use arrow:: {
380- array:: { ArrayRef , Int64Array , StringArray } ,
381- error:: ArrowError ,
382- record_batch:: RecordBatch ,
383- } ;
384-
385- use iceberg_rust_spec:: {
386- partition:: BoundPartitionField ,
387- spec:: {
388- partition:: { PartitionField , PartitionSpec , Transform } ,
389- schema:: Schema ,
390- types:: { PrimitiveType , StructField , Type } ,
391- } ,
392- } ;
393-
394- use crate :: { arrow:: partition:: PartitionStream , error:: Error } ;
395-
396- #[ tokio:: test]
397- async fn test_partition ( ) {
398- let batch1 = RecordBatch :: try_from_iter ( vec ! [
399- (
400- "x" ,
401- Arc :: new( Int64Array :: from( vec![ 1 , 1 , 1 , 1 , 2 , 3 ] ) ) as ArrayRef ,
402- ) ,
403- (
404- "y" ,
405- Arc :: new( Int64Array :: from( vec![ 1 , 2 , 2 , 2 , 1 , 1 ] ) ) as ArrayRef ,
406- ) ,
407- (
408- "z" ,
409- Arc :: new( StringArray :: from( vec![ "A" , "B" , "C" , "D" , "E" , "F" ] ) ) as ArrayRef ,
410- ) ,
411- ] )
412- . unwrap ( ) ;
413- let batch2 = RecordBatch :: try_from_iter ( vec ! [
414- (
415- "x" ,
416- Arc :: new( Int64Array :: from( vec![ 1 , 1 , 2 , 2 , 2 , 3 ] ) ) as ArrayRef ,
417- ) ,
418- (
419- "y" ,
420- Arc :: new( Int64Array :: from( vec![ 1 , 2 , 2 , 2 , 1 , 1 ] ) ) as ArrayRef ,
421- ) ,
422- (
423- "z" ,
424- Arc :: new( StringArray :: from( vec![ "A" , "B" , "C" , "D" , "E" , "F" ] ) ) as ArrayRef ,
425- ) ,
426- ] )
427- . unwrap ( ) ;
428- let record_batches = stream:: iter (
429- vec ! [ Ok :: <_, ArrowError >( batch1) , Ok :: <_, ArrowError >( batch2) ] . into_iter ( ) ,
430- ) ;
431-
432- let schema = Schema :: builder ( )
433- . with_schema_id ( 0 )
434- . with_struct_field ( StructField {
435- id : 1 ,
436- name : "x" . to_string ( ) ,
437- field_type : Type :: Primitive ( PrimitiveType :: Int ) ,
438- required : true ,
439- doc : None ,
440- } )
441- . with_struct_field ( StructField {
442- id : 2 ,
443- name : "y" . to_string ( ) ,
444- field_type : Type :: Primitive ( PrimitiveType :: Int ) ,
445- required : true ,
446- doc : None ,
447- } )
448- . with_struct_field ( StructField {
449- id : 3 ,
450- name : "z" . to_string ( ) ,
451- field_type : Type :: Primitive ( PrimitiveType :: String ) ,
452- required : true ,
453- doc : None ,
454- } )
455- . build ( )
456- . unwrap ( ) ;
457-
458- let partition_spec = PartitionSpec :: builder ( )
459- . with_partition_field ( PartitionField :: new ( 1 , 1001 , "x" , Transform :: Identity ) )
460- . build ( )
461- . unwrap ( ) ;
462- let partition_fields = partition_spec
463- . fields ( )
464- . iter ( )
465- . map ( |partition_field| {
466- let field =
467- schema
468- . get ( * partition_field. source_id ( ) as usize )
469- . ok_or ( Error :: NotFound ( format ! (
470- "Schema field with id {}" ,
471- partition_field. source_id( ) ,
472- ) ) ) ?;
473- Ok ( BoundPartitionField :: new ( partition_field, field) )
474- } )
475- . collect :: < Result < Vec < _ > , Error > > ( )
476- . unwrap ( ) ;
477- let mut streams = PartitionStream :: new ( Box :: pin ( record_batches) , & partition_fields) ;
478- let mut set = JoinSet :: new ( ) ;
479- while let Some ( Ok ( ( _, receiver) ) ) = streams. next ( ) . await {
480- set. spawn ( async move { receiver. collect :: < Vec < _ > > ( ) . await } ) ;
481- }
482- let output = set. join_all ( ) . await ;
483-
484- for x in output {
485- for y in x {
486- y. unwrap ( ) ;
487- }
488- }
489- }
490- }
0 commit comments