diff --git a/common/src/main/java/org/apache/comet/vector/CometListVector.java b/common/src/main/java/org/apache/comet/vector/CometListVector.java index 752495c0d8..93e8e8bf9f 100644 --- a/common/src/main/java/org/apache/comet/vector/CometListVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometListVector.java @@ -45,6 +45,7 @@ public CometListVector( @Override public ColumnarArray getArray(int i) { + if (isNullAt(i)) return null; int start = listVector.getOffsetBuffer().getInt(i * ListVector.OFFSET_WIDTH); int end = listVector.getOffsetBuffer().getInt((i + 1) * ListVector.OFFSET_WIDTH); diff --git a/common/src/main/java/org/apache/comet/vector/CometMapVector.java b/common/src/main/java/org/apache/comet/vector/CometMapVector.java index 1d531ca903..c5984a4dcb 100644 --- a/common/src/main/java/org/apache/comet/vector/CometMapVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometMapVector.java @@ -65,6 +65,7 @@ public CometMapVector( @Override public ColumnarMap getMap(int i) { + if (isNullAt(i)) return null; int start = mapVector.getOffsetBuffer().getInt(i * MapVector.OFFSET_WIDTH); int end = mapVector.getOffsetBuffer().getInt((i + 1) * MapVector.OFFSET_WIDTH); diff --git a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java index f3803d53a9..2a30be1b1c 100644 --- a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java @@ -123,6 +123,7 @@ public double getDouble(int rowId) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (!isBaseFixedWidthVector) { BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector; long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress(); @@ -147,6 +148,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int offset; int length; if (valueVector instanceof BaseVariableWidthVector) { diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 0c6fa8f12d..a1f75696f6 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -85,6 +85,7 @@ public boolean isFixedLength() { @Override public Decimal getDecimal(int i, int precision, int scale) { + if (isNullAt(i)) return null; if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type instanceof IntegerType) { return createDecimal(getInt(i), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { diff --git a/native/spark-expr/src/array_funcs/array_insert.rs b/native/spark-expr/src/array_funcs/array_insert.rs index eb96fec12f..dcee441cea 100644 --- a/native/spark-expr/src/array_funcs/array_insert.rs +++ b/native/spark-expr/src/array_funcs/array_insert.rs @@ -16,11 +16,10 @@ // under the License. use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use arrow::{ array::{as_primitive_array, Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, - datatypes::ArrowNativeType, record_batch::RecordBatch, }; use datafusion::common::{ @@ -198,114 +197,131 @@ fn array_insert( pos_array: &ArrayRef, legacy_mode: bool, ) -> DataFusionResult { - // The code is based on the implementation of the array_append from the Apache DataFusion - // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 - // - // This code is also based on the implementation of the array_insert from the Apache Spark - // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713 + // Implementation aligned with Arrow's half-open offset ranges and Spark semantics. let values = list_array.values(); let offsets = list_array.offsets(); let values_data = values.to_data(); let item_data = items_array.to_data(); + + // Estimate capacity (original values + inserted items upper bound) let new_capacity = Capacities::Array(values_data.len() + item_data.len()); let mut mutable_values = MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity); - let mut new_offsets = vec![O::usize_as(0)]; - let mut new_nulls = Vec::::with_capacity(list_array.len()); + // New offsets and top-level list validity bitmap + let mut new_offsets = Vec::with_capacity(list_array.len() + 1); + new_offsets.push(O::usize_as(0)); + let mut list_valid = Vec::::with_capacity(list_array.len()); - let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions + // Spark supports only Int32 position indices + let pos_data: &Int32Array = as_primitive_array(&pos_array); - for (row_index, offset_window) in offsets.windows(2).enumerate() { - let pos = pos_data.values()[row_index]; - let start = offset_window[0].as_usize(); - let end = offset_window[1].as_usize(); - let is_item_null = items_array.is_null(row_index); + for (row_index, window) in offsets.windows(2).enumerate() { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + let len = end - start; + + // Return null for the entire row when pos is null (consistent with Spark's behavior) + if pos_data.is_null(row_index) { + new_offsets.push(new_offsets[row_index]); + list_valid.push(false); + continue; + } + let pos = pos_data.value(row_index); if list_array.is_null(row_index) { - // In Spark if value of the array is NULL than nothing happens - mutable_values.extend_nulls(1); - new_offsets.push(new_offsets[row_index] + O::one()); - new_nulls.push(false); + // Top-level list row is NULL: do not write any child values and do not advance offset + new_offsets.push(new_offsets[row_index]); + list_valid.push(false); continue; } if pos == 0 { return Err(DataFusionError::Internal( - "Position for array_insert should be greter or less than zero".to_string(), + "Position for array_insert should be greater or less than zero".to_string(), )); } - if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) { - let corrected_pos = if pos > 0 { - (pos - 1).as_usize() - } else { - end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 } - }; - let new_array_len = std::cmp::max(end - start + 1, corrected_pos); - if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { - return Err(DataFusionError::Internal(format!( - "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}" - ))); - } + let final_len: usize; - if (start + corrected_pos) <= end { - mutable_values.extend(0, start, start + corrected_pos); + if pos > 0 { + // Positive index (1-based) + let pos1 = pos as usize; + if pos1 <= len + 1 { + // In-range insertion (including appending to end) + let corrected = pos1 - 1; // 0-based insertion point + mutable_values.extend(0, start, start + corrected); mutable_values.extend(1, row_index, row_index + 1); - mutable_values.extend(0, start + corrected_pos, end); - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + mutable_values.extend(0, start + corrected, end); + final_len = len + 1; } else { + // Beyond end: pad with nulls then insert + let corrected = pos1 - 1; + let padding = corrected - len; mutable_values.extend(0, start, end); - mutable_values.extend_nulls(new_array_len - (end - start)); + mutable_values.extend_nulls(padding); mutable_values.extend(1, row_index, row_index + 1); - // In that case spark actualy makes array longer than expected; - // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5 - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one()); + final_len = corrected + 1; // equals pos1 } } else { - // This comment is takes from the Apache Spark source code as is: - // special case- if the new position is negative but larger than the current array size - // place the new item at start of array, place the current array contents at the end - // and fill the newly created array elements inbetween with a null - let base_offset = if legacy_mode { 1 } else { 0 }; - let new_array_len = (-pos + base_offset).as_usize(); - if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { - return Err(DataFusionError::Internal(format!( - "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}" - ))); - } - mutable_values.extend(1, row_index, row_index + 1); - mutable_values.extend_nulls(new_array_len - (end - start + 1)); - mutable_values.extend(0, start, end); - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); - } - if is_item_null { - if (start == end) || (values.is_null(row_index)) { - new_nulls.push(false) + // Negative index (1-based from the end) + let k = (-pos) as usize; + + if k <= len { + // In-range negative insertion + // Non-legacy: -1 behaves like append to end (corrected = len - k + 1) + // Legacy: -1 behaves like insert before the last element (corrected = len - k) + let base_offset = if legacy_mode { 0 } else { 1 }; + let corrected = len - k + base_offset; + mutable_values.extend(0, start, start + corrected); + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend(0, start + corrected, end); + final_len = len + 1; } else { - new_nulls.push(true) + // Negative index beyond the start (Spark-specific behavior): + // Place item first, then pad with nulls, then append the original array. + // Final length = k + base_offset, where base_offset = 1 in legacy mode, otherwise 0. + let base_offset = if legacy_mode { 1 } else { 0 }; + let target_len = k + base_offset; + let padding = target_len.saturating_sub(len + 1); + mutable_values.extend(1, row_index, row_index + 1); // insert item first + mutable_values.extend_nulls(padding); // pad nulls + mutable_values.extend(0, start, end); // append original values + final_len = target_len; } - } else { - new_nulls.push(true) } + + if final_len > MAX_ROUNDED_ARRAY_LENGTH { + return Err(DataFusionError::Internal(format!( + "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but got {final_len}" + ))); + } + + let prev = new_offsets[row_index].as_usize(); + new_offsets.push(O::usize_as(prev + final_len)); + list_valid.push(true); } - let data = make_array(mutable_values.freeze()); - let data_type = match list_array.data_type() { - DataType::List(field) => field.data_type(), - DataType::LargeList(field) => field.data_type(), + let child = make_array(mutable_values.freeze()); + + // Reuse the original list element field (name/type/nullability) + let elem_field = match list_array.data_type() { + DataType::List(field) => Arc::clone(field), + DataType::LargeList(field) => Arc::clone(field), _ => unreachable!(), }; - let new_array = GenericListArray::::try_new( - Arc::new(Field::new("item", data_type.clone(), true)), + + // Build the resulting list array + let new_list = GenericListArray::::try_new( + elem_field, OffsetBuffer::new(new_offsets.into()), - data, - Some(NullBuffer::new(new_nulls.into())), + child, + Some(NullBuffer::new(list_valid.into())), )?; - Ok(ColumnarValue::Array(Arc::new(new_array))) + Ok(ColumnarValue::Array(Arc::new(new_list))) } impl Display for ArrayInsert { @@ -442,4 +458,37 @@ mod test { Ok(()) } + + #[test] + fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> { + use arrow::array::{Array, ArrayRef, Int32Array, ListArray}; + use arrow::datatypes::Int32Type; + + // row0 = [0, null, 0] + // row1 = [1, null, 1] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), None, Some(0)]), + Some(vec![Some(1), None, Some(1)]), + ]); + + let positions = Int32Array::from(vec![1, 1]); + let items = Int32Array::from(vec![None, None]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + false, // legacy_mode = false + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![None, Some(0), None, Some(0)]), + Some(vec![None, Some(1), None, Some(1)]), + ]); + assert_eq!(&result.to_data(), &expected.to_data()); + Ok(()) + } } diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index c5060382e0..4d06baaa8d 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} import org.apache.comet.DataTypeSupport.isComplexType @@ -210,11 +211,13 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)")) + .withColumn("arrPosIsNull", expr("array_insert(arr, cast(null as int), 1)")) .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)")) .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)")) checkSparkAnswerAndOperator(df.select("arrInsertResult")) checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize")) + checkSparkAnswerAndOperator(df.select("arrPosIsNull")) checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize")) checkSparkAnswerAndOperator(df.select("arrInsertNone")) }) @@ -802,4 +805,28 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp fallbackReason) } } + + test("array_reverse 2") { + // This test validates data correctness for array columns with nullable elements. + // See https://github.com/apache/datafusion-comet/issues/2612 + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val schemaOptions = + SchemaGenOptions(generateArray = true, generateStruct = false, generateMap = false) + val dataOptions = DataGenOptions(allowNull = true, generateNegativeZero = false) + ParquetGenerator.makeParquetFile(random, spark, filename, 100, schemaOptions, dataOptions) + } + withTempView("t1") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + for (field <- table.schema.fields.filter(_.dataType.isInstanceOf[ArrayType])) { + val sql = s"SELECT ${field.name}, reverse(${field.name}) FROM t1 ORDER BY ${field.name}" + checkSparkAnswer(sql) + } + } + } + } }