Skip to content

Commit 01a02ec

Browse files
Fixes bug preventing usage of feedRangeFilter with change feed (Azure#29338)
* Fixed bug preventing usage of feedRangeFilter with change feed * Fixed FeedRange split when FeedRange is smaller than requested split count
1 parent 0bd5113 commit 01a02ec

File tree

3 files changed

+146
-33
lines changed

3 files changed

+146
-33
lines changed

sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosPartitionPlanner.scala

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -396,25 +396,40 @@ private object CosmosPartitionPlanner extends BasicLoggingTrait {
396396
Some(range)
397397
})
398398

399+
val result = new ArrayBuffer[PartitionMetadata]
399400
orderedRanges
400-
.map(range => {
401-
while (!SparkBridgeImplementationInternal.doRangesOverlap(range, startTokens(startTokensIndex)._1)) {
401+
.foreach(range => {
402+
logInfo(s"merging range $range")
403+
val initialStartTokensIndex = startTokensIndex
404+
val initialLatestTokensIndex = latestTokensIndex
405+
while (startTokensIndex < startTokens.length &&
406+
!SparkBridgeImplementationInternal.doRangesOverlap(range, startTokens(startTokensIndex)._1)) {
407+
402408
startTokensIndex += 1
403-
if (startTokensIndex >= startTokens.length) {
404-
throw new IllegalStateException(s"No overlapping start token found for range '$range'.")
405-
}
406409
}
407410

408-
while (!SparkBridgeImplementationInternal.doRangesOverlap(range, latestTokens(latestTokensIndex).feedRange)) {
411+
while (startTokensIndex < startTokens.length &&
412+
latestTokensIndex < latestTokens.length &&
413+
!SparkBridgeImplementationInternal.doRangesOverlap(range, latestTokens(latestTokensIndex).feedRange)) {
414+
409415
latestTokensIndex += 1
410-
if (latestTokensIndex >= latestTokens.length) {
411-
throw new IllegalStateException(s"No overlapping latest token found for range '$range'.")
412-
}
413416
}
414417

415-
val startLsn: Long = startTokens(startTokensIndex)._2
416-
latestTokens(latestTokensIndex).cloneForSubRange(range, startLsn)
418+
if (startTokensIndex < startTokens.length &&
419+
latestTokensIndex < latestTokens.length) {
420+
421+
val startLsn: Long = startTokens(startTokensIndex)._2
422+
val newToken = latestTokens(latestTokensIndex).cloneForSubRange(range, startLsn)
423+
result.append(newToken)
424+
} else {
425+
startTokensIndex = initialStartTokensIndex
426+
latestTokensIndex = initialLatestTokensIndex
427+
}
417428
})
429+
430+
assert(result.size > 0)
431+
432+
result.toArray
418433
}
419434
// scalastyle:on method.length
420435

sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/SparkE2EStructuredStreamingITest.scala

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ package com.azure.cosmos.spark
44

55
import com.azure.cosmos.CosmosAsyncContainer
66
import com.azure.cosmos.implementation.{TestConfigurations, Utils}
7-
import com.azure.cosmos.models.{PartitionKey, ThroughputProperties}
8-
import org.apache.spark.sql.SparkSession
7+
import com.azure.cosmos.models.{ModelBridgeInternal, PartitionKey, ThroughputProperties}
8+
import org.apache.spark.sql.{DataFrame, SparkSession}
99
import org.apache.spark.sql.streaming.{StreamingQueryListener, Trigger}
1010
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
1111

@@ -663,6 +663,100 @@ class SparkE2EStructuredStreamingITest
663663
targetContainer.delete().block()
664664
}
665665

666+
"spark change feed micro batch (incremental)" can
667+
"filter by feedRange (Restrictive partitioning strategy)" taggedAs(Retryable) in {
668+
669+
runChangeFeedFeedRangeFilterTest("Restrictive")
670+
}
671+
672+
"spark change feed micro batch (incremental)" can
673+
"filter by feedRange (Default partitioning strategy)" taggedAs(Retryable) in {
674+
675+
runChangeFeedFeedRangeFilterTest("Default")
676+
}
677+
678+
"spark change feed micro batch (incremental)" can
679+
"filter by feedRange (Aggressive partitioning strategy)" taggedAs(Retryable) in {
680+
681+
runChangeFeedFeedRangeFilterTest("Aggressive")
682+
}
683+
684+
private[this] def runChangeFeedFeedRangeFilterTest(partitioningStrategy: String): Unit = {
685+
val processedRecordCount = new AtomicLong(0)
686+
val forEachBatchRecordCount = new AtomicLong(0)
687+
var spark = this.createSparkSession(processedRecordCount)
688+
val cosmosEndpoint = TestConfigurations.HOST
689+
val cosmosMasterKey = TestConfigurations.MASTER_KEY
690+
val sourceContainer = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
691+
val testId = UUID.randomUUID().toString
692+
693+
// Initially ingest 100 records
694+
var lastId = ""
695+
for (i <- 0 until 20) {
696+
lastId = this.ingestTestDocument(sourceContainer, i)
697+
}
698+
699+
Thread.sleep(2100)
700+
701+
val pkDefinition = sourceContainer.read().block().getProperties.getPartitionKeyDefinition
702+
val pkDefinitionJson = ModelBridgeInternal.getJsonSerializable(pkDefinition).toJson
703+
704+
val feedRangeFilter = new GetFeedRangeForPartitionKeyValue().call(pkDefinitionJson, lastId)
705+
706+
val changeFeedCfg = Map(
707+
"spark.cosmos.accountEndpoint" -> cosmosEndpoint,
708+
"spark.cosmos.accountKey" -> cosmosMasterKey,
709+
"spark.cosmos.database" -> cosmosDatabase,
710+
"spark.cosmos.container" -> cosmosContainer,
711+
"spark.cosmos.read.inferSchema.enabled" -> "false",
712+
"spark.cosmos.read.partitioning.strategy" -> partitioningStrategy,
713+
"spark.cosmos.partitioning.feedRangeFilter" -> feedRangeFilter
714+
)
715+
716+
val changeFeedDF = spark
717+
.readStream
718+
.format("cosmos.oltp.changeFeed")
719+
.options(changeFeedCfg)
720+
.load()
721+
722+
val microBatchQuery = changeFeedDF
723+
.writeStream
724+
.foreachBatch { (batchDF: DataFrame, batchId: Long) =>
725+
batchDF.persist()
726+
val recordCount = batchDF.count()
727+
forEachBatchRecordCount.addAndGet(recordCount)
728+
println(s"BatchId: $batchId, Document count: $recordCount")
729+
batchDF.unpersist()
730+
()
731+
}
732+
.trigger(Trigger.ProcessingTime("500 milliseconds"))
733+
.queryName(testId)
734+
.option("checkpointLocation", s"/tmp/$testId/")
735+
.start()
736+
737+
Thread.sleep(5000)
738+
739+
microBatchQuery.lastProgress should not be null
740+
microBatchQuery.lastProgress.sources should not be null
741+
microBatchQuery.lastProgress.sources should not be null
742+
microBatchQuery.lastProgress.sources(0).endOffset should not be null
743+
getPartitionCountInOffset(microBatchQuery.lastProgress.sources(0).endOffset) >= 1 shouldEqual true
744+
745+
microBatchQuery.stop()
746+
747+
var sourceCount: Long = getRecordCountOfContainer(sourceContainer)
748+
logInfo(s"RecordCount in source container after first execution: $sourceCount")
749+
750+
forEachBatchRecordCount.get() shouldEqual 1L
751+
processedRecordCount.get() shouldEqual 1L
752+
sourceCount shouldEqual 20L
753+
754+
// close and recreate spark session to validate
755+
// that it is possible to recover the previous query
756+
// from the commit log
757+
spark.close()
758+
}
759+
666760
private[this] def ingestTestDocument
667761
(
668762
container: CosmosAsyncContainer,
@@ -690,7 +784,7 @@ class SparkE2EStructuredStreamingITest
690784
override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = {}
691785
override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = {}
692786
override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = {
693-
processedRecordCount.addAndGet(queryProgress.progress.sink.numOutputRows)
787+
processedRecordCount.addAndGet(queryProgress.progress.numInputRows)
694788
}
695789
})
696790

sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/feedranges/FeedRangeInternal.java

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -265,31 +265,35 @@ static List<FeedRangeEpkImpl> trySplitWithHashV1(
265265
String minRange = effectiveRange.getMin();
266266
long diff = max - min;
267267
List<FeedRangeEpkImpl> splitFeedRanges = new ArrayList<>(targetedSplitCount);
268-
for (int i = 1; i < targetedSplitCount; i++) {
269-
long splitPoint = min + (i * (diff / targetedSplitCount));
270-
String maxRange = PartitionKeyInternalHelper.toHexEncodedBinaryString(
271-
new NumberPartitionKeyComponent[] {
272-
new NumberPartitionKeyComponent(splitPoint)
273-
});
268+
if (diff < targetedSplitCount) {
269+
splitFeedRanges.add(new FeedRangeEpkImpl(effectiveRange));
270+
} else {
271+
for (int i = 1; i < targetedSplitCount; i++) {
272+
long splitPoint = min + (i * (diff / targetedSplitCount));
273+
String maxRange = PartitionKeyInternalHelper.toHexEncodedBinaryString(
274+
new NumberPartitionKeyComponent[] {
275+
new NumberPartitionKeyComponent(splitPoint)
276+
});
277+
splitFeedRanges.add(
278+
new FeedRangeEpkImpl(
279+
new Range<>(
280+
minRange,
281+
maxRange,
282+
i > 1 || effectiveRange.isMinInclusive(),
283+
false)));
284+
285+
minRange = maxRange;
286+
}
287+
274288
splitFeedRanges.add(
275289
new FeedRangeEpkImpl(
276290
new Range<>(
277291
minRange,
278-
maxRange,
279-
i > 1 || effectiveRange.isMinInclusive(),
280-
false)));
281-
282-
minRange = maxRange;
292+
effectiveRange.getMax(),
293+
true,
294+
effectiveRange.isMaxInclusive())));
283295
}
284296

285-
splitFeedRanges.add(
286-
new FeedRangeEpkImpl(
287-
new Range<>(
288-
minRange,
289-
effectiveRange.getMax(),
290-
true,
291-
effectiveRange.isMaxInclusive())));
292-
293297
return splitFeedRanges;
294298
}
295299

0 commit comments

Comments
 (0)