Skip to content

Commit a248e81

Browse files
authored
Restore multitenant and ARM cross-tenant authentication API (Azure#20655)
1 parent 3824f00 commit a248e81

16 files changed

+62
-19
lines changed

sdk/azcore/arm/policy/policy.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ import (
1414

1515
// BearerTokenOptions configures the bearer token policy's behavior.
1616
type BearerTokenOptions struct {
17+
// AuxiliaryTenants are additional tenant IDs for authenticating cross-tenant requests.
18+
// The policy will add a token from each of these tenants to every request. The
19+
// authenticating user or service principal must be a guest in these tenants, and the
20+
// policy's credential must support multitenant authentication.
21+
AuxiliaryTenants []string
22+
1723
// Scopes contains the list of permission scopes required for the token.
1824
Scopes []string
1925
}
@@ -44,6 +50,12 @@ type RegistrationOptions struct {
4450
type ClientOptions struct {
4551
policy.ClientOptions
4652

53+
// AuxiliaryTenants are additional tenant IDs for authenticating cross-tenant requests.
54+
// The client will add a token from each of these tenants to every request. The
55+
// authenticating user or service principal must be a guest in these tenants, and the
56+
// client's credential must support multitenant authentication.
57+
AuxiliaryTenants []string
58+
4759
// DisableRPRegistration disables the auto-RP registration policy. Defaults to false.
4860
DisableRPRegistration bool
4961
}

sdk/azcore/arm/runtime/pipeline.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
2929
return azruntime.Pipeline{}, err
3030
}
3131
authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{
32-
Scopes: []string{conf.Audience + "/.default"},
32+
AuxiliaryTenants: options.AuxiliaryTenants,
33+
Scopes: []string{conf.Audience + "/.default"},
3334
})
3435
perRetry := make([]azpolicy.Policy, len(plOpts.PerRetry), len(plOpts.PerRetry)+1)
3536
copy(perRetry, plOpts.PerRetry)

sdk/azcore/arm/runtime/policy_bearer_token.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"net/http"
1010
"strings"
11+
"time"
1112

1213
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1314
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
@@ -26,6 +27,19 @@ type acquiringResourceState struct {
2627
tenant string
2728
}
2829

30+
// acquire acquires or updates the resource; only one
31+
// thread/goroutine at a time ever calls this function
32+
func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
33+
tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{
34+
Scopes: state.p.scopes,
35+
TenantID: state.tenant,
36+
})
37+
if err != nil {
38+
return azcore.AccessToken{}, time.Time{}, err
39+
}
40+
return tk, tk.ExpiresOn, nil
41+
}
42+
2943
// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
3044
type BearerTokenPolicy struct {
3145
auxResources map[string]*temporal.Resource[azcore.AccessToken, acquiringResourceState]
@@ -42,6 +56,10 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTok
4256
opts = &armpolicy.BearerTokenOptions{}
4357
}
4458
p := &BearerTokenPolicy{cred: cred}
59+
p.auxResources = make(map[string]*temporal.Resource[azcore.AccessToken, acquiringResourceState], len(opts.AuxiliaryTenants))
60+
for _, t := range opts.AuxiliaryTenants {
61+
p.auxResources[t] = temporal.NewResource(acquire)
62+
}
4563
p.scopes = make([]string, len(opts.Scopes))
4664
copy(p.scopes, opts.Scopes)
4765
p.btp = azruntime.NewBearerTokenPolicy(cred, opts.Scopes, &azpolicy.BearerTokenOptions{

sdk/azcore/arm/runtime/policy_bearer_token_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) {
163163
}
164164

165165
func TestAuxiliaryTenants(t *testing.T) {
166-
t.Skip("unskip this test after restoring cross-tenant auth support")
167166
srv, close := mock.NewTLSServer()
168167
defer close()
169168
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
@@ -177,13 +176,13 @@ func TestAuxiliaryTenants(t *testing.T) {
177176
getTokenImpl: func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
178177
require.False(t, expectCache, "client should have used a cached token instead of requesting another")
179178
tenant := primary
180-
// if options.TenantID != "" {
181-
// tenant = options.TenantID
182-
// }
179+
if options.TenantID != "" {
180+
tenant = options.TenantID
181+
}
183182
return azcore.AccessToken{Token: tenant, ExpiresOn: time.Now().Add(time.Hour).UTC()}, nil
184183
},
185184
},
186-
&armpolicy.BearerTokenOptions{ /*AuxiliaryTenants: auxTenants,*/ Scopes: []string{scope}},
185+
&armpolicy.BearerTokenOptions{AuxiliaryTenants: auxTenants, Scopes: []string{scope}},
187186
)
188187
pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, PerRetryPolicies: []azpolicy.Policy{b}})
189188
expected := strings.Split(shared.BearerTokenPrefix+strings.Join(auxTenants, ","+shared.BearerTokenPrefix), ",")

sdk/azcore/internal/exported/exported.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ type AccessToken struct {
5353
type TokenRequestOptions struct {
5454
// Scopes contains the list of permission scopes required for the token.
5555
Scopes []string
56+
57+
// TenantID identifies the tenant from which to request the token. azidentity credentials authenticate in
58+
// their configured default tenants when this field isn't set.
59+
TenantID string
5660
}
5761

5862
// TokenCredential represents a credential capable of providing an OAuth token.

sdk/azidentity/azidentity_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ func Test_NonHTTPSAuthorityHost(t *testing.T) {
279279
}
280280

281281
func TestAdditionallyAllowedTenants(t *testing.T) {
282-
t.Skip("unskip this test after restoring TokenRequestOptions.TenantID")
283282
af := filepath.Join(t.TempDir(), t.Name()+credNameWorkloadIdentity)
284283
if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil {
285284
t.Fatal(err)
@@ -321,7 +320,7 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
321320
err: true,
322321
},
323322
} {
324-
tro := policy.TokenRequestOptions{Scopes: []string{liveTestScope}}
323+
tro := policy.TokenRequestOptions{Scopes: []string{liveTestScope}, TenantID: test.tenant}
325324
for _, subtest := range []struct {
326325
ctor func(azcore.ClientOptions) (azcore.TokenCredential, error)
327326
env map[string]string

sdk/azidentity/azure_cli_credential.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
8181
}
8282

8383
func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
84-
b, err := c.tokenProvider(ctx, opts.Scopes[0], "")
84+
b, err := c.tokenProvider(ctx, opts.Scopes[0], opts.TenantID)
8585
if err != nil {
8686
return azcore.AccessToken{}, err
8787
}

sdk/azidentity/azure_cli_credential_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ func TestAzureCLICredential_GetTokenInvalidToken(t *testing.T) {
6565
}
6666

6767
func TestAzureCLICredential_TenantID(t *testing.T) {
68-
t.Skip("unskip this test after restoring TokenRequestOptions.TenantID")
6968
expected := "expected-tenant-id"
7069
called := false
7170
options := AzureCLICredentialOptions{

sdk/azidentity/client_assertion_credential.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.To
7272
}
7373

7474
func (c *ClientAssertionCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
75-
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes)
75+
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(opts.TenantID))
7676
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
7777
}
7878

7979
func (c *ClientAssertionCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
80-
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes)
80+
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(opts.TenantID))
8181
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
8282
}
8383

sdk/azidentity/client_certificate_credential.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.
7979
}
8080

8181
func (c *ClientCertificateCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
82-
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes)
82+
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(opts.TenantID))
8383
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
8484
}
8585

8686
func (c *ClientCertificateCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
87-
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes)
87+
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(opts.TenantID))
8888
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
8989
}
9090

0 commit comments

Comments
 (0)