Skip to content

Commit d09baec

Browse files
authored
fix: Pass all Comet configs to native plan (#2801)
1 parent 80bef43 commit d09baec

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.broadcast.Broadcast
3030
import org.apache.spark.internal.Logging
3131
import org.apache.spark.network.util.ByteUnit
3232
import org.apache.spark.sql.comet.CometMetricNode
33+
import org.apache.spark.sql.internal.SQLConf
3334
import org.apache.spark.sql.vectorized._
3435
import org.apache.spark.util.SerializableConfiguration
3536

@@ -87,11 +88,7 @@ class CometExecIterator(
8788
val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs
8889

8990
// serialize Comet related Spark configs in protobuf format
90-
val builder = ConfigMap.newBuilder()
91-
conf.getAll.filter(_._1.startsWith(CometConf.COMET_PREFIX)).foreach { case (k, v) =>
92-
builder.putEntries(k, v)
93-
}
94-
val protobufSparkConfigs = builder.build().toByteArray
91+
val protobufSparkConfigs = CometExecIterator.serializeCometSQLConfs()
9592

9693
// Create keyUnwrapper if encryption is enabled
9794
val keyUnwrapper = if (encryptedFilePaths.nonEmpty) {
@@ -265,6 +262,17 @@ class CometExecIterator(
265262

266263
object CometExecIterator extends Logging {
267264

265+
private def cometSqlConfs: Map[String, String] =
266+
SQLConf.get.getAllConfs.filter(_._1.startsWith(CometConf.COMET_PREFIX))
267+
268+
def serializeCometSQLConfs(): Array[Byte] = {
269+
val builder = ConfigMap.newBuilder()
270+
cometSqlConfs.foreach { case (k, v) =>
271+
builder.putEntries(k, v)
272+
}
273+
builder.build().toByteArray
274+
}
275+
268276
def getMemoryConfig(conf: SparkConf): MemoryConfig = {
269277
val numCores = numDriverOrExecutorCores(conf)
270278
val coresPerTask = conf.get("spark.task.cpus", "1").toInt

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ import org.apache.spark.sql.internal.SQLConf
4747
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
4848
import org.apache.spark.unsafe.types.UTF8String
4949

50-
import org.apache.comet.{CometConf, ExtendedExplainInfo}
50+
import org.apache.comet.{CometConf, CometExecIterator, ExtendedExplainInfo}
5151
import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
52+
import org.apache.comet.serde.Config.ConfigMap
5253
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
5354

5455
class CometExecSuite extends CometTestBase {
@@ -66,6 +67,27 @@ class CometExecSuite extends CometTestBase {
6667
}
6768
}
6869

70+
test("SQLConf serde") {
71+
72+
def roundtrip = {
73+
val protobuf = CometExecIterator.serializeCometSQLConfs()
74+
ConfigMap.parseFrom(protobuf)
75+
}
76+
77+
// test not setting the config
78+
val deserialized: ConfigMap = roundtrip
79+
assert(null == deserialized.getEntriesMap.get(CometConf.COMET_EXPLAIN_NATIVE_ENABLED.key))
80+
81+
// test explicitly setting the config
82+
for (value <- Seq("true", "false")) {
83+
withSQLConf(CometConf.COMET_EXPLAIN_NATIVE_ENABLED.key -> value) {
84+
val deserialized: ConfigMap = roundtrip
85+
assert(
86+
value == deserialized.getEntriesMap.get(CometConf.COMET_EXPLAIN_NATIVE_ENABLED.key))
87+
}
88+
}
89+
}
90+
6991
test("TopK operator should return correct results on dictionary column with nulls") {
7092
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
7193
withTable("test_data") {

0 commit comments

Comments
 (0)