Skip to content

Commit 3f7acd2

Browse files
authored
Cosmos DB: Add AAD authentication (Azure#17742)
Adding azcosmos.NewClient with support for TokenCredential
1 parent ffb48f0 commit 3f7acd2

File tree

12 files changed

+413
-5
lines changed

12 files changed

+413
-5
lines changed

sdk/data/azcosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Features Added
66
* Added single partition query support.
7+
* Added Azure AD authentication support through `azcosmos.NewClient`
78

89
### Breaking Changes
910
* This module now requires Go 1.18

sdk/data/azcosmos/README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,29 @@ The following section provides several code snippets covering some of the most c
4444

4545
### Create Cosmos Client
4646

47+
The clients support different forms of authentication. The azcosmos library supports authorization via Azure Active Directory or an account key.
48+
49+
**Using Azure Active Directory**
50+
51+
```go
52+
import "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
53+
54+
cred, err := azidentity.NewDefaultAzureCredential(nil)
55+
handle(err)
56+
client, err := azcosmos.NewClient("myAccountEndpointURL", cred, nil)
57+
handle(err)
58+
```
59+
60+
**Using account keys**
61+
4762
```go
4863
const (
4964
cosmosDbEndpoint = "someEndpoint"
5065
cosmosDbKey = "someKey"
5166
)
5267

53-
cred, _ := azcosmos.NewKeyCredential(cosmosDbKey)
68+
cred, err := azcosmos.NewKeyCredential(cosmosDbKey)
69+
handle(err)
5470
client, err := azcosmos.NewClientWithKey(cosmosDbEndpoint, cred, nil)
5571
handle(err)
5672
```

sdk/data/azcosmos/cosmos_client.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ import (
77
"bytes"
88
"context"
99
"errors"
10+
"fmt"
1011
"net/http"
12+
"net/url"
1113
"time"
1214

15+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1316
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1417
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
1518
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
@@ -26,31 +29,52 @@ func (c *Client) Endpoint() string {
2629
return c.endpoint
2730
}
2831

29-
// NewClientWithKey creates a new instance of Cosmos client with the specified values. It uses the default pipeline configuration.
32+
// NewClientWithKey creates a new instance of Cosmos client with shared key authentication. It uses the default pipeline configuration.
3033
// endpoint - The cosmos service endpoint to use.
3134
// cred - The credential used to authenticate with the cosmos service.
3235
// options - Optional Cosmos client options. Pass nil to accept default values.
3336
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
34-
return &Client{endpoint: endpoint, pipeline: newPipeline(cred, o)}, nil
37+
return &Client{endpoint: endpoint, pipeline: newPipeline([]policy.Policy{newSharedKeyCredPolicy(cred)}, o)}, nil
3538
}
3639

37-
func newPipeline(cred KeyCredential, options *ClientOptions) azruntime.Pipeline {
40+
// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
41+
// endpoint - The cosmos service endpoint to use.
42+
// cred - The credential used to authenticate with the cosmos service.
43+
// options - Optional Cosmos client options. Pass nil to accept default values.
44+
func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (*Client, error) {
45+
scope, err := createScopeFromEndpoint(endpoint)
46+
if err != nil {
47+
return nil, err
48+
}
49+
return &Client{endpoint: endpoint, pipeline: newPipeline([]policy.Policy{azruntime.NewBearerTokenPolicy(cred, scope, nil), &cosmosBearerTokenPolicy{}}, o)}, nil
50+
}
51+
52+
func newPipeline(authPolicy []policy.Policy, options *ClientOptions) azruntime.Pipeline {
3853
if options == nil {
3954
options = &ClientOptions{}
4055
}
4156

4257
return azruntime.NewPipeline("azcosmos", serviceLibVersion,
4358
azruntime.PipelineOptions{
4459
PerCall: []policy.Policy{
45-
newSharedKeyCredPolicy(cred),
4660
&headerPolicies{
4761
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
4862
},
4963
},
64+
PerRetry: authPolicy,
5065
},
5166
&options.ClientOptions)
5267
}
5368

69+
func createScopeFromEndpoint(endpoint string) ([]string, error) {
70+
u, err := url.Parse(endpoint)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
return []string{fmt.Sprintf("%s://%s/.default", u.Scheme, u.Hostname())}, nil
76+
}
77+
5478
// NewDatabase returns a struct that represents a database and allows database level operations.
5579
// id - The id of the database.
5680
func (c *Client) NewDatabase(id string) (*DatabaseClient, error) {

sdk/data/azcosmos/cosmos_client_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,22 @@ func TestSendQuery(t *testing.T) {
354354
}
355355
}
356356

357+
func TestCreateScopeFromEndpoint(t *testing.T) {
358+
url := "https://foo.documents.azure.com:443/"
359+
scope, err := createScopeFromEndpoint(url)
360+
if err != nil {
361+
t.Fatal(err)
362+
}
363+
364+
if scope[0] != "https://foo.documents.azure.com/.default" {
365+
t.Errorf("Expected %v, but got %v", "https://foo.documents.azure.com/.default", scope[0])
366+
}
367+
368+
if len(scope) != 1 {
369+
t.Errorf("Expected %v, but got %v", 1, len(scope))
370+
}
371+
}
372+
357373
type pipelineVerifier struct {
358374
requests []pipelineVerifierRequest
359375
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package azcosmos
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
11+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
12+
)
13+
14+
const lenBearerTokenPrefix = len("Bearer ")
15+
16+
type cosmosBearerTokenPolicy struct {
17+
}
18+
19+
func (b *cosmosBearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
20+
currentAuthorization := req.Raw().Header.Get(headerAuthorization)
21+
if currentAuthorization == "" {
22+
return nil, errors.New("authorization header is missing")
23+
}
24+
25+
token := currentAuthorization[lenBearerTokenPrefix:]
26+
req.Raw().Header.Set(headerAuthorization, fmt.Sprintf("type=aad&ver=1.0&sig=%v", token))
27+
return req.Next()
28+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package azcosmos
5+
6+
import (
7+
"context"
8+
"net/http"
9+
"testing"
10+
11+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
12+
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
13+
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
14+
)
15+
16+
func TestConvertBearerToken(t *testing.T) {
17+
srv, close := mock.NewTLSServer()
18+
defer close()
19+
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
20+
21+
verifier := bearerTokenVerify{}
22+
pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{&mockAuthPolicy{}, &cosmosBearerTokenPolicy{}, &verifier}}, &policy.ClientOptions{Transport: srv})
23+
req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
24+
req.SetOperationValue(pipelineRequestOptions{
25+
isWriteOperation: true,
26+
})
27+
28+
if err != nil {
29+
t.Fatalf("unexpected error: %v", err)
30+
}
31+
32+
_, err = pl.Do(req)
33+
if err != nil {
34+
t.Fatalf("unexpected error: %v", err)
35+
}
36+
37+
if verifier.authHeaderContent != "type=aad&ver=1.0&sig=this is a test token" {
38+
t.Fatalf("Expected auth header content to be 'type=aad&ver=1.0&sig=this is a test token', got %s", verifier.authHeaderContent)
39+
}
40+
}
41+
42+
type bearerTokenVerify struct {
43+
authHeaderContent string
44+
}
45+
46+
func (p *bearerTokenVerify) Do(req *policy.Request) (*http.Response, error) {
47+
p.authHeaderContent = req.Raw().Header.Get(headerAuthorization)
48+
49+
return req.Next()
50+
}
51+
52+
type mockAuthPolicy struct{}
53+
54+
func (p *mockAuthPolicy) Do(req *policy.Request) (*http.Response, error) {
55+
req.Raw().Header.Set(headerAuthorization, "Bearer this is a test token")
56+
57+
return req.Next()
58+
}

sdk/data/azcosmos/doc.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@ The azcosmos package is capable of:
1212
1313
Creating the Client
1414
15+
Types of Credentials
16+
The clients support different forms of authentication. The azcosmos library supports
17+
authorization via Azure Active Directory or an account key.
18+
19+
Using Azure Active Directory
20+
To create a client, you can use any of the TokenCredential implementations provided by `azidentity`.
21+
cred, err := azidentity.NewClientSecretCredential("tenantId", "clientId", "clientSecret")
22+
handle(err)
23+
client, err := azcosmos.NewClient("myAccountEndpointURL", cred, nil)
24+
handle(err)
25+
26+
27+
Using account keys
1528
To create a client, you will need the account's endpoint URL and a key credential.
1629
1730
cred, err := azcosmos.NewKeyCredential("myAccountKey")
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package azcosmos
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"testing"
10+
)
11+
12+
func TestAAD(t *testing.T) {
13+
emulatorTests := newEmulatorTests(t)
14+
client := emulatorTests.getClient(t)
15+
16+
database := emulatorTests.createDatabase(t, context.TODO(), client, "aadTest")
17+
defer emulatorTests.deleteDatabase(t, context.TODO(), database)
18+
properties := ContainerProperties{
19+
ID: "aContainer",
20+
PartitionKeyDefinition: PartitionKeyDefinition{
21+
Paths: []string{"/id"},
22+
},
23+
}
24+
25+
_, err := database.CreateContainer(context.TODO(), properties, nil)
26+
if err != nil {
27+
t.Fatalf("Failed to create container: %v", err)
28+
}
29+
30+
aadClient := emulatorTests.getAadClient(t)
31+
32+
item := map[string]string{
33+
"id": "1",
34+
"value": "2",
35+
}
36+
37+
container, _ := aadClient.NewContainer("aadTest", "aContainer")
38+
pk := NewPartitionKeyString("1")
39+
40+
marshalled, err := json.Marshal(item)
41+
if err != nil {
42+
t.Fatal(err)
43+
}
44+
45+
itemResponse, err := container.CreateItem(context.TODO(), pk, marshalled, nil)
46+
if err != nil {
47+
t.Fatalf("Failed to create item: %v", err)
48+
}
49+
50+
if itemResponse.SessionToken == "" {
51+
t.Fatalf("Session token is empty")
52+
}
53+
54+
// No content on write by default
55+
if len(itemResponse.Value) != 0 {
56+
t.Fatalf("Expected empty response, got %v", itemResponse.Value)
57+
}
58+
59+
itemResponse, err = container.ReadItem(context.TODO(), pk, "1", nil)
60+
if err != nil {
61+
t.Fatalf("Failed to read item: %v", err)
62+
}
63+
64+
if len(itemResponse.Value) == 0 {
65+
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
66+
}
67+
68+
var itemResponseBody map[string]interface{}
69+
err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
70+
if err != nil {
71+
t.Fatalf("Failed to unmarshal item response: %v", err)
72+
}
73+
if itemResponseBody["id"] != "1" {
74+
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
75+
}
76+
if itemResponseBody["value"] != "2" {
77+
t.Fatalf("Expected value to be 2, got %v", itemResponseBody["value"])
78+
}
79+
80+
item["value"] = "3"
81+
marshalled, err = json.Marshal(item)
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
itemResponse, err = container.ReplaceItem(context.TODO(), pk, "1", marshalled, &ItemOptions{EnableContentResponseOnWrite: true})
86+
if err != nil {
87+
t.Fatalf("Failed to replace item: %v", err)
88+
}
89+
90+
// Explicitly requesting body on write
91+
if len(itemResponse.Value) == 0 {
92+
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
93+
}
94+
95+
err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
96+
if err != nil {
97+
t.Fatalf("Failed to unmarshal item response: %v", err)
98+
}
99+
if itemResponseBody["id"] != "1" {
100+
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
101+
}
102+
if itemResponseBody["value"] != "3" {
103+
t.Fatalf("Expected value to be 3, got %v", itemResponseBody["value"])
104+
}
105+
106+
item["value"] = "4"
107+
marshalled, err = json.Marshal(item)
108+
if err != nil {
109+
t.Fatal(err)
110+
}
111+
itemResponse, err = container.UpsertItem(context.TODO(), pk, marshalled, &ItemOptions{EnableContentResponseOnWrite: true})
112+
if err != nil {
113+
t.Fatalf("Failed to upsert item: %v", err)
114+
}
115+
116+
// Explicitly requesting body on write
117+
if len(itemResponse.Value) == 0 {
118+
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
119+
}
120+
121+
err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
122+
if err != nil {
123+
t.Fatalf("Failed to unmarshal item response: %v", err)
124+
}
125+
if itemResponseBody["id"] != "1" {
126+
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
127+
}
128+
if itemResponseBody["value"] != "4" {
129+
t.Fatalf("Expected value to be 4, got %v", itemResponseBody["value"])
130+
}
131+
132+
itemResponse, err = container.DeleteItem(context.TODO(), pk, "1", nil)
133+
if err != nil {
134+
t.Fatalf("Failed to replace item: %v", err)
135+
}
136+
137+
if len(itemResponse.Value) != 0 {
138+
t.Fatalf("Expected empty response, got %v", itemResponse.Value)
139+
}
140+
}

0 commit comments

Comments
 (0)