-
Notifications
You must be signed in to change notification settings - Fork 254
Fix: Fix null handling in CometVector implementations #2643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2643 +/- ##
============================================
+ Coverage 56.12% 58.36% +2.23%
- Complexity 976 1424 +448
============================================
Files 119 162 +43
Lines 11743 14150 +2407
Branches 2251 2366 +115
============================================
+ Hits 6591 8258 +1667
- Misses 4012 4700 +688
- Partials 1140 1192 +52 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
CometPlainVector#getBinary
|
|
Yes, I completely agree that unnecessary null checking should be avoided for performance reasons.
For |
|
The failing unit tests appear to be related to the |
I stand corrected. This does make the implementations consistent with Spark. |
| // | ||
| // This code is also based on the implementation of the array_insert from the Apache Spark | ||
| // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713 | ||
| // Implementation aligned with Arrow's half-open offset ranges and Spark semantics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The version fixed by ChatGPT :)
CometPlainVector#getBinary| let start = window[0].as_usize(); | ||
| let end = window[1].as_usize(); | ||
| let len = end - start; | ||
| let pos = pos_data.value(row_index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should pos_data[row_index] be checked for null before using the .value(usize) method ?
https://github.com/apache/arrow-rs/blob/a0db1985c3a0f3190cfc5166b428933a28c740f9/arrow-array/src/array/primitive_array.rs#L766-L767
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed, return null for the entire row when pos is null.
| test("array_reverse 2") { | ||
| // This test validates data correctness for array<binary> columns with nullable elements. | ||
| // See https://github.com/apache/datafusion-comet/issues/2612 | ||
| withTempDir { dir => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me again why this is a good test for the changes in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s exercising a long‑standing null-handling issue in Comet.
Minimal reproducible snippet:
sql("select cast(array(null) as array<binary>) as c1").write
.mode("overwrite")
.save("/tmp/parquet/t1")
sql("select c1, reverse(c1) from parquet.`/tmp/parquet/t1`").show# current output
+------+-----------+
| c1|reverse(c1)|
+------+-----------+
|[NULL]| [[]]|
+------+-----------+
# expected output
+------+-----------+
| c1|reverse(c1)|
+------+-----------+
|[NULL]| [NULL]|
+------+-----------+
Why this happens:
- reverse for
array<binary>isn’t implemented natively by Comet, so the operator falls back to vanilla Spark execution. The root scan, however, is still CometNativeScan.
== Physical Plan ==
CollectLimit (4)
+- * Project (3)
+- * ColumnarToRow (2)
+- CometNativeScan parquet (1)
Note the scan is CometNativeScan. The bug is the mismatch between Comet’s columnar getters and Spark’s expectations when nulls are present.
Relevant generated code (codegenStageId=1) highlights:
- The reverse logic is here:
/* 047 */ for (int project_i_0 = 0; project_i_0 < project_numElements_1; project_i_0++) {
/* 048 */ int project_j_0 = project_numElements_1 - project_i_0 - 1;
/* 049 */ project_arrayData_0.update(project_i_0, project_expr_0_0.getBinary(project_j_0));
/* 050 */ }
/* 051 */ project_value_2 = project_arrayData_0;Observation 1: When constructing the reversed array, Spark’s code directly calls getBinary(j) and does not check element nullability at this point. It relies on getBinary(j) returning null for null elements.
- When writing out the array, Spark does distinguish nulls:
/* 099 */ for (int project_index_2 = 0; project_index_2 < project_numElements_3; project_index_2++) {
/* 100 */ if (project_tmpInput_2.isNullAt(project_index_2)) {
/* 101 */ columnartorow_mutableStateArray_4[3].setNull8Bytes(project_index_2);
/* 102 */ } else {
/* 103 */ columnartorow_mutableStateArray_4[3].write(project_index_2, project_tmpInput_2.getBinary(project_index_2));
/* 104 */ }Observation 2: Spark uses isNullAt to mark nulls, and only calls getBinary(i) for non-null elements. Therefore, Comet must return null from getBinary(i) when the element is null; returning an empty byte array leads to [[]] instead of [NULL].
This PR makes Comet’s ColumnVector getters (getBinary, getUTF8String, getArray, getMap, getDecimal) return null when isNullAt(i) is true to fix this bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, the full generated code
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private int columnartorow_batchIdx_0;
/* 010 */ private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] columnartorow_mutableStateArray_2 = new org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[1];
/* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] columnartorow_mutableStateArray_3 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 012 */ private org.apache.spark.sql.vectorized.ColumnarBatch[] columnartorow_mutableStateArray_1 = new org.apache.spark.sql.vectorized.ColumnarBatch[1];
/* 013 */ private scala.collection.Iterator[] columnartorow_mutableStateArray_0 = new scala.collection.Iterator[1];
/* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[] columnartorow_mutableStateArray_4 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[4];
/* 015 */
/* 016 */ public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 017 */ this.references = references;
/* 018 */ }
/* 019 */
/* 020 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 021 */ partitionIndex = index;
/* 022 */ this.inputs = inputs;
/* 023 */ columnartorow_mutableStateArray_0[0] = inputs[0];
/* 024 */
/* 025 */ columnartorow_mutableStateArray_3[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 026 */ columnartorow_mutableStateArray_4[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(columnartorow_mutableStateArray_3[0], 8);
/* 027 */ columnartorow_mutableStateArray_3[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 028 */ columnartorow_mutableStateArray_4[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(columnartorow_mutableStateArray_3[1], 8);
/* 029 */ columnartorow_mutableStateArray_3[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 64);
/* 030 */ columnartorow_mutableStateArray_4[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(columnartorow_mutableStateArray_3[2], 8);
/* 031 */ columnartorow_mutableStateArray_4[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(columnartorow_mutableStateArray_3[2], 8);
/* 032 */
/* 033 */ }
/* 034 */
/* 035 */ private void project_doConsume_0(ArrayData project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException {
/* 036 */ // common sub-expressions
/* 037 */
/* 038 */ boolean project_isNull_2 = project_exprIsNull_0_0;
/* 039 */ ArrayData project_value_2 = null;
/* 040 */
/* 041 */ if (!project_exprIsNull_0_0) {
/* 042 */ final int project_numElements_1 = project_expr_0_0.numElements();
/* 043 */
/* 044 */ ArrayData project_arrayData_0 = ArrayData.allocateArrayData(
/* 045 */ -1, project_numElements_1, " reverse failed.");
/* 046 */
/* 047 */ for (int project_i_0 = 0; project_i_0 < project_numElements_1; project_i_0++) {
/* 048 */ int project_j_0 = project_numElements_1 - project_i_0 - 1;
/* 049 */ project_arrayData_0.update(project_i_0, project_expr_0_0.getBinary(project_j_0));
/* 050 */ }
/* 051 */ project_value_2 = project_arrayData_0;
/* 052 */
/* 053 */ }
/* 054 */ columnartorow_mutableStateArray_3[2].reset();
/* 055 */
/* 056 */ columnartorow_mutableStateArray_3[2].zeroOutNullBytes();
/* 057 */
/* 058 */ if (project_exprIsNull_0_0) {
/* 059 */ columnartorow_mutableStateArray_3[2].setNullAt(0);
/* 060 */ } else {
/* 061 */ // Remember the current cursor so that we can calculate how many bytes are
/* 062 */ // written later.
/* 063 */ final int project_previousCursor_1 = columnartorow_mutableStateArray_3[2].cursor();
/* 064 */
/* 065 */ final ArrayData project_tmpInput_1 = project_expr_0_0;
/* 066 */ if (project_tmpInput_1 instanceof UnsafeArrayData) {
/* 067 */ columnartorow_mutableStateArray_3[2].write((UnsafeArrayData) project_tmpInput_1);
/* 068 */ } else {
/* 069 */ final int project_numElements_2 = project_tmpInput_1.numElements();
/* 070 */ columnartorow_mutableStateArray_4[2].initialize(project_numElements_2);
/* 071 */
/* 072 */ for (int project_index_1 = 0; project_index_1 < project_numElements_2; project_index_1++) {
/* 073 */ if (project_tmpInput_1.isNullAt(project_index_1)) {
/* 074 */ columnartorow_mutableStateArray_4[2].setNull8Bytes(project_index_1);
/* 075 */ } else {
/* 076 */ columnartorow_mutableStateArray_4[2].write(project_index_1, project_tmpInput_1.getBinary(project_index_1));
/* 077 */ }
/* 078 */
/* 079 */ }
/* 080 */ }
/* 081 */
/* 082 */ columnartorow_mutableStateArray_3[2].setOffsetAndSizeFromPreviousCursor(0, project_previousCursor_1);
/* 083 */ }
/* 084 */
/* 085 */ if (project_isNull_2) {
/* 086 */ columnartorow_mutableStateArray_3[2].setNullAt(1);
/* 087 */ } else {
/* 088 */ // Remember the current cursor so that we can calculate how many bytes are
/* 089 */ // written later.
/* 090 */ final int project_previousCursor_2 = columnartorow_mutableStateArray_3[2].cursor();
/* 091 */
/* 092 */ final ArrayData project_tmpInput_2 = project_value_2;
/* 093 */ if (project_tmpInput_2 instanceof UnsafeArrayData) {
/* 094 */ columnartorow_mutableStateArray_3[2].write((UnsafeArrayData) project_tmpInput_2);
/* 095 */ } else {
/* 096 */ final int project_numElements_3 = project_tmpInput_2.numElements();
/* 097 */ columnartorow_mutableStateArray_4[3].initialize(project_numElements_3);
/* 098 */
/* 099 */ for (int project_index_2 = 0; project_index_2 < project_numElements_3; project_index_2++) {
/* 100 */ if (project_tmpInput_2.isNullAt(project_index_2)) {
/* 101 */ columnartorow_mutableStateArray_4[3].setNull8Bytes(project_index_2);
/* 102 */ } else {
/* 103 */ columnartorow_mutableStateArray_4[3].write(project_index_2, project_tmpInput_2.getBinary(project_index_2));
/* 104 */ }
/* 105 */
/* 106 */ }
/* 107 */ }
/* 108 */
/* 109 */ columnartorow_mutableStateArray_3[2].setOffsetAndSizeFromPreviousCursor(1, project_previousCursor_2);
/* 110 */ }
/* 111 */ append((columnartorow_mutableStateArray_3[2].getRow()));
/* 112 */
/* 113 */ }
/* 114 */
/* 115 */ private void columnartorow_nextBatch_0() throws java.io.IOException {
/* 116 */ if (columnartorow_mutableStateArray_0[0].hasNext()) {
/* 117 */ columnartorow_mutableStateArray_1[0] = (org.apache.spark.sql.vectorized.ColumnarBatch)columnartorow_mutableStateArray_0[0].next();
/* 118 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numInputBatches */).add(1);
/* 119 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(columnartorow_mutableStateArray_1[0].numRows());
/* 120 */ columnartorow_batchIdx_0 = 0;
/* 121 */ columnartorow_mutableStateArray_2[0] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) columnartorow_mutableStateArray_1[0].column(0);
/* 122 */
/* 123 */ }
/* 124 */ }
/* 125 */
/* 126 */ protected void processNext() throws java.io.IOException {
/* 127 */ if (columnartorow_mutableStateArray_1[0] == null) {
/* 128 */ columnartorow_nextBatch_0();
/* 129 */ }
/* 130 */ while ( columnartorow_mutableStateArray_1[0] != null) {
/* 131 */ int columnartorow_numRows_0 = columnartorow_mutableStateArray_1[0].numRows();
/* 132 */ int columnartorow_localEnd_0 = columnartorow_numRows_0 - columnartorow_batchIdx_0;
/* 133 */ for (int columnartorow_localIdx_0 = 0; columnartorow_localIdx_0 < columnartorow_localEnd_0; columnartorow_localIdx_0++) {
/* 134 */ int columnartorow_rowIdx_0 = columnartorow_batchIdx_0 + columnartorow_localIdx_0;
/* 135 */ boolean columnartorow_isNull_0 = columnartorow_mutableStateArray_2[0].isNullAt(columnartorow_rowIdx_0);
/* 136 */ ArrayData columnartorow_value_0 = columnartorow_isNull_0 ? null : (columnartorow_mutableStateArray_2[0].getArray(columnartorow_rowIdx_0));
/* 137 */
/* 138 */ project_doConsume_0(columnartorow_value_0, columnartorow_isNull_0);
/* 139 */ if (shouldStop()) { columnartorow_batchIdx_0 = columnartorow_rowIdx_0 + 1; return; }
/* 140 */ }
/* 141 */ columnartorow_batchIdx_0 = columnartorow_numRows_0;
/* 142 */ columnartorow_mutableStateArray_1[0] = null;
/* 143 */ columnartorow_nextBatch_0();
/* 144 */ }
/* 145 */ }
/* 146 */
/* 147 */ }There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great explanation! I wonder if we can add this knowledge to the developer docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to add this to the dev docs. Should I document this explanation specifically, and where should it live?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a 'Null handling in comet vectors' doc in the contributor-guide section? Probably in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for adding docs in a separate PR.
96233d9 to
622065d
Compare
parthchandra
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
|
Thanks @cfmcgrady and @parthchandra! |
Which issue does this PR close?
Closes #2612.
Rationale for this change
What changes are included in this PR?
How are these changes tested?
Added UT.