@@ -4,8 +4,8 @@ package com.azure.cosmos.spark
44
55import com .azure .cosmos .CosmosAsyncContainer
66import com .azure .cosmos .implementation .{TestConfigurations , Utils }
7- import com .azure .cosmos .models .{PartitionKey , ThroughputProperties }
8- import org .apache .spark .sql .SparkSession
7+ import com .azure .cosmos .models .{ModelBridgeInternal , PartitionKey , ThroughputProperties }
8+ import org .apache .spark .sql .{ DataFrame , SparkSession }
99import org .apache .spark .sql .streaming .{StreamingQueryListener , Trigger }
1010import org .apache .spark .sql .streaming .StreamingQueryListener .{QueryProgressEvent , QueryStartedEvent , QueryTerminatedEvent }
1111
@@ -663,6 +663,100 @@ class SparkE2EStructuredStreamingITest
663663 targetContainer.delete().block()
664664 }
665665
666+ " spark change feed micro batch (incremental)" can
667+ " filter by feedRange (Restrictive partitioning strategy)" taggedAs(Retryable ) in {
668+
669+ runChangeFeedFeedRangeFilterTest(" Restrictive" )
670+ }
671+
672+ " spark change feed micro batch (incremental)" can
673+ " filter by feedRange (Default partitioning strategy)" taggedAs(Retryable ) in {
674+
675+ runChangeFeedFeedRangeFilterTest(" Default" )
676+ }
677+
678+ " spark change feed micro batch (incremental)" can
679+ " filter by feedRange (Aggressive partitioning strategy)" taggedAs(Retryable ) in {
680+
681+ runChangeFeedFeedRangeFilterTest(" Aggressive" )
682+ }
683+
684+ private [this ] def runChangeFeedFeedRangeFilterTest (partitioningStrategy : String ): Unit = {
685+ val processedRecordCount = new AtomicLong (0 )
686+ val forEachBatchRecordCount = new AtomicLong (0 )
687+ var spark = this .createSparkSession(processedRecordCount)
688+ val cosmosEndpoint = TestConfigurations .HOST
689+ val cosmosMasterKey = TestConfigurations .MASTER_KEY
690+ val sourceContainer = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
691+ val testId = UUID .randomUUID().toString
692+
693+ // Initially ingest 100 records
694+ var lastId = " "
695+ for (i <- 0 until 20 ) {
696+ lastId = this .ingestTestDocument(sourceContainer, i)
697+ }
698+
699+ Thread .sleep(2100 )
700+
701+ val pkDefinition = sourceContainer.read().block().getProperties.getPartitionKeyDefinition
702+ val pkDefinitionJson = ModelBridgeInternal .getJsonSerializable(pkDefinition).toJson
703+
704+ val feedRangeFilter = new GetFeedRangeForPartitionKeyValue ().call(pkDefinitionJson, lastId)
705+
706+ val changeFeedCfg = Map (
707+ " spark.cosmos.accountEndpoint" -> cosmosEndpoint,
708+ " spark.cosmos.accountKey" -> cosmosMasterKey,
709+ " spark.cosmos.database" -> cosmosDatabase,
710+ " spark.cosmos.container" -> cosmosContainer,
711+ " spark.cosmos.read.inferSchema.enabled" -> " false" ,
712+ " spark.cosmos.read.partitioning.strategy" -> partitioningStrategy,
713+ " spark.cosmos.partitioning.feedRangeFilter" -> feedRangeFilter
714+ )
715+
716+ val changeFeedDF = spark
717+ .readStream
718+ .format(" cosmos.oltp.changeFeed" )
719+ .options(changeFeedCfg)
720+ .load()
721+
722+ val microBatchQuery = changeFeedDF
723+ .writeStream
724+ .foreachBatch { (batchDF : DataFrame , batchId : Long ) =>
725+ batchDF.persist()
726+ val recordCount = batchDF.count()
727+ forEachBatchRecordCount.addAndGet(recordCount)
728+ println(s " BatchId: $batchId, Document count: $recordCount" )
729+ batchDF.unpersist()
730+ ()
731+ }
732+ .trigger(Trigger .ProcessingTime (" 500 milliseconds" ))
733+ .queryName(testId)
734+ .option(" checkpointLocation" , s " /tmp/ $testId/ " )
735+ .start()
736+
737+ Thread .sleep(5000 )
738+
739+ microBatchQuery.lastProgress should not be null
740+ microBatchQuery.lastProgress.sources should not be null
741+ microBatchQuery.lastProgress.sources should not be null
742+ microBatchQuery.lastProgress.sources(0 ).endOffset should not be null
743+ getPartitionCountInOffset(microBatchQuery.lastProgress.sources(0 ).endOffset) >= 1 shouldEqual true
744+
745+ microBatchQuery.stop()
746+
747+ var sourceCount : Long = getRecordCountOfContainer(sourceContainer)
748+ logInfo(s " RecordCount in source container after first execution: $sourceCount" )
749+
750+ forEachBatchRecordCount.get() shouldEqual 1L
751+ processedRecordCount.get() shouldEqual 1L
752+ sourceCount shouldEqual 20L
753+
754+ // close and recreate spark session to validate
755+ // that it is possible to recover the previous query
756+ // from the commit log
757+ spark.close()
758+ }
759+
666760 private [this ] def ingestTestDocument
667761 (
668762 container : CosmosAsyncContainer ,
@@ -690,7 +784,7 @@ class SparkE2EStructuredStreamingITest
690784 override def onQueryStarted (queryStarted : QueryStartedEvent ): Unit = {}
691785 override def onQueryTerminated (queryTerminated : QueryTerminatedEvent ): Unit = {}
692786 override def onQueryProgress (queryProgress : QueryProgressEvent ): Unit = {
693- processedRecordCount.addAndGet(queryProgress.progress.sink.numOutputRows )
787+ processedRecordCount.addAndGet(queryProgress.progress.numInputRows )
694788 }
695789 })
696790
0 commit comments