Skip to content

Commit 0564bd7

Browse files
authored
feat: Add EngineDataArrowExt and use it everywhere (#1516)
## What changes are proposed in this pull request? We had tons of places in the code where we would downcast `EngineData` into `ArrowEngineData` and the pull out the `RecordBatch`. In addition, trying to write a simple example for using the default engine is ugly due to needing such a downcast in user code. So this adds a new trait which allows just doing `try_into_record_batch()` which hides the downcast ugliness. Question to reviews, we can use this in two ways in some places. Like: ```rust scan_results.map(EngineDataArrowExt::try_into_record_batch).try_collect() ``` or like: ```rust scan_results.map(|data| data?.try_into_record_batch()).try_collect() ``` Do you have a preference? ### This PR affects the following public APIs Adds a new extension trait which provides a new `try_into_record_batch()` public method. ## How was this change tested? Existing tests
1 parent def7919 commit 0564bd7

File tree

16 files changed

+88
-163
lines changed

16 files changed

+88
-163
lines changed

acceptance/src/data.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ use delta_kernel::arrow::array::{Array, RecordBatch};
44
use delta_kernel::arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn};
55
use delta_kernel::arrow::datatypes::{DataType, Schema};
66

7+
use delta_kernel::engine::arrow_data::EngineDataArrowExt as _;
78
use delta_kernel::parquet::arrow::async_reader::{
89
ParquetObjectReader, ParquetRecordBatchStreamBuilder,
910
};
1011
use delta_kernel::snapshot::Snapshot;
11-
use delta_kernel::{engine::arrow_data::ArrowEngineData, DeltaResult, Engine, Error};
12+
use delta_kernel::{DeltaResult, Engine, Error};
1213
use futures::{stream::TryStreamExt, StreamExt};
1314
use itertools::Itertools;
1415
use object_store::{local::LocalFileSystem, ObjectStore};
@@ -119,11 +120,7 @@ pub async fn assert_scan_metadata(
119120
let batches: Vec<RecordBatch> = scan
120121
.execute(engine)?
121122
.map(|data| -> DeltaResult<_> {
122-
let record_batch: RecordBatch = data?
123-
.into_any()
124-
.downcast::<ArrowEngineData>()
125-
.unwrap()
126-
.into();
123+
let record_batch = data?.try_into_record_batch()?;
127124
if schema.is_none() {
128125
schema = Some(record_batch.schema());
129126
}

ffi/src/engine_data.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! EngineData related ffi code
2-
32
#[cfg(feature = "default-engine-base")]
43
use delta_kernel::arrow;
54
#[cfg(feature = "default-engine-base")]
@@ -8,7 +7,7 @@ use delta_kernel::arrow::array::{
87
ArrayData, RecordBatch, StructArray,
98
};
109
#[cfg(feature = "default-engine-base")]
11-
use delta_kernel::engine::arrow_data::ArrowEngineData;
10+
use delta_kernel::engine::arrow_data::{ArrowEngineData, EngineDataArrowExt as _};
1211
#[cfg(feature = "default-engine-base")]
1312
use delta_kernel::DeltaResult;
1413
use delta_kernel::EngineData;
@@ -96,11 +95,7 @@ pub unsafe extern "C" fn get_raw_arrow_data(
9695
// TODO: This method leaks the returned pointer memory. How will the engine free it?
9796
#[cfg(feature = "default-engine-base")]
9897
fn get_raw_arrow_data_impl(data: Box<dyn EngineData>) -> DeltaResult<*mut ArrowFFIData> {
99-
let record_batch: delta_kernel::arrow::array::RecordBatch = data
100-
.into_any()
101-
.downcast::<ArrowEngineData>()
102-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
103-
.into();
98+
let record_batch = data.try_into_record_batch()?;
10499
let sa: StructArray = record_batch.into();
105100
let array_data: ArrayData = sa.into();
106101
// these call `clone`. is there a way to not copy anything and what exactly are they cloning?

ffi/src/table_changes.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
use std::sync::Arc;
44
use std::sync::Mutex;
55

6-
use delta_kernel::arrow::array::{Array, ArrayData, RecordBatch, StructArray};
6+
use delta_kernel::arrow::array::{Array, ArrayData, StructArray};
77
use delta_kernel::arrow::ffi::to_ffi;
8-
use delta_kernel::engine::arrow_data::ArrowEngineData;
8+
use delta_kernel::engine::arrow_data::EngineDataArrowExt;
99
use delta_kernel::table_changes::scan::TableChangesScan;
1010
use delta_kernel::table_changes::TableChanges;
1111
use delta_kernel::EngineData;
@@ -319,11 +319,7 @@ fn scan_table_changes_next_impl(data: &ScanTableChangesIterator) -> DeltaResult<
319319
return Ok(ArrowFFIData::empty());
320320
};
321321

322-
let record_batch: RecordBatch = data
323-
.into_any()
324-
.downcast::<ArrowEngineData>()
325-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
326-
.into();
322+
let record_batch = data.try_into_record_batch()?;
327323

328324
let batch_struct_array: StructArray = record_batch.into();
329325
let array_data: ArrayData = batch_struct_array.into_data();
@@ -346,6 +342,7 @@ mod tests {
346342
use delta_kernel::arrow::record_batch::RecordBatch;
347343
use delta_kernel::arrow::util::pretty::pretty_format_batches;
348344
use delta_kernel::engine::arrow_conversion::TryIntoArrow as _;
345+
use delta_kernel::engine::arrow_data::ArrowEngineData;
349346
use delta_kernel::engine::default::DefaultEngine;
350347
use delta_kernel::schema::{DataType, StructField, StructType};
351348
use delta_kernel::Engine;
@@ -355,7 +352,7 @@ mod tests {
355352
use std::sync::Arc;
356353
use test_utils::{
357354
actions_to_string_with_metadata, add_commit, generate_batch, record_batch_to_bytes,
358-
to_arrow, IntoArray as _, TestAction,
355+
IntoArray as _, TestAction,
359356
};
360357

361358
const PARQUET_FILE1: &str =
@@ -480,7 +477,7 @@ mod tests {
480477
) -> DeltaResult<Vec<RecordBatch>> {
481478
let scan_results = scan.execute(engine)?;
482479
scan_results
483-
.map(|data| -> DeltaResult<_> { to_arrow(data?) })
480+
.map(EngineDataArrowExt::try_into_record_batch)
484481
.try_collect()
485482
}
486483

@@ -699,7 +696,7 @@ mod tests {
699696
}
700697
let engine_data =
701698
ok_or_panic(unsafe { get_engine_data(data.array, &data.schema, allocate_err) });
702-
let record_batch = unsafe { to_arrow(engine_data.into_inner()) }?;
699+
let record_batch = unsafe { engine_data.into_inner().try_into_record_batch() }?;
703700

704701
println!("Batch ({i}) num rows {:?}", record_batch.num_rows());
705702
batches.push(record_batch);

kernel/examples/read-table-changes/src/main.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use clap::Parser;
44
use common::{LocationArgs, ParseWithExamples};
55
use delta_kernel::arrow::array::RecordBatch;
66
use delta_kernel::arrow::util::pretty::print_batches;
7-
use delta_kernel::engine::arrow_data::ArrowEngineData;
7+
use delta_kernel::engine::arrow_data::EngineDataArrowExt;
88
use delta_kernel::table_changes::TableChanges;
99
use delta_kernel::DeltaResult;
1010
use itertools::Itertools;
@@ -38,14 +38,7 @@ fn main() -> DeltaResult<()> {
3838
let table_changes_scan = table_changes.into_scan_builder().build()?;
3939
let batches: Vec<RecordBatch> = table_changes_scan
4040
.execute(Arc::new(engine))?
41-
.map(|data| -> DeltaResult<_> {
42-
let record_batch: RecordBatch = data?
43-
.into_any()
44-
.downcast::<ArrowEngineData>()
45-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
46-
.into();
47-
Ok(record_batch)
48-
})
41+
.map(EngineDataArrowExt::try_into_record_batch)
4942
.try_collect()?;
5043
print_batches(&batches)?;
5144
Ok(())

kernel/examples/read-table-multi-threaded/src/main.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ use arrow::record_batch::RecordBatch;
99
use arrow::util::pretty::print_batches;
1010
use common::{LocationArgs, ParseWithExamples, ScanArgs};
1111
use delta_kernel::actions::deletion_vector::split_vector;
12-
use delta_kernel::engine::arrow_data::ArrowEngineData;
12+
use delta_kernel::engine::arrow_data::EngineDataArrowExt as _;
1313
use delta_kernel::scan::state::{transform_to_logical, DvInfo, Stats};
1414
use delta_kernel::schema::SchemaRef;
15-
use delta_kernel::{DeltaResult, Engine, EngineData, ExpressionRef, FileMeta, Snapshot};
15+
use delta_kernel::{DeltaResult, Engine, ExpressionRef, FileMeta, Snapshot};
1616

1717
use clap::Parser;
1818
use url::Url;
@@ -59,15 +59,6 @@ struct ScanFile {
5959
dv_info: DvInfo,
6060
}
6161

62-
// we know we're using arrow under the hood, so cast an EngineData into something we can work with
63-
fn to_arrow(data: Box<dyn EngineData>) -> DeltaResult<RecordBatch> {
64-
Ok(data
65-
.into_any()
66-
.downcast::<ArrowEngineData>()
67-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
68-
.into())
69-
}
70-
7162
// This is the callback that will be called for each valid scan row
7263
fn send_scan_file(
7364
scan_tx: &mut spmc::Sender<ScanFile>,
@@ -231,7 +222,7 @@ fn do_work(
231222
)
232223
.unwrap();
233224

234-
let record_batch = to_arrow(logical).unwrap();
225+
let record_batch = logical.try_into_record_batch().unwrap();
235226

236227
// need to split the dv_mask. what's left in dv_mask covers this result, and rest
237228
// will cover the following results

kernel/examples/read-table-single-threaded/src/main.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::sync::Arc;
44
use arrow::record_batch::RecordBatch;
55
use arrow::util::pretty::print_batches;
66
use common::{LocationArgs, ParseWithExamples, ScanArgs};
7-
use delta_kernel::engine::arrow_data::ArrowEngineData;
7+
use delta_kernel::engine::arrow_data::EngineDataArrowExt;
88
use delta_kernel::{DeltaResult, Snapshot};
99

1010
use clap::Parser;
@@ -49,15 +49,7 @@ fn try_main() -> DeltaResult<()> {
4949
let mut rows_so_far = 0;
5050
let batches: Vec<RecordBatch> = scan
5151
.execute(Arc::new(engine))?
52-
.map(|data| -> DeltaResult<_> {
53-
// extract the batches and filter them if they have deletion vectors
54-
let record_batch: RecordBatch = data?
55-
.into_any()
56-
.downcast::<ArrowEngineData>()
57-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
58-
.into();
59-
Ok(record_batch)
60-
})
52+
.map(EngineDataArrowExt::try_into_record_batch)
6153
.scan(&mut rows_so_far, |rows_so_far, record_batch| {
6254
// handle truncation if we've specified a limit
6355
let Ok(batch) = record_batch else {

kernel/examples/write-table/src/main.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use uuid::Uuid;
1616
use delta_kernel::arrow::array::TimestampMicrosecondArray;
1717
use delta_kernel::committer::FileSystemCommitter;
1818
use delta_kernel::engine::arrow_conversion::TryIntoArrow;
19-
use delta_kernel::engine::arrow_data::ArrowEngineData;
19+
use delta_kernel::engine::arrow_data::{ArrowEngineData, EngineDataArrowExt};
2020
use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor;
2121
use delta_kernel::engine::default::DefaultEngine;
2222
use delta_kernel::schema::{DataType, SchemaRef, StructField, StructType};
@@ -317,14 +317,7 @@ async fn read_and_display_data(
317317

318318
let batches: Vec<RecordBatch> = scan
319319
.execute(Arc::new(engine))?
320-
.map(|data| -> DeltaResult<_> {
321-
let record_batch: RecordBatch = data?
322-
.into_any()
323-
.downcast::<ArrowEngineData>()
324-
.map_err(|_| Error::EngineDataType("ArrowEngineData".to_string()))?
325-
.into();
326-
Ok(record_batch)
327-
})
320+
.map(EngineDataArrowExt::try_into_record_batch)
328321
.try_collect()?;
329322

330323
print_batches(&batches)?;

kernel/src/actions/mod.rs

Lines changed: 21 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,13 +1017,15 @@ impl DomainMetadata {
10171017
mod tests {
10181018
use super::*;
10191019
use crate::{
1020-
arrow::array::{
1021-
Array, BooleanArray, Int32Array, Int64Array, ListArray, ListBuilder, MapBuilder,
1022-
MapFieldNames, RecordBatch, StringArray, StringBuilder, StructArray,
1020+
arrow::{
1021+
array::{
1022+
Array, BooleanArray, Int32Array, Int64Array, ListArray, ListBuilder, MapBuilder,
1023+
MapFieldNames, RecordBatch, StringArray, StringBuilder, StructArray,
1024+
},
1025+
datatypes::{DataType as ArrowDataType, Field, Schema},
1026+
json::ReaderBuilder,
10231027
},
1024-
arrow::datatypes::{DataType as ArrowDataType, Field, Schema},
1025-
arrow::json::ReaderBuilder,
1026-
engine::{arrow_data::ArrowEngineData, arrow_expression::ArrowEvaluationHandler},
1028+
engine::{arrow_data::EngineDataArrowExt as _, arrow_expression::ArrowEvaluationHandler},
10271029
schema::{ArrayType, DataType, MapType, StructField},
10281030
Engine, EvaluationHandler, JsonHandler, ParquetHandler, StorageHandler,
10291031
};
@@ -1643,13 +1645,7 @@ mod tests {
16431645

16441646
let engine_data =
16451647
set_transaction.into_engine_data(SetTransaction::to_schema().into(), &engine);
1646-
1647-
let record_batch: RecordBatch = engine_data
1648-
.unwrap()
1649-
.into_any()
1650-
.downcast::<ArrowEngineData>()
1651-
.unwrap()
1652-
.into();
1648+
let record_batch = engine_data.try_into_record_batch().unwrap();
16531649

16541650
let schema = Arc::new(Schema::new(vec![
16551651
Field::new("appId", ArrowDataType::Utf8, false),
@@ -1678,13 +1674,7 @@ mod tests {
16781674
let commit_info_txn_id = commit_info.txn_id.clone();
16791675

16801676
let engine_data = commit_info.into_engine_data(CommitInfo::to_schema().into(), &engine);
1681-
1682-
let record_batch: RecordBatch = engine_data
1683-
.unwrap()
1684-
.into_any()
1685-
.downcast::<ArrowEngineData>()
1686-
.unwrap()
1687-
.into();
1677+
let record_batch = engine_data.try_into_record_batch().unwrap();
16881678

16891679
let mut map_builder = create_string_map_builder(false);
16901680
map_builder.append(true).unwrap();
@@ -1719,13 +1709,7 @@ mod tests {
17191709

17201710
let engine_data =
17211711
domain_metadata.into_engine_data(DomainMetadata::to_schema().into(), &engine);
1722-
1723-
let record_batch: RecordBatch = engine_data
1724-
.unwrap()
1725-
.into_any()
1726-
.downcast::<ArrowEngineData>()
1727-
.unwrap()
1728-
.into();
1712+
let record_batch = engine_data.try_into_record_batch().unwrap();
17291713

17301714
let expected = RecordBatch::try_new(
17311715
record_batch.schema(),
@@ -1878,14 +1862,11 @@ mod tests {
18781862

18791863
// have to get the id since it's random
18801864
let test_id = test_metadata.id.clone();
1881-
1882-
let actual: RecordBatch = test_metadata
1865+
let actual = test_metadata
18831866
.into_engine_data(Metadata::to_schema().into(), &engine)
18841867
.unwrap()
1885-
.into_any()
1886-
.downcast::<ArrowEngineData>()
1887-
.unwrap()
1888-
.into();
1868+
.try_into_record_batch()
1869+
.unwrap();
18891870

18901871
let expected_json = json!({
18911872
"id": test_id,
@@ -1931,13 +1912,11 @@ mod tests {
19311912

19321913
// test with the full log schema that wraps metadata in a "metaData" field
19331914
let commit_schema = get_commit_schema().project(&[METADATA_NAME]).unwrap();
1934-
let actual: RecordBatch = metadata
1915+
let actual = metadata
19351916
.into_engine_data(commit_schema, &engine)
19361917
.unwrap()
1937-
.into_any()
1938-
.downcast::<ArrowEngineData>()
1939-
.unwrap()
1940-
.into();
1918+
.try_into_record_batch()
1919+
.unwrap();
19411920

19421921
let expected_json = json!({
19431922
"metaData": {
@@ -1977,12 +1956,7 @@ mod tests {
19771956
let engine_data = protocol
19781957
.clone()
19791958
.into_engine_data(Protocol::to_schema().into(), &engine);
1980-
let record_batch: RecordBatch = engine_data
1981-
.unwrap()
1982-
.into_any()
1983-
.downcast::<ArrowEngineData>()
1984-
.unwrap()
1985-
.into();
1959+
let record_batch = engine_data.try_into_record_batch().unwrap();
19861960

19871961
let list_field = Arc::new(Field::new("element", ArrowDataType::Utf8, false));
19881962
let protocol_fields = vec![
@@ -2069,12 +2043,7 @@ mod tests {
20692043
)
20702044
.unwrap();
20712045

2072-
let record_batch: RecordBatch = engine_data
2073-
.unwrap()
2074-
.into_any()
2075-
.downcast::<ArrowEngineData>()
2076-
.unwrap()
2077-
.into();
2046+
let record_batch = engine_data.try_into_record_batch().unwrap();
20782047

20792048
assert_eq!(record_batch, expected);
20802049
}
@@ -2089,11 +2058,7 @@ mod tests {
20892058
let engine_data = protocol
20902059
.into_engine_data(Protocol::to_schema().into(), &engine)
20912060
.unwrap();
2092-
let record_batch: RecordBatch = engine_data
2093-
.into_any()
2094-
.downcast::<ArrowEngineData>()
2095-
.unwrap()
2096-
.into();
2061+
let record_batch = engine_data.try_into_record_batch().unwrap();
20972062

20982063
assert_eq!(record_batch.num_rows(), 1);
20992064
assert_eq!(record_batch.num_columns(), 4);
@@ -2123,11 +2088,7 @@ mod tests {
21232088
let engine_data = protocol
21242089
.into_engine_data(Protocol::to_schema().into(), &engine)
21252090
.unwrap();
2126-
let record_batch: RecordBatch = engine_data
2127-
.into_any()
2128-
.downcast::<ArrowEngineData>()
2129-
.unwrap()
2130-
.into();
2091+
let record_batch = engine_data.try_into_record_batch().unwrap();
21312092

21322093
assert_eq!(record_batch.num_rows(), 1);
21332094
assert_eq!(record_batch.num_columns(), 4);

0 commit comments

Comments
 (0)