@@ -47,7 +47,15 @@ use async_trait::async_trait;
4747use futures:: { ready, Stream , StreamExt , TryStreamExt } ;
4848
4949/// Data of the left side
50- type JoinLeftData = ( RecordBatch , MemoryReservation ) ;
50+ #[ derive( Debug ) ]
51+ struct JoinLeftData {
52+ /// Single RecordBatch with all rows from the left side
53+ merged_batch : RecordBatch ,
54+ /// Track memory reservation for merged_batch. Relies on drop
55+ /// semantics to release reservation when JoinLeftData is dropped.
56+ #[ allow( dead_code) ]
57+ reservation : MemoryReservation ,
58+ }
5159
5260#[ allow( rustdoc:: private_intra_doc_links) ]
5361/// executes partitions in parallel and combines them into a set of
@@ -185,7 +193,10 @@ async fn load_left_input(
185193
186194 let merged_batch = concat_batches ( & left_schema, & batches) ?;
187195
188- Ok ( ( merged_batch, reservation) )
196+ Ok ( JoinLeftData {
197+ merged_batch,
198+ reservation,
199+ } )
189200}
190201
191202impl DisplayAs for CrossJoinExec {
@@ -357,7 +368,7 @@ struct CrossJoinStream<T> {
357368 join_metrics : BuildProbeJoinMetrics ,
358369 /// State of the stream
359370 state : CrossJoinStreamState ,
360- /// Left data
371+ /// Left data (copy of the entire buffered left side)
361372 left_data : RecordBatch ,
362373 /// Batch transformer
363374 batch_transformer : T ,
@@ -457,16 +468,17 @@ impl<T: BatchTransformer> CrossJoinStream<T> {
457468 cx : & mut std:: task:: Context < ' _ > ,
458469 ) -> Poll < Result < StatefulStreamResult < Option < RecordBatch > > > > {
459470 let build_timer = self . join_metrics . build_time . timer ( ) ;
460- let ( left_data, _ ) = match ready ! ( self . left_fut. get( cx) ) {
471+ let left_data = match ready ! ( self . left_fut. get( cx) ) {
461472 Ok ( left_data) => left_data,
462473 Err ( e) => return Poll :: Ready ( Err ( e) ) ,
463474 } ;
464475 build_timer. done ( ) ;
465476
477+ let left_data = left_data. merged_batch . clone ( ) ;
466478 let result = if left_data. num_rows ( ) == 0 {
467479 StatefulStreamResult :: Ready ( None )
468480 } else {
469- self . left_data = left_data. clone ( ) ;
481+ self . left_data = left_data;
470482 self . state = CrossJoinStreamState :: FetchProbeBatch ;
471483 StatefulStreamResult :: Continue
472484 } ;
0 commit comments