Skip to content

Commit 1905567

Browse files
authored
Cosmos Spark: Support CustomQuery to be used for inference (Azure#22079)
* Updating name * Using custom query if available * Tests * Tests * Adding doc
1 parent ef0f607 commit 1905567

File tree

5 files changed

+62
-14
lines changed

5 files changed

+62
-14
lines changed

sdk/cosmos/azure-cosmos-spark_3-1_2-12/docs/configuration-reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Configuration Reference:
3535
### Query Config
3636
| Config Property Name | Default | Description |
3737
| :--- | :---- | :--- |
38-
| `spark.cosmos.read.customQuery` | None | When provided the custom query will be processed against the Cosmos endpoint instead of dynamically generating the query via predicate push down. Usually it is recommended to rely on Spark's predicate push down because that will allow to generate the most efficient set of filters based on the query plan. But there are a couple of of predicates like aggregates (count, group by, avg, sum etc.) that cannot be pushed down yet (at least in Spark 3.1) - so the custom query is a fallback to allow them to be pushed into the query sent to Cosmos. |
38+
| `spark.cosmos.read.customQuery` | None | When provided the custom query will be processed against the Cosmos endpoint instead of dynamically generating the query via predicate push down. Usually it is recommended to rely on Spark's predicate push down because that will allow to generate the most efficient set of filters based on the query plan. But there are a couple of predicates like aggregates (count, group by, avg, sum etc.) that cannot be pushed down yet (at least in Spark 3.1) - so the custom query is a fallback to allow them to be pushed into the query sent to Cosmos. If specified, with schema inference enabled, the custom query will also be used to infer the schema. |
3939

4040
#### Schema Inference Config
4141

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ private object CosmosSchemaInferenceConfig {
549549
parseFromStringFunction = query => query,
550550
helpMessage = "When schema inference is enabled, used as custom query to infer it")
551551

552-
def parseCosmosReadConfig(cfg: Map[String, String]): CosmosSchemaInferenceConfig = {
552+
def parseCosmosInferenceConfig(cfg: Map[String, String]): CosmosSchemaInferenceConfig = {
553553
val samplingSize = CosmosConfigEntry.parse(cfg, inferSchemaSamplingSize)
554554
val enabled = CosmosConfigEntry.parse(cfg, inferSchemaEnabled)
555555
val query = CosmosConfigEntry.parse(cfg, inferSchemaQuery)

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/CosmosTableSchemaInferrer.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,34 @@ private object CosmosTableSchemaInferrer
7575
private[spark] def inferSchema(client: CosmosAsyncClient,
7676
userConfig: Map[String, String],
7777
defaultSchema: StructType): StructType = {
78-
val cosmosReadConfig = CosmosSchemaInferenceConfig.parseCosmosReadConfig(userConfig)
79-
if (cosmosReadConfig.inferSchemaEnabled) {
78+
val cosmosInferenceConfig = CosmosSchemaInferenceConfig.parseCosmosInferenceConfig(userConfig)
79+
val cosmosReadConfig = CosmosReadConfig.parseCosmosReadConfig(userConfig)
80+
if (cosmosInferenceConfig.inferSchemaEnabled) {
8081
val cosmosContainerConfig = CosmosContainerConfig.parseCosmosContainerConfig(userConfig)
8182
val sourceContainer = ThroughputControlHelper.getContainer(userConfig, cosmosContainerConfig, client)
8283
val queryOptions = new CosmosQueryRequestOptions()
83-
queryOptions.setMaxBufferedItemCount(cosmosReadConfig.inferSchemaSamplingSize)
84-
val queryText = cosmosReadConfig.inferSchemaQuery match {
85-
case None => s"select TOP ${cosmosReadConfig.inferSchemaSamplingSize} * from c"
86-
case _ => cosmosReadConfig.inferSchemaQuery.get
84+
queryOptions.setMaxBufferedItemCount(cosmosInferenceConfig.inferSchemaSamplingSize)
85+
val queryText = cosmosInferenceConfig.inferSchemaQuery match {
86+
case None =>
87+
cosmosReadConfig.customQuery match {
88+
case None => s"select TOP ${cosmosInferenceConfig.inferSchemaSamplingSize} * from c"
89+
case _ => cosmosReadConfig.customQuery.get.queryText
90+
}
91+
case _ => cosmosInferenceConfig.inferSchemaQuery.get
8792
}
8893

8994
val pagedFluxResponse =
9095
sourceContainer.queryItems(queryText, queryOptions, classOf[ObjectNode])
9196

9297
val feedResponseList = pagedFluxResponse
93-
.take(cosmosReadConfig.inferSchemaSamplingSize)
98+
.take(cosmosInferenceConfig.inferSchemaSamplingSize)
9499
.collectList
95100
.block
96101

97102
inferSchema(feedResponseList.asScala,
98-
cosmosReadConfig.inferSchemaQuery.isDefined || cosmosReadConfig.includeSystemProperties,
99-
cosmosReadConfig.inferSchemaQuery.isDefined || cosmosReadConfig.includeTimestamp,
100-
cosmosReadConfig.allowNullForInferredProperties)
103+
cosmosInferenceConfig.inferSchemaQuery.isDefined || cosmosInferenceConfig.includeSystemProperties,
104+
cosmosInferenceConfig.inferSchemaQuery.isDefined || cosmosInferenceConfig.includeTimestamp,
105+
cosmosInferenceConfig.allowNullForInferredProperties)
101106
} else {
102107
defaultSchema
103108
}

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class CosmosConfigSpec extends UnitSpec {
222222
"spark.cosmos.read.inferSchema.query" -> customQuery
223223
)
224224

225-
val config = CosmosSchemaInferenceConfig.parseCosmosReadConfig(userConfig)
225+
val config = CosmosSchemaInferenceConfig.parseCosmosInferenceConfig(userConfig)
226226
config.inferSchemaSamplingSize shouldEqual 50
227227
config.inferSchemaEnabled shouldBe false
228228
config.includeSystemProperties shouldBe true
@@ -233,7 +233,7 @@ class CosmosConfigSpec extends UnitSpec {
233233
it should "provide default schema inference config" in {
234234
val userConfig = Map[String, String]()
235235

236-
val config = CosmosSchemaInferenceConfig.parseCosmosReadConfig(userConfig)
236+
val config = CosmosSchemaInferenceConfig.parseCosmosInferenceConfig(userConfig)
237237

238238
config.inferSchemaSamplingSize shouldEqual 1000
239239
config.inferSchemaEnabled shouldBe true

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,49 @@ class SparkE2EQueryITest
609609
item.getAs[String]("id") shouldEqual id
610610
}
611611

612+
"spark query" should "use Custom Query also for inference" in {
613+
val cosmosEndpoint = TestConfigurations.HOST
614+
val cosmosMasterKey = TestConfigurations.MASTER_KEY
615+
616+
val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
617+
for (state <- Array(true, false)) {
618+
val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
619+
objectNode.put("name", "Shrodigner's duck")
620+
objectNode.put("type", "duck")
621+
objectNode.put("age", 20)
622+
objectNode.put("isAlive", state)
623+
objectNode.put("id", UUID.randomUUID().toString)
624+
container.createItem(objectNode).block()
625+
}
626+
627+
val cfgWithInference = Map("spark.cosmos.accountEndpoint" -> cosmosEndpoint,
628+
"spark.cosmos.accountKey" -> cosmosMasterKey,
629+
"spark.cosmos.database" -> cosmosDatabase,
630+
"spark.cosmos.container" -> cosmosContainer,
631+
"spark.cosmos.read.inferSchema.enabled" -> "true",
632+
"spark.cosmos.read.customQuery" ->
633+
"SELECT c.type, c.age, c.isAlive FROM c where c.type = 'duck' and c.isAlive = true",
634+
"spark.cosmos.read.partitioning.strategy" -> "Restrictive"
635+
)
636+
637+
// Not passing schema, letting inference work
638+
val dfWithInference = spark.read.format("cosmos.oltp").options(cfgWithInference).load()
639+
val rowsArrayWithInference = dfWithInference.collect()
640+
rowsArrayWithInference should have size 1
641+
642+
val rowWithInference = rowsArrayWithInference(0)
643+
rowWithInference.getAs[String]("type") shouldEqual "duck"
644+
rowWithInference.getAs[Integer]("age") shouldEqual 20
645+
rowWithInference.getAs[Boolean]("isAlive") shouldEqual true
646+
647+
val fieldNames = rowWithInference.schema.fields.map(field => field.name)
648+
fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false
649+
fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe false
650+
fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false
651+
fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false
652+
fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false
653+
}
654+
612655
//scalastyle:on magic.number
613656
//scalastyle:on multiple.string.literals
614657
}

0 commit comments

Comments
 (0)