Skip to content

Commit f8d9350

Browse files
authored
Add DisableInstanceDiscovery option for private cloud auth (Azure#19872)
1 parent 71b4e49 commit f8d9350

34 files changed

+2224
-33
lines changed

sdk/azidentity/azidentity.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net/url"
1616
"os"
1717
"regexp"
18+
"strings"
1819

1920
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2021
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
@@ -56,20 +57,31 @@ var getConfidentialClient = func(clientID, tenantID string, cred confidential.Cr
5657
confidential.WithHTTPClient(newPipelineAdapter(co)),
5758
}
5859
o = append(o, additionalOpts...)
60+
if strings.ToLower(tenantID) == "adfs" {
61+
o = append(o, confidential.WithInstanceDiscovery(false))
62+
}
5963
return confidential.New(clientID, cred, o...)
6064
}
6165

62-
var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) {
66+
var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions, additionalOpts ...public.Option) (public.Client, error) {
6367
if !validTenantID(tenantID) {
6468
return public.Client{}, errors.New(tenantIDValidationErr)
6569
}
6670
authorityHost, err := setAuthorityHost(co.Cloud)
6771
if err != nil {
6872
return public.Client{}, err
6973
}
70-
return public.New(clientID,
74+
75+
o := []public.Option{
7176
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
7277
public.WithHTTPClient(newPipelineAdapter(co)),
78+
}
79+
o = append(o, additionalOpts...)
80+
if strings.ToLower(tenantID) == "adfs" {
81+
o = append(o, public.WithInstanceDiscovery(false))
82+
}
83+
return public.New(clientID,
84+
o...,
7385
)
7486
}
7587

sdk/azidentity/client_assertion_credential.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ type ClientAssertionCredential struct {
3232
// ClientAssertionCredentialOptions contains optional parameters for ClientAssertionCredential.
3333
type ClientAssertionCredentialOptions struct {
3434
azcore.ClientOptions
35+
36+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
37+
DisableInstanceDiscovery bool
3538
}
3639

3740
// NewClientAssertionCredential constructs a ClientAssertionCredential. The getAssertion function must be thread safe. Pass nil for options to accept defaults.
@@ -47,7 +50,7 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c
4750
return getAssertion(ctx)
4851
},
4952
)
50-
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions)
53+
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
5154
if err != nil {
5255
return nil, err
5356
}

sdk/azidentity/client_certificate_credential.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ const credNameCert = "ClientCertificateCredential"
2525
type ClientCertificateCredentialOptions struct {
2626
azcore.ClientOptions
2727

28+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
29+
DisableInstanceDiscovery bool
30+
2831
// SendCertificateChain controls whether the credential sends the public certificate chain in the x5c
2932
// header of each token request's JWT. This is required for Subject Name/Issuer (SNI) authentication.
3033
// Defaults to False.
@@ -52,6 +55,7 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
5255
if options.SendCertificateChain {
5356
o = append(o, confidential.WithX5C())
5457
}
58+
o = append(o, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
5559
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, o...)
5660
if err != nil {
5761
return nil, err

sdk/azidentity/client_certificate_credential_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2020
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
2121
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
22+
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
2223
)
2324

2425
type certTest struct {
@@ -224,6 +225,52 @@ func TestClientCertificateCredential_Live(t *testing.T) {
224225
testGetTokenSuccess(t, cred)
225226
})
226227
}
228+
t.Run("instance discovery disabled", func(t *testing.T) {
229+
if liveSP.pemPath == "" {
230+
t.Skip("no certificate file specified")
231+
}
232+
certData, err := os.ReadFile(liveSP.pemPath)
233+
if err != nil {
234+
t.Fatalf(`failed to read cert: %v`, err)
235+
}
236+
certs, key, err := ParseCertificates(certData, nil)
237+
if err != nil {
238+
t.Fatalf(`failed to parse cert: %v`, err)
239+
}
240+
o, stop := initRecording(t)
241+
defer stop()
242+
opts := &ClientCertificateCredentialOptions{ClientOptions: o, DisableInstanceDiscovery: true}
243+
cred, err := NewClientCertificateCredential(liveSP.tenantID, liveSP.clientID, certs, key, opts)
244+
if err != nil {
245+
t.Fatalf("failed to construct credential: %v", err)
246+
}
247+
testGetTokenSuccess(t, cred)
248+
})
249+
}
250+
251+
func TestClientCertificateCredentialADFS_Live(t *testing.T) {
252+
if recording.GetRecordMode() != recording.PlaybackMode {
253+
if adfsLiveSP.clientID == "" || adfsLiveSP.certPath == "" || adfsScope == "" {
254+
t.Skip("set ADFS_SP_* to run this test live")
255+
}
256+
}
257+
certData, err := os.ReadFile(adfsLiveSP.certPath)
258+
if err != nil {
259+
t.Fatalf(`failed to read cert: %v`, err)
260+
}
261+
certs, key, err := ParseCertificates(certData, nil)
262+
if err != nil {
263+
t.Fatalf(`failed to parse cert: %v`, err)
264+
}
265+
o, stop := initRecording(t)
266+
defer stop()
267+
o.Cloud.ActiveDirectoryAuthorityHost = adfsAuthority
268+
opts := &ClientCertificateCredentialOptions{ClientOptions: o, DisableInstanceDiscovery: true}
269+
cred, err := NewClientCertificateCredential("adfs", adfsLiveSP.clientID, certs, key, opts)
270+
if err != nil {
271+
t.Fatalf("failed to construct credential: %v", err)
272+
}
273+
testGetTokenSuccess(t, cred, adfsScope)
227274
}
228275

229276
func TestClientCertificateCredential_InvalidCertLive(t *testing.T) {

sdk/azidentity/client_secret_credential.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ const credNameSecret = "ClientSecretCredential"
2020
// ClientSecretCredentialOptions contains optional parameters for ClientSecretCredential.
2121
type ClientSecretCredentialOptions struct {
2222
azcore.ClientOptions
23+
24+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
25+
DisableInstanceDiscovery bool
2326
}
2427

2528
// ClientSecretCredential authenticates an application with a client secret.
@@ -36,7 +39,7 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st
3639
if err != nil {
3740
return nil, err
3841
}
39-
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions)
42+
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
4043
if err != nil {
4144
return nil, err
4245
}

sdk/azidentity/client_secret_credential_test.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
16+
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
1617
)
1718

1819
const secret = "secret"
@@ -40,14 +41,39 @@ func TestClientSecretCredential_GetTokenSuccess(t *testing.T) {
4041
}
4142

4243
func TestClientSecretCredential_Live(t *testing.T) {
44+
for _, disabledID := range []bool{true, false} {
45+
name := "default options"
46+
if disabledID {
47+
name = "instance discovery disabled"
48+
}
49+
t.Run(name, func(t *testing.T) {
50+
opts, stop := initRecording(t)
51+
defer stop()
52+
o := ClientSecretCredentialOptions{ClientOptions: opts, DisableInstanceDiscovery: disabledID}
53+
cred, err := NewClientSecretCredential(liveSP.tenantID, liveSP.clientID, liveSP.secret, &o)
54+
if err != nil {
55+
t.Fatalf("failed to construct credential: %v", err)
56+
}
57+
testGetTokenSuccess(t, cred)
58+
})
59+
}
60+
}
61+
62+
func TestClientSecretCredentialADFS_Live(t *testing.T) {
63+
if recording.GetRecordMode() != recording.PlaybackMode {
64+
if adfsLiveSP.clientID == "" || adfsLiveSP.secret == "" || adfsScope == "" {
65+
t.Skip("set ADFS_SP_* environment variables to run this test live")
66+
}
67+
}
4368
opts, stop := initRecording(t)
4469
defer stop()
45-
o := ClientSecretCredentialOptions{ClientOptions: opts}
46-
cred, err := NewClientSecretCredential(liveSP.tenantID, liveSP.clientID, liveSP.secret, &o)
70+
opts.Cloud.ActiveDirectoryAuthorityHost = adfsAuthority
71+
o := ClientSecretCredentialOptions{ClientOptions: opts, DisableInstanceDiscovery: true}
72+
cred, err := NewClientSecretCredential("adfs", adfsLiveSP.clientID, adfsLiveSP.secret, &o)
4773
if err != nil {
4874
t.Fatalf("failed to construct credential: %v", err)
4975
}
50-
testGetTokenSuccess(t, cred)
76+
testGetTokenSuccess(t, cred, adfsScope)
5177
}
5278

5379
func TestClientSecretCredential_InvalidSecretLive(t *testing.T) {

sdk/azidentity/default_azure_credential.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ import (
2323
type DefaultAzureCredentialOptions struct {
2424
azcore.ClientOptions
2525

26+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
27+
DisableInstanceDiscovery bool
28+
2629
// TenantID identifies the tenant the Azure CLI should authenticate in.
2730
// Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI.
2831
TenantID string
@@ -56,7 +59,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
5659
options = &DefaultAzureCredentialOptions{}
5760
}
5861

59-
envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ClientOptions: options.ClientOptions})
62+
envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery})
6063
if err == nil {
6164
creds = append(creds, envCred)
6265
} else {

sdk/azidentity/device_code_credential.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@ const credNameDeviceCode = "DeviceCodeCredential"
2222
type DeviceCodeCredentialOptions struct {
2323
azcore.ClientOptions
2424

25+
// ClientID is the ID of the application users will authenticate to.
26+
// Defaults to the ID of an Azure development application.
27+
ClientID string
28+
29+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
30+
DisableInstanceDiscovery bool
31+
2532
// TenantID is the Azure Active Directory tenant the credential authenticates in. Defaults to the
2633
// "organizations" tenant, which can authenticate work and school accounts. Required for single-tenant
2734
// applications.
2835
TenantID string
29-
// ClientID is the ID of the application users will authenticate to.
30-
// Defaults to the ID of an Azure development application.
31-
ClientID string
36+
3237
// UserPrompt controls how the credential presents authentication instructions. The credential calls
3338
// this function with authentication details when it receives a device code. By default, the credential
3439
// prints these details to stdout.
@@ -78,7 +83,7 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC
7883
cp = *options
7984
}
8085
cp.init()
81-
c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions)
86+
c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions, public.WithInstanceDiscovery(!cp.DisableInstanceDiscovery))
8287
if err != nil {
8388
return nil, err
8489
}

sdk/azidentity/device_code_credential_test.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,44 @@ func TestDeviceCodeCredential_Live(t *testing.T) {
8989
if recording.GetRecordMode() != recording.PlaybackMode {
9090
t.Skip("this test requires manual recording and can't pass live in CI")
9191
}
92+
for _, disabledID := range []bool{true, false} {
93+
name := "default options"
94+
if disabledID {
95+
name = "instance discovery disabled"
96+
}
97+
t.Run(name, func(t *testing.T) {
98+
o, stop := initRecording(t)
99+
defer stop()
100+
opts := DeviceCodeCredentialOptions{TenantID: liveUser.tenantID, ClientOptions: o, DisableInstanceDiscovery: disabledID}
101+
if recording.GetRecordMode() == recording.PlaybackMode {
102+
opts.UserPrompt = func(ctx context.Context, m DeviceCodeMessage) error { return nil }
103+
}
104+
cred, err := NewDeviceCodeCredential(&opts)
105+
if err != nil {
106+
t.Fatal(err)
107+
}
108+
testGetTokenSuccess(t, cred)
109+
})
110+
}
111+
}
112+
113+
func TestDeviceCodeCredentialADFS_Live(t *testing.T) {
114+
if recording.GetRecordMode() != recording.PlaybackMode {
115+
t.Skip("this test requires manual recording and can't pass live in CI")
116+
}
117+
if adfsLiveSP.clientID == "" {
118+
t.Skip("set ADFS_SP_* environment variables to run this test")
119+
}
92120
o, stop := initRecording(t)
93121
defer stop()
94-
opts := DeviceCodeCredentialOptions{TenantID: liveUser.tenantID, ClientOptions: o}
122+
o.Cloud.ActiveDirectoryAuthorityHost = adfsAuthority
123+
opts := DeviceCodeCredentialOptions{TenantID: "adfs", ClientID: adfsLiveUser.clientID, ClientOptions: o, DisableInstanceDiscovery: true}
95124
if recording.GetRecordMode() == recording.PlaybackMode {
96125
opts.UserPrompt = func(ctx context.Context, m DeviceCodeMessage) error { return nil }
97126
}
98127
cred, err := NewDeviceCodeCredential(&opts)
99128
if err != nil {
100129
t.Fatal(err)
101130
}
102-
testGetTokenSuccess(t, cred)
131+
testGetTokenSuccess(t, cred, adfsScope)
103132
}

sdk/azidentity/environment_credential.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ const envVarSendCertChain = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN"
2323
// EnvironmentCredentialOptions contains optional parameters for EnvironmentCredential
2424
type EnvironmentCredentialOptions struct {
2525
azcore.ClientOptions
26+
27+
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
28+
DisableInstanceDiscovery bool
2629
}
2730

2831
// EnvironmentCredential authenticates a service principal with a secret or certificate, or a user with a password, depending
@@ -74,7 +77,7 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme
7477
}
7578
if clientSecret := os.Getenv(azureClientSecret); clientSecret != "" {
7679
log.Write(EventAuthentication, "EnvironmentCredential will authenticate with ClientSecretCredential")
77-
o := &ClientSecretCredentialOptions{ClientOptions: options.ClientOptions}
80+
o := &ClientSecretCredentialOptions{ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery}
7881
cred, err := NewClientSecretCredential(tenantID, clientID, clientSecret, o)
7982
if err != nil {
8083
return nil, err
@@ -95,7 +98,7 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme
9598
if err != nil {
9699
return nil, fmt.Errorf(`failed to load certificate from "%s": %v`, certPath, err)
97100
}
98-
o := &ClientCertificateCredentialOptions{ClientOptions: options.ClientOptions}
101+
o := &ClientCertificateCredentialOptions{ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery}
99102
if v, ok := os.LookupEnv(envVarSendCertChain); ok {
100103
o.SendCertificateChain = v == "1" || strings.ToLower(v) == "true"
101104
}
@@ -108,7 +111,7 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme
108111
if username := os.Getenv(azureUsername); username != "" {
109112
if password := os.Getenv(azurePassword); password != "" {
110113
log.Write(EventAuthentication, "EnvironmentCredential will authenticate with UsernamePasswordCredential")
111-
o := &UsernamePasswordCredentialOptions{ClientOptions: options.ClientOptions}
114+
o := &UsernamePasswordCredentialOptions{ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery}
112115
cred, err := NewUsernamePasswordCredential(tenantID, clientID, username, password, o)
113116
if err != nil {
114117
return nil, err

0 commit comments

Comments
 (0)