Skip to content

Commit ddafb5f

Browse files
Adjust MicroBatchSize dynamically based on throttling rate in BulkExecutor (Azure#22290)
* Temp snapshot * Adjusting MicroBatchSize dynamically in BulkExecutor.java * Making sure Bulk Request 429 bubble up to the BulkExecutor so they are accounted for in dynamic MicroBatchSize adjustment * Adjusting targeted bulk throttling retry rate to be a range * Reducing lock contention in PartitionScopeThresholds.java * Adding unit test coverage for dynamically changing micro batch size in BulkExecutor * Adjusting log level in PartitionScopeThresholds * Moving new API to V4_17_0 Beta annotation * Adding missing copyright header * Removing 408 special-casing * Reacting to code review feedback * Reacting to code review feedback * Reenabling Direct tests * Fixing a bug in the new buffering logic causing 400-BadRequest when the Batch request contains no actual operations after filtering out the dummy FlushOperations * Fixing type * Fixes for merge conflicts * Dummy * Update BulkWriter.scala * Update BulkProcessingThresholds.java * Reverting BridgeInternal changes * Update BridgeInternal.java * Update BulkProcessingOptionsTest.java * Triggering flush on completion of input flux * Self-code review feedback :-) * Update BulkProcessingThresholds.java * Fixing Nullref in BulkWriterTest * Making FlushBuffersItemOperation a singleton * Fixing build break * Fixing test failure
1 parent ab037df commit ddafb5f

24 files changed

+887
-155
lines changed

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
// Licensed under the MIT License.
33
package com.azure.cosmos.spark
44

5+
import com.azure.cosmos.
6+
{
7+
BulkItemRequestOptions,
8+
BulkOperations,
9+
BulkProcessingOptions,
10+
BulkProcessingThresholds,
11+
CosmosAsyncContainer,
12+
CosmosBulkOperationResponse,
13+
CosmosException,
14+
CosmosItemOperation
15+
}
516
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
617
import com.azure.cosmos.implementation.guava25.base.Preconditions
718
import com.azure.cosmos.implementation.spark.{OperationContextAndListenerTuple, OperationListener}
819
import com.azure.cosmos.models.PartitionKey
920
import com.azure.cosmos.spark.BulkWriter.{DefaultMaxPendingOperationPerCore, emitFailureHandler}
1021
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext}
11-
import com.azure.cosmos.{
12-
BulkItemRequestOptions,
13-
BulkOperations, BulkProcessingOptions, CosmosAsyncContainer, CosmosBulkOperationResponse, CosmosException, CosmosItemOperation
14-
}
1522
import com.fasterxml.jackson.databind.node.ObjectNode
1623
import org.apache.spark.TaskContext
1724
import reactor.core.Disposable
@@ -69,34 +76,46 @@ class BulkWriter(container: CosmosAsyncContainer,
6976
private val totalScheduledMetrics = new AtomicLong(0)
7077
private val totalSuccessfulIngestionMetrics = new AtomicLong(0)
7178

72-
private val bulkOptions = new BulkProcessingOptions[Object]()
73-
initializeDiagnosticsIfConfigured()
79+
private val bulkOptions = new BulkProcessingOptions[Object](null, BulkWriter.bulkProcessingThresholds)
80+
private val operationContext = initializeOperationContext()
7481

75-
private def initializeDiagnosticsIfConfigured(): Unit = {
76-
if (diagnosticsConfig.mode.isDefined) {
77-
val taskContext = TaskContext.get
78-
assert(taskContext != null)
82+
private def initializeOperationContext(): SparkTaskContext = {
83+
val taskContext = TaskContext.get
7984

80-
val diagnosticsContext: DiagnosticsContext = DiagnosticsContext(UUID.randomUUID().toString, "BulkWriter")
85+
val diagnosticsContext: DiagnosticsContext = DiagnosticsContext(UUID.randomUUID().toString, "BulkWriter")
8186

87+
if (taskContext != null) {
8288
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
8389
taskContext.stageId(),
8490
taskContext.partitionId(),
91+
taskContext.taskAttemptId(),
8592
"")
8693

8794
val listener: OperationListener =
8895
DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass)
8996

9097
val operationContextAndListenerTuple = new OperationContextAndListenerTuple(taskDiagnosticsContext, listener)
9198
ImplementationBridgeHelpers.CosmosBulkProcessingOptionsHelper
92-
.getCosmosBulkProcessingOptionAccessor()
99+
.getCosmosBulkProcessingOptionAccessor
93100
.setOperationContext(bulkOptions, operationContextAndListenerTuple)
101+
102+
taskDiagnosticsContext
103+
} else{
104+
SparkTaskContext(diagnosticsContext.correlationActivityId,
105+
-1,
106+
-1,
107+
-1,
108+
"")
94109
}
95110
}
96111

97112
private val subscriptionDisposable: Disposable = {
98113
val bulkOperationResponseFlux: SFlux[CosmosBulkOperationResponse[Object]] =
99-
container.processBulkOperations[Object](bulkInputEmitter.asFlux(), bulkOptions).asScala
114+
container
115+
.processBulkOperations[Object](
116+
bulkInputEmitter.asFlux(),
117+
bulkOptions)
118+
.asScala
100119

101120
bulkOperationResponseFlux.subscribe(
102121
resp => {
@@ -109,18 +128,18 @@ class BulkWriter(container: CosmosAsyncContainer,
109128

110129
if (resp.getException != null) {
111130
Option(resp.getException) match {
112-
case Some(cosmosException: CosmosException) => {
113-
log.logDebug(s"encountered ${cosmosException.getStatusCode}")
131+
case Some(cosmosException: CosmosException) =>
132+
log.logDebug(s"encountered ${cosmosException.getStatusCode}, Context: ${operationContext.toString}")
114133
if (shouldIgnore(cosmosException)) {
115134
log.logDebug(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
116-
s"ignored encountered ${cosmosException.getStatusCode}")
135+
s"ignored encountered ${cosmosException.getStatusCode}, Context: ${operationContext.toString}")
117136
totalSuccessfulIngestionMetrics.getAndIncrement()
118137
// work done
119138
} else if (shouldRetry(cosmosException, contextOpt.get)) {
120139
// requeue
121140
log.logWarning(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
122141
s"encountered ${cosmosException.getStatusCode}, will retry! " +
123-
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}")
142+
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}, Context: {${operationContext.toString}}")
124143

125144
// this is to ensure the submission will happen on a different thread in background
126145
// and doesn't block the active thread
@@ -136,14 +155,14 @@ class BulkWriter(container: CosmosAsyncContainer,
136155
} else {
137156
log.logWarning(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
138157
s"encountered ${cosmosException.getStatusCode}, all retries exhausted! " +
139-
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}")
158+
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}, Context: {${operationContext.toString}")
140159
captureIfFirstFailure(cosmosException)
141160
cancelWork()
142161
}
143-
}
144162
case _ =>
145163
log.logWarning(s"unexpected failure: itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
146-
s"encountered , attemptNumber=${context.attemptNumber}, exceptionMessage=${resp.getException.getMessage}", resp.getException)
164+
s"encountered , attemptNumber=${context.attemptNumber}, exceptionMessage=${resp.getException.getMessage}, " +
165+
s"Context: ${operationContext.toString}", resp.getException)
147166
captureIfFirstFailure(resp.getException)
148167
cancelWork()
149168
}
@@ -163,7 +182,7 @@ class BulkWriter(container: CosmosAsyncContainer,
163182
},
164183
errorConsumer = Option.apply(
165184
ex => {
166-
log.logError("Unexpected failure code path in Bulk ingestion", ex)
185+
log.logError(s"Unexpected failure code path in Bulk ingestion, Context: ${operationContext.toString}", ex)
167186
// if there is any failure this closes the bulk.
168187
// at this point bulk api doesn't allow any retrying
169188
// we don't know the list of failed item-operations
@@ -182,21 +201,21 @@ class BulkWriter(container: CosmosAsyncContainer,
182201
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
183202
Preconditions.checkState(!closed.get())
184203
if (errorCaptureFirstException.get() != null) {
185-
log.logWarning("encountered failure earlier, rejecting new work")
204+
log.logWarning(s"encountered failure earlier, rejecting new work, Context: ${operationContext.toString}")
186205
throw errorCaptureFirstException.get()
187206
}
188207

189208
semaphore.acquire()
190209
val cnt = totalScheduledMetrics.getAndIncrement()
191-
log.logDebug(s"total scheduled ${cnt}")
210+
log.logDebug(s"total scheduled $cnt, Context: ${operationContext.toString}")
192211

193212
scheduleWriteInternal(partitionKeyValue, objectNode, OperationContext(getId(objectNode), partitionKeyValue, getETag(objectNode), 1))
194213
}
195214

196215
private def scheduleWriteInternal(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationContext: OperationContext): Unit = {
197216
activeTasks.incrementAndGet()
198217
if (operationContext.attemptNumber > 1) {
199-
log.logInfo(s"bulk scheduleWrite attemptCnt: ${operationContext.attemptNumber}")
218+
log.logInfo(s"bulk scheduleWrite attemptCnt: ${operationContext.attemptNumber}, Context: ${operationContext.toString}")
200219
}
201220

202221
val bulkItemOperation = writeConfig.itemWriteStrategy match {
@@ -226,46 +245,49 @@ class BulkWriter(container: CosmosAsyncContainer,
226245

227246
// the caller has to ensure that after invoking this method scheduleWrite doesn't get invoked
228247
override def flushAndClose(): Unit = {
229-
this.synchronized{
248+
this.synchronized {
230249
try {
231250
if (closed.get()) {
232251
// scalastyle:off return
233252
return
234253
// scalastyle:on return
235254
}
236-
237-
log.logInfo("flushAndClose invoked")
238-
239-
log.logInfo(s"completed so far ${totalSuccessfulIngestionMetrics.get()}, pending tasks ${activeOperations.size}")
255+
log.logInfo(s"flushAndClose invoked, Context: ${operationContext.toString}")
256+
log.logInfo(s"completed so far ${totalSuccessfulIngestionMetrics.get()}, pending tasks ${activeOperations.size}, Context: ${operationContext.toString}")
240257

241258
// error handling, if there is any error and the subscription is cancelled
242259
// the remaining tasks will not be processed hence we never reach activeTasks = 0
243260
// once we do error handling we should think how to cover the scenario.
244261
lock.lock()
245262
try {
246-
while (activeTasks.get() > 0 || errorCaptureFirstException.get != null) {
263+
var activeTasksSnapshot = activeTasks.get()
264+
while (activeTasksSnapshot > 0 || errorCaptureFirstException.get != null) {
265+
log.logDebug(s"Waiting for pending activeTasks $activeTasksSnapshot, Context: ${operationContext.toString}")
247266
pendingTasksCompleted.await()
267+
activeTasksSnapshot = activeTasks.get()
268+
log.logDebug(s"Waiting completed for pending activeTasks $activeTasksSnapshot, Context: ${operationContext.toString}")
248269
}
249270
} finally {
250271
lock.unlock()
251272
}
252273

253-
log.logInfo("invoking bulkInputEmitter.onComplete()")
274+
log.logInfo(s"invoking bulkInputEmitter.onComplete(), Context: ${operationContext.toString}")
254275
semaphore.release(activeTasks.get())
255276
bulkInputEmitter.tryEmitComplete()
256277

257278
// which error to report?
258279
if (errorCaptureFirstException.get() != null) {
259-
log.logError(s"flushAndClose throw captured error ${errorCaptureFirstException.get().getMessage}")
280+
log.logError(s"flushAndClose throw captured error ${errorCaptureFirstException.get().getMessage}, " +
281+
s"Context: ${operationContext.toString}")
260282
throw errorCaptureFirstException.get()
261283
}
262284

263285
assume(activeTasks.get() == 0)
264286
assume(activeOperations.isEmpty)
265287
assume(semaphore.availablePermits() == maxPendingOperations)
266-
267288
log.logInfo(s"flushAndClose completed with no error. " +
268-
s"totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, totalScheduled=${totalScheduledMetrics}")
289+
s"totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, " +
290+
s"totalScheduled=$totalScheduledMetrics, Context: ${operationContext.toString}")
269291
assume(totalScheduledMetrics.get() == totalSuccessfulIngestionMetrics.get)
270292
} finally {
271293
closed.set(true)
@@ -276,16 +298,20 @@ class BulkWriter(container: CosmosAsyncContainer,
276298
private def markTaskCompletion(): Unit = {
277299
lock.lock()
278300
try {
279-
if (activeTasks.decrementAndGet() == 0 || errorCaptureFirstException.get() != null) {
301+
val activeTasksLeftSnapshot = activeTasks.decrementAndGet()
302+
val exceptionSnapshot = errorCaptureFirstException.get()
303+
if (activeTasksLeftSnapshot == 0 || exceptionSnapshot != null) {
304+
log.logDebug(s"MarkTaskCompletion, Active tasks left: $activeTasksLeftSnapshot, " +
305+
s"error: $exceptionSnapshot, Context: ${operationContext.toString}")
280306
pendingTasksCompleted.signal()
281307
}
282308
} finally {
283309
lock.unlock()
284310
}
285311
}
286312

287-
private def captureIfFirstFailure(throwable: Throwable) = {
288-
log.logError("capture failure", throwable)
313+
private def captureIfFirstFailure(throwable: Throwable): Unit = {
314+
log.logError(s"capture failure, Context: {${operationContext.toString}}", throwable)
289315
lock.lock()
290316
try {
291317
errorCaptureFirstException.compareAndSet(null, throwable)
@@ -296,7 +322,8 @@ class BulkWriter(container: CosmosAsyncContainer,
296322
}
297323

298324
private def cancelWork(): Unit = {
299-
log.logInfo(s"cancelling remaining un process tasks ${activeTasks.get}")
325+
log.logInfo(s"cancelling remaining unprocessed tasks ${activeTasks.get}, " +
326+
s"Context: ${operationContext.toString}")
300327
subscriptionDisposable.dispose()
301328
}
302329

@@ -341,10 +368,12 @@ private object BulkWriter {
341368
// hence we want 2MB/ 1KB items per partition to be buffered
342369
// 2 * 1024 * 167 items should get buffered on a 16 CPU core VM
343370
// so per CPU core we want (2 * 1024 * 167 / 16) max items to be buffered
344-
val DefaultMaxPendingOperationPerCore = 2 * 1024 * 167 / 16
371+
val DefaultMaxPendingOperationPerCore: Int = 2 * 1024 * 167 / 16
345372

346373
val emitFailureHandler: EmitFailureHandler =
347-
(signalType, emitResult) => if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) true else false
374+
(_, emitResult) => if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) true else false
375+
376+
val bulkProcessingThresholds = new BulkProcessingThresholds[Object]()
348377
}
349378

350379
//scalastyle:on multiple.string.literals

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReader.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ private case class ItemsPartitionReader
5656
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
5757
taskContext.stageId(),
5858
taskContext.partitionId(),
59+
taskContext.taskAttemptId(),
5960
feedRange.toString + " " + cosmosQuery.toSqlQuerySpec.getQueryText)
6061

6162
val listener: OperationListener =

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/PointWriter.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class PointWriter(container: CosmosAsyncContainer, cosmosWriteConfig: CosmosWrit
6161
private val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
6262
taskContext.stageId(),
6363
taskContext.partitionId(),
64+
taskContext.taskAttemptId(),
6465
"PointWriter")
6566

6667
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
@@ -287,9 +288,11 @@ class PointWriter(container: CosmosAsyncContainer, cosmosWriteConfig: CosmosWrit
287288
private def getOptions(): CosmosItemRequestOptions = {
288289
val options = new CosmosItemRequestOptions()
289290
if (diagnosticsConfig.mode.isDefined) {
290-
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
291+
val taskDiagnosticsContext = SparkTaskContext(
292+
diagnosticsContext.correlationActivityId,
291293
taskContext.stageId(),
292294
taskContext.partitionId(),
295+
taskContext.taskAttemptId(),
293296
"")
294297

295298
val listener: OperationListener =

sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/diagnostics/SparkTaskContext.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ private[spark] case class DeleteOperation(sparkTaskContext: SparkTaskContext, it
2222
private[spark] case class SparkTaskContext(correlationActivityId: String,
2323
stageId: Int,
2424
partitionId: Long,
25+
taskAttemptId: Long,
2526
details: String) extends OperationContext {
2627

2728
@transient private lazy val cachedToString = {
2829
"SparkTaskContext(" +
2930
"correlationActivityId=" + correlationActivityId +
3031
",stageId=" + stageId +
3132
",partitionId=" + partitionId +
33+
",taskAttemptId=" + taskAttemptId +
3234
",details=" + details + ")"
3335
}
3436

0 commit comments

Comments
 (0)