Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {

(fromType, toType) match {
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
Incompatible()
case (dt: ArrayType, DataTypes.StringType) =>
isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
case (dt: ArrayType, dt1: ArrayType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe

test("bitwise_get - throws exceptions") {
def checkSparkAndCometEqualThrows(query: String): Unit = {
checkSparkMaybeThrows(sql(query)) match {
checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(sparkExc.getMessage == cometExc.getMessage)
case _ => fail("Exception should be thrown")
Expand Down
14 changes: 7 additions & 7 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.expressions.{CometCast, CometEvalMode}
Expand Down Expand Up @@ -1035,7 +1035,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast between decimals with negative precision") {
// cast to negative scale
checkSparkMaybeThrows(
checkSparkAnswerMaybeThrows(
spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match {
case (expected, actual) =>
assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR"))
Expand All @@ -1062,11 +1062,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
IntegerType,
LongType,
ShortType,
// FloatType,
// DoubleType,
// FloatType,
// DoubleType,
// BinaryType
DecimalType(10, 2),
DecimalType(38, 18),
BinaryType).foreach { dt =>
DecimalType(38, 18)).foreach { dt =>
val input = generateArrays(100, dt)
castTest(input, StringType, hasIncompatibleType = hasIncompatibleType(input.schema))
}
Expand Down Expand Up @@ -1272,7 +1272,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

// cast() should throw exception on invalid inputs when ansi mode is enabled
val df = data.withColumn("converted", col("a").cast(toType))
checkSparkMaybeThrows(df) match {
checkSparkAnswerMaybeThrows(df) match {
case (None, None) =>
// neither system threw an exception
case (None, Some(e)) =>
Expand Down
24 changes: 12 additions & 12 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
val (sparkErr, cometErr) =
checkSparkMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM tbl"))
checkSparkAnswerMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
} else {
Expand Down Expand Up @@ -359,7 +359,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
val (sparkErr, cometErr) =
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
checkSparkAnswerMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
Expand Down Expand Up @@ -2022,7 +2022,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val expectedDivideByZeroError =
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead."

checkSparkMaybeThrows(sql(query)) match {
checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(expectedDivideByZeroError))
assert(cometException.getMessage.contains(expectedDivideByZeroError))
Expand Down Expand Up @@ -2174,7 +2174,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

def checkOverflow(query: String, dtype: String): Unit = {
checkSparkMaybeThrows(sql(query)) match {
checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
Expand Down Expand Up @@ -2700,7 +2700,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("ListExtract") {
def assertBothThrow(df: DataFrame): Unit = {
checkSparkMaybeThrows(df) match {
checkSparkAnswerMaybeThrows(df) match {
case (Some(_), Some(_)) => ()
case (spark, comet) =>
fail(
Expand Down Expand Up @@ -2850,7 +2850,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
Expand All @@ -2869,7 +2869,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| _1 - _2
| from tbl
| """.stripMargin)
checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
Expand All @@ -2889,7 +2889,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
Expand All @@ -2909,7 +2909,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("Division by zero"))
Expand All @@ -2929,7 +2929,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("Division by zero"))
Expand All @@ -2950,7 +2950,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
Expand Down Expand Up @@ -2985,7 +2985,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq(true, false).foreach { ansi =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
val res = spark.sql(s"SELECT round(_1, $scale) from tbl")
checkSparkMaybeThrows(res) match {
checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
for (field <- df.schema.fields) {
val col = field.name
checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match {
checkSparkAnswerMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match {
case (Some(sparkExc), Some(cometExc)) =>
val cometErrorPattern =
""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r
Expand Down
6 changes: 3 additions & 3 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ abstract class CometTestBase
* This method does not check that Comet replaced any operators or that the results match in the
* case where the query is successful against both Spark and Comet.
*/
protected def checkSparkMaybeThrows(
protected def checkSparkAnswerMaybeThrows(
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
var expected: Try[Array[Row]] = null
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
Expand All @@ -316,8 +316,8 @@ abstract class CometTestBase

(expected, actual) match {
case (Success(_), Success(_)) =>
// TODO compare results and confirm that they match
// https://github.com/apache/datafusion-comet/issues/2657
// compare results and confirm that they match
checkSparkAnswer(df)
Comment on lines +319 to +320
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix. All other changes are for renaming the method.

(None, None)
case _ =>
(expected.failed.toOption, actual.failed.toOption)
Expand Down
Loading