Skip to content

Commit 1cd56ad

Browse files
authored
Workload identity credential defaults to environment configuration (Azure#20478)
1 parent 9b0eb84 commit 1cd56ad

File tree

5 files changed

+155
-63
lines changed

5 files changed

+155
-63
lines changed

sdk/azidentity/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
### Breaking Changes
88
> These changes affect only code written against a beta version such as v1.3.0-beta.4
9+
* Moved `NewWorkloadIdentityCredential()` parameters into `WorkloadIdentityCredentialOptions`.
10+
The constructor now reads default configuration from environment variables set by the Azure
11+
workload identity webhook by default.
12+
([#20478](https://github.com/Azure/azure-sdk-for-go/pull/20478))
913
* Removed CAE support. It will return in the next beta release.
14+
([#20479](https://github.com/Azure/azure-sdk-for-go/pull/20479))
1015

1116
### Bugs Fixed
1217
* Fixed an issue in `DefaultAzureCredential` that could cause the managed identity endpoint check to fail in rare circumstances.

sdk/azidentity/azidentity_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,13 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
393393
{
394394
name: credNameWorkloadIdentity,
395395
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
396-
o := WorkloadIdentityCredentialOptions{AdditionallyAllowedTenants: test.allowed, ClientOptions: co}
397-
return NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, af, &o)
396+
return NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
397+
AdditionallyAllowedTenants: test.allowed,
398+
ClientID: fakeClientID,
399+
ClientOptions: co,
400+
TenantID: fakeTenantID,
401+
TokenFilePath: af,
402+
})
398403
},
399404
},
400405
{

sdk/azidentity/default_azure_credential.go

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -79,36 +79,20 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
7979
}
8080

8181
// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
82-
haveWorkloadConfig := false
83-
clientID, haveClientID := os.LookupEnv(azureClientID)
84-
if haveClientID {
85-
if file, ok := os.LookupEnv(azureFederatedTokenFile); ok {
86-
if _, ok := os.LookupEnv(azureAuthorityHost); ok {
87-
if tenantID, ok := os.LookupEnv(azureTenantID); ok {
88-
haveWorkloadConfig = true
89-
workloadCred, err := NewWorkloadIdentityCredential(tenantID, clientID, file, &WorkloadIdentityCredentialOptions{
90-
AdditionallyAllowedTenants: additionalTenants,
91-
ClientOptions: options.ClientOptions,
92-
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
93-
})
94-
if err == nil {
95-
creds = append(creds, workloadCred)
96-
} else {
97-
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
98-
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
99-
}
100-
}
101-
}
102-
}
103-
}
104-
if !haveWorkloadConfig {
105-
err := errors.New("missing environment variables for workload identity. Check webhook and pod configuration")
82+
wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
83+
AdditionallyAllowedTenants: additionalTenants,
84+
ClientOptions: options.ClientOptions,
85+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
86+
})
87+
if err == nil {
88+
creds = append(creds, wic)
89+
} else {
90+
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
10691
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
10792
}
108-
10993
o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
110-
if haveClientID {
111-
o.ID = ClientID(clientID)
94+
if ID, ok := os.LookupEnv(azureClientID); ok {
95+
o.ID = ClientID(ID)
11296
}
11397
miCred, err := NewManagedIdentityCredential(o)
11498
if err == nil {

sdk/azidentity/workload_identity.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package azidentity
88

99
import (
1010
"context"
11+
"errors"
1112
"os"
1213
"sync"
1314
"time"
@@ -37,16 +38,42 @@ type WorkloadIdentityCredentialOptions struct {
3738
// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the
3839
// application is registered.
3940
AdditionallyAllowedTenants []string
41+
// ClientID of the service principal. Defaults to the value of the environment variable AZURE_CLIENT_ID.
42+
ClientID string
4043
// DisableInstanceDiscovery allows disconnected cloud solutions to skip instance discovery for unknown authority hosts.
4144
DisableInstanceDiscovery bool
45+
// TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID.
46+
TenantID string
47+
// TokenFilePath is the path a file containing the workload identity token. Defaults to the value of the
48+
// environment variable AZURE_FEDERATED_TOKEN_FILE.
49+
TokenFilePath string
4250
}
4351

44-
// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. tenantID and clientID specify the identity the credential authenticates.
45-
// file is a path to a file containing a Kubernetes service account token that authenticates the identity.
46-
func NewWorkloadIdentityCredential(tenantID, clientID, file string, options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
52+
// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read
53+
// from environment variables as set by the Azure workload identity webhook. Set options to override those values.
54+
func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (*WorkloadIdentityCredential, error) {
4755
if options == nil {
4856
options = &WorkloadIdentityCredentialOptions{}
4957
}
58+
ok := false
59+
clientID := options.ClientID
60+
if clientID == "" {
61+
if clientID, ok = os.LookupEnv(azureClientID); !ok {
62+
return nil, errors.New("no client ID specified. Check pod configuration or set ClientID in the options")
63+
}
64+
}
65+
file := options.TokenFilePath
66+
if file == "" {
67+
if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok {
68+
return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options")
69+
}
70+
}
71+
tenantID := options.TenantID
72+
if tenantID == "" {
73+
if tenantID, ok = os.LookupEnv(azureTenantID); !ok {
74+
return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options")
75+
}
76+
}
5077
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
5178
caco := ClientAssertionCredentialOptions{
5279
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,

sdk/azidentity/workload_identity_test.go

Lines changed: 102 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"time"
2424

2525
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
26-
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
2726
"github.com/golang-jwt/jwt/v4"
2827
"github.com/google/uuid"
2928
)
@@ -71,8 +70,13 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) {
7170
t.Run(name, func(t *testing.T) {
7271
co, stop := initRecording(t)
7372
defer stop()
74-
o := WorkloadIdentityCredentialOptions{ClientOptions: co, DisableInstanceDiscovery: b}
75-
cred, err := NewWorkloadIdentityCredential(liveSP.tenantID, liveSP.clientID, f, &o)
73+
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
74+
ClientID: liveSP.clientID,
75+
ClientOptions: co,
76+
DisableInstanceDiscovery: b,
77+
TenantID: liveSP.tenantID,
78+
TokenFilePath: f,
79+
})
7680
if err != nil {
7781
t.Fatal(err)
7882
}
@@ -86,7 +90,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
8690
if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil {
8791
t.Fatalf("failed to write token file: %v", err)
8892
}
89-
validateReq := func(req *http.Request) bool {
93+
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
9094
if err := req.ParseForm(); err != nil {
9195
t.Error(err)
9296
}
@@ -103,18 +107,13 @@ func TestWorkloadIdentityCredential(t *testing.T) {
103107
if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID {
104108
t.Errorf(`unexpected tenant "%s"`, actual)
105109
}
106-
return true
107-
}
108-
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
109-
defer close()
110-
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
111-
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
112-
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
113-
srv.AppendResponse()
114-
opts := WorkloadIdentityCredentialOptions{
115-
ClientOptions: policy.ClientOptions{Transport: srv},
116-
}
117-
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
110+
}}
111+
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
112+
ClientID: fakeClientID,
113+
ClientOptions: policy.ClientOptions{Transport: &sts},
114+
TenantID: fakeTenantID,
115+
TokenFilePath: tempFile,
116+
})
118117
if err != nil {
119118
t.Fatal(err)
120119
}
@@ -124,7 +123,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
124123
func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
125124
tokenReqs := 0
126125
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
127-
validateReq := func(req *http.Request) bool {
126+
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) {
128127
if err := req.ParseForm(); err != nil {
129128
t.Error(err)
130129
}
@@ -134,20 +133,13 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
134133
t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0])
135134
}
136135
tokenReqs++
137-
return true
138-
}
139-
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
140-
defer close()
141-
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
142-
srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse))
143-
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
144-
srv.AppendResponse()
145-
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess))
146-
srv.AppendResponse()
147-
opts := WorkloadIdentityCredentialOptions{
148-
ClientOptions: policy.ClientOptions{Transport: srv},
149-
}
150-
cred, err := NewWorkloadIdentityCredential(fakeTenantID, fakeClientID, tempFile, &opts)
136+
}}
137+
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
138+
ClientID: fakeClientID,
139+
ClientOptions: policy.ClientOptions{Transport: &sts},
140+
TenantID: fakeTenantID,
141+
TokenFilePath: tempFile,
142+
})
151143
if err != nil {
152144
t.Fatal(err)
153145
}
@@ -167,3 +159,82 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
167159
t.Fatalf("expected 2 token requests, got %d", tokenReqs)
168160
}
169161
}
162+
163+
func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
164+
f := filepath.Join(t.TempDir(), t.Name())
165+
for _, env := range []map[string]string{
166+
{},
167+
168+
{azureClientID: fakeClientID},
169+
{azureFederatedTokenFile: f},
170+
{azureTenantID: fakeTenantID},
171+
172+
{azureClientID: fakeClientID, azureTenantID: fakeTenantID},
173+
{azureClientID: fakeClientID, azureFederatedTokenFile: f},
174+
{azureTenantID: fakeTenantID, azureFederatedTokenFile: f},
175+
} {
176+
t.Run("", func(t *testing.T) {
177+
for k, v := range env {
178+
t.Setenv(k, v)
179+
}
180+
if _, err := NewWorkloadIdentityCredential(nil); err == nil {
181+
t.Fatal("expected an error")
182+
}
183+
})
184+
}
185+
}
186+
187+
func TestWorkloadIdentityCredential_Options(t *testing.T) {
188+
clientID := "not-" + fakeClientID
189+
tenantID := "not-" + fakeTenantID
190+
wrongFile := filepath.Join(t.TempDir(), "wrong")
191+
rightFile := filepath.Join(t.TempDir(), "right")
192+
if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil {
193+
t.Fatal(err)
194+
}
195+
sts := mockSTS{
196+
tenant: tenantID,
197+
tokenRequestCallback: func(req *http.Request) {
198+
if err := req.ParseForm(); err != nil {
199+
t.Error(err)
200+
}
201+
if actual, ok := req.PostForm["client_assertion"]; !ok {
202+
t.Error("expected a client_assertion")
203+
} else if len(actual) != 1 || actual[0] != tokenValue {
204+
t.Errorf(`unexpected assertion "%s"`, actual[0])
205+
}
206+
if actual, ok := req.PostForm["client_id"]; !ok {
207+
t.Error("expected a client_id")
208+
} else if len(actual) != 1 || actual[0] != clientID {
209+
t.Errorf(`unexpected assertion "%s"`, actual[0])
210+
}
211+
if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID {
212+
t.Errorf(`unexpected tenant "%s"`, actual)
213+
}
214+
},
215+
}
216+
// options should override environment variables
217+
for k, v := range map[string]string{
218+
azureClientID: fakeClientID,
219+
azureFederatedTokenFile: wrongFile,
220+
azureTenantID: fakeTenantID,
221+
} {
222+
t.Setenv(k, v)
223+
}
224+
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
225+
ClientID: clientID,
226+
ClientOptions: policy.ClientOptions{Transport: &sts},
227+
TenantID: tenantID,
228+
TokenFilePath: rightFile,
229+
})
230+
if err != nil {
231+
t.Fatal(err)
232+
}
233+
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
234+
if err != nil {
235+
t.Fatal(err)
236+
}
237+
if tk.Token != tokenValue {
238+
t.Fatalf("unexpected token %q", tk.Token)
239+
}
240+
}

0 commit comments

Comments
 (0)