Skip to content

Commit 9b22ee2

Browse files
authored
Cosmos DB - Adds TransactionalBatch support (Azure#17795)
Adding support for TransactionalBatch
1 parent b0ab728 commit 9b22ee2

18 files changed

+1261
-2
lines changed

sdk/data/azcosmos/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
* Added Transactional Batch support
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/data/azcosmos/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This client library enables client applications to connect to Azure Cosmos via t
88

99
### Prerequisites
1010

11-
* Go versions 1.16 or higher
11+
* Go versions 1.18 or higher
1212
* An Azure subscription or free Azure Cosmos DB trial account
1313

1414
Note: If you don't have an Azure subscription, create a free account before you begin.

sdk/data/azcosmos/cosmos_client.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,26 @@ func (c *Client) sendDeleteRequest(
238238
return c.executeAndEnsureSuccessResponse(req)
239239
}
240240

241+
func (c *Client) sendBatchRequest(
242+
ctx context.Context,
243+
path string,
244+
batch []batchOperation,
245+
operationContext pipelineRequestOptions,
246+
requestOptions cosmosRequestOptions,
247+
requestEnricher func(*policy.Request)) (*http.Response, error) {
248+
req, err := c.createRequest(path, ctx, http.MethodPost, operationContext, requestOptions, requestEnricher)
249+
if err != nil {
250+
return nil, err
251+
}
252+
253+
err = c.attachContent(batch, req)
254+
if err != nil {
255+
return nil, err
256+
}
257+
258+
return c.executeAndEnsureSuccessResponse(req)
259+
}
260+
241261
func (c *Client) createRequest(
242262
path string,
243263
ctx context.Context,
@@ -288,7 +308,6 @@ func (c *Client) attachContent(content interface{}, req *policy.Request) error {
288308
default:
289309
// Otherwise, we need to marshal it
290310
err = azruntime.MarshalAsJSON(req, content)
291-
292311
}
293312

294313
if err != nil {

sdk/data/azcosmos/cosmos_client_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,50 @@ func TestSendQuery(t *testing.T) {
354354
}
355355
}
356356

357+
func TestSendBatch(t *testing.T) {
358+
srv, close := mock.NewTLSServer()
359+
defer close()
360+
srv.SetResponse(
361+
mock.WithStatusCode(200))
362+
verifier := pipelineVerifier{}
363+
pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{&verifier}}, &policy.ClientOptions{Transport: srv})
364+
client := &Client{endpoint: srv.URL(), pipeline: pl}
365+
operationContext := pipelineRequestOptions{
366+
resourceType: resourceTypeDocument,
367+
resourceAddress: "",
368+
}
369+
370+
batch := TransactionalBatch{}
371+
batch.partitionKey = NewPartitionKeyString("foo")
372+
373+
body := map[string]string{
374+
"foo": "bar",
375+
}
376+
377+
itemMarshall, _ := json.Marshal(body)
378+
379+
batch.CreateItem(itemMarshall, nil)
380+
batch.ReadItem("someId", nil)
381+
382+
marshalled, err := json.Marshal(batch.operations)
383+
if err != nil {
384+
t.Fatal(err)
385+
}
386+
387+
_, err = client.sendBatchRequest(context.Background(), "/", batch.operations, operationContext, &TransactionalBatchOptions{}, nil)
388+
if err != nil {
389+
t.Fatal(err)
390+
}
391+
392+
if verifier.requests[0].method != http.MethodPost {
393+
t.Errorf("Expected %v, but got %v", http.MethodPost, verifier.requests[0].method)
394+
}
395+
396+
if verifier.requests[0].body != string(marshalled) {
397+
t.Errorf("Expected %v, but got %v", string(marshalled), verifier.requests[0].body)
398+
}
399+
}
400+
357401
func TestCreateScopeFromEndpoint(t *testing.T) {
358402
url := "https://foo.documents.azure.com:443/"
359403
scope, err := createScopeFromEndpoint(url)

sdk/data/azcosmos/cosmos_container.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package azcosmos
55

66
import (
77
"context"
8+
"errors"
89

910
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1011
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
@@ -456,6 +457,66 @@ func (c *ContainerClient) NewQueryItemsPager(query string, partitionKey Partitio
456457
})
457458
}
458459

460+
// NewTransactionalBatch creates a batch of operations to be committed as a single unit.
461+
// See https://docs.microsoft.com/azure/cosmos-db/sql/transactional-batch
462+
func (c *ContainerClient) NewTransactionalBatch(partitionKey PartitionKey) TransactionalBatch {
463+
return TransactionalBatch{partitionKey: partitionKey}
464+
}
465+
466+
// ExecuteTransactionalBatch executes a transactional batch.
467+
// Once executed, verify the Success property of the response to determine if the batch was committed
468+
func (c *ContainerClient) ExecuteTransactionalBatch(ctx context.Context, b TransactionalBatch, o *TransactionalBatchOptions) (TransactionalBatchResponse, error) {
469+
if len(b.operations) == 0 {
470+
return TransactionalBatchResponse{}, errors.New("no operations in batch")
471+
}
472+
473+
h := headerOptionsOverride{
474+
partitionKey: &b.partitionKey,
475+
}
476+
477+
if o == nil {
478+
o = &TransactionalBatchOptions{}
479+
} else {
480+
h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite
481+
}
482+
483+
// If contentResponseOnWrite is not enabled at the client level the
484+
// service will not even send a batch response payload
485+
// Instead we should automatically enforce contentResponseOnWrite for all
486+
// batch requests whenever at least one of the item operations requires a content response (read operation)
487+
enableContentResponseOnWriteForReadOperations := true
488+
for _, op := range b.operations {
489+
if op.getOperationType() == operationTypeRead {
490+
h.enableContentResponseOnWrite = &enableContentResponseOnWriteForReadOperations
491+
break
492+
}
493+
}
494+
495+
operationContext := pipelineRequestOptions{
496+
resourceType: resourceTypeDocument,
497+
resourceAddress: c.link,
498+
isWriteOperation: true,
499+
headerOptionsOverride: &h}
500+
501+
path, err := generatePathForNameBased(resourceTypeDocument, operationContext.resourceAddress, true)
502+
if err != nil {
503+
return TransactionalBatchResponse{}, err
504+
}
505+
506+
azResponse, err := c.database.client.sendBatchRequest(
507+
ctx,
508+
path,
509+
b.operations,
510+
operationContext,
511+
o,
512+
nil)
513+
if err != nil {
514+
return TransactionalBatchResponse{}, err
515+
}
516+
517+
return newTransactionalBatchResponse(azResponse)
518+
}
519+
459520
func (c *ContainerClient) getRID(ctx context.Context) (string, error) {
460521
containerResponse, err := c.Read(ctx, nil)
461522
if err != nil {

sdk/data/azcosmos/cosmos_container_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,72 @@ func TestContainerQueryItems(t *testing.T) {
520520
}
521521
}
522522
}
523+
524+
func TestContainerExecuteBatch(t *testing.T) {
525+
batchResponseRaw := []map[string]interface{}{
526+
{"statusCode": 200, "requestCharge": 10.0, "eTag": "someETag", "resourceBody": "someBody"},
527+
{"statusCode": 201, "requestCharge": 11.0, "eTag": "someETag2"},
528+
}
529+
530+
jsonString, err := json.Marshal(batchResponseRaw)
531+
if err != nil {
532+
t.Fatal(err)
533+
}
534+
535+
srv, close := mock.NewTLSServer()
536+
defer close()
537+
srv.SetResponse(
538+
mock.WithBody(jsonString),
539+
mock.WithStatusCode(http.StatusOK),
540+
mock.WithHeader(cosmosHeaderEtag, "someEtag"),
541+
mock.WithHeader(cosmosHeaderActivityId, "someActivityId"),
542+
mock.WithHeader(cosmosHeaderRequestCharge, "13.42"))
543+
544+
verifier := pipelineVerifier{}
545+
546+
pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{&verifier}}, &policy.ClientOptions{Transport: srv})
547+
client := &Client{endpoint: srv.URL(), pipeline: pl}
548+
549+
database, _ := newDatabase("databaseId", client)
550+
container, _ := newContainer("containerId", database)
551+
552+
pk := NewPartitionKeyString("pk")
553+
batch := container.NewTransactionalBatch(pk)
554+
_, err = container.ExecuteTransactionalBatch(context.TODO(), batch, nil)
555+
if err == nil {
556+
t.Fatal("Expected error, but got nil")
557+
}
558+
559+
batch.ReadItem("someId", nil)
560+
561+
body := map[string]string{
562+
"foo": "bar",
563+
}
564+
565+
itemMarshall, _ := json.Marshal(body)
566+
batch.CreateItem(itemMarshall, nil)
567+
568+
_, err = container.ExecuteTransactionalBatch(context.TODO(), batch, nil)
569+
if err != nil {
570+
t.Fatal(err)
571+
}
572+
573+
if len(verifier.requests) != 1 {
574+
t.Fatalf("Expected 1 request, got %d", len(verifier.requests))
575+
}
576+
577+
request := verifier.requests[0]
578+
579+
if request.method != http.MethodPost {
580+
t.Errorf("Expected method to be %s, but got %s", http.MethodPost, request.method)
581+
}
582+
583+
if request.url.RequestURI() != "/dbs/databaseId/colls/containerId/docs" {
584+
t.Errorf("Expected url to be %s, but got %s", "/dbs/databaseId/colls/containerId/docs", request.url.RequestURI())
585+
}
586+
587+
marshalledOperations, _ := json.Marshal(batch.operations)
588+
if request.body != string(marshalledOperations) {
589+
t.Errorf("Expected %v, but got %v", string(marshalledOperations), request.body)
590+
}
591+
}

sdk/data/azcosmos/cosmos_headers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ const (
3030
cosmosHeaderQueryMetrics string = "x-ms-documentdb-query-metrics"
3131
cosmosHeaderIndexUtilization string = "x-ms-cosmos-index-utilization"
3232
cosmosHeaderCorrelatedActivityId string = "x-ms-cosmos-correlated-activityid"
33+
cosmosHeaderIsBatchRequest string = "x-ms-cosmos-is-batch-request"
34+
cosmosHeaderIsBatchAtomic string = "x-ms-cosmos-batch-atomic"
35+
cosmosHeaderIsBatchOrdered string = "x-ms-cosmos-batch-ordered"
3336
headerXmsDate string = "x-ms-date"
3437
headerAuthorization string = "Authorization"
3538
headerContentType string = "Content-Type"

0 commit comments

Comments
 (0)