@@ -23,7 +23,6 @@ import (
2323 "time"
2424
2525 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
26- "github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
2726 "github.com/golang-jwt/jwt/v4"
2827 "github.com/google/uuid"
2928)
@@ -71,8 +70,13 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) {
7170 t .Run (name , func (t * testing.T ) {
7271 co , stop := initRecording (t )
7372 defer stop ()
74- o := WorkloadIdentityCredentialOptions {ClientOptions : co , DisableInstanceDiscovery : b }
75- cred , err := NewWorkloadIdentityCredential (liveSP .tenantID , liveSP .clientID , f , & o )
73+ cred , err := NewWorkloadIdentityCredential (& WorkloadIdentityCredentialOptions {
74+ ClientID : liveSP .clientID ,
75+ ClientOptions : co ,
76+ DisableInstanceDiscovery : b ,
77+ TenantID : liveSP .tenantID ,
78+ TokenFilePath : f ,
79+ })
7680 if err != nil {
7781 t .Fatal (err )
7882 }
@@ -86,7 +90,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
8690 if err := os .WriteFile (tempFile , []byte (tokenValue ), os .ModePerm ); err != nil {
8791 t .Fatalf ("failed to write token file: %v" , err )
8892 }
89- validateReq := func (req * http.Request ) bool {
93+ sts := mockSTS { tenant : fakeTenantID , tokenRequestCallback : func (req * http.Request ) {
9094 if err := req .ParseForm (); err != nil {
9195 t .Error (err )
9296 }
@@ -103,18 +107,13 @@ func TestWorkloadIdentityCredential(t *testing.T) {
103107 if actual := strings .Split (req .URL .Path , "/" )[1 ]; actual != fakeTenantID {
104108 t .Errorf (`unexpected tenant "%s"` , actual )
105109 }
106- return true
107- }
108- srv , close := mock .NewServer (mock .WithTransformAllRequestsToTestServerUrl ())
109- defer close ()
110- srv .AppendResponse (mock .WithBody (instanceDiscoveryResponse ))
111- srv .AppendResponse (mock .WithBody (tenantDiscoveryResponse ))
112- srv .AppendResponse (mock .WithPredicate (validateReq ), mock .WithBody (accessTokenRespSuccess ))
113- srv .AppendResponse ()
114- opts := WorkloadIdentityCredentialOptions {
115- ClientOptions : policy.ClientOptions {Transport : srv },
116- }
117- cred , err := NewWorkloadIdentityCredential (fakeTenantID , fakeClientID , tempFile , & opts )
110+ }}
111+ cred , err := NewWorkloadIdentityCredential (& WorkloadIdentityCredentialOptions {
112+ ClientID : fakeClientID ,
113+ ClientOptions : policy.ClientOptions {Transport : & sts },
114+ TenantID : fakeTenantID ,
115+ TokenFilePath : tempFile ,
116+ })
118117 if err != nil {
119118 t .Fatal (err )
120119 }
@@ -124,7 +123,7 @@ func TestWorkloadIdentityCredential(t *testing.T) {
124123func TestWorkloadIdentityCredential_Expiration (t * testing.T ) {
125124 tokenReqs := 0
126125 tempFile := filepath .Join (t .TempDir (), "test-workload-token-file" )
127- validateReq := func (req * http.Request ) bool {
126+ sts := mockSTS { tenant : fakeTenantID , tokenRequestCallback : func (req * http.Request ) {
128127 if err := req .ParseForm (); err != nil {
129128 t .Error (err )
130129 }
@@ -134,20 +133,13 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
134133 t .Errorf (`expected assertion "%d", got "%s"` , tokenReqs , actual [0 ])
135134 }
136135 tokenReqs ++
137- return true
138- }
139- srv , close := mock .NewServer (mock .WithTransformAllRequestsToTestServerUrl ())
140- defer close ()
141- srv .AppendResponse (mock .WithBody (instanceDiscoveryResponse ))
142- srv .AppendResponse (mock .WithBody (tenantDiscoveryResponse ))
143- srv .AppendResponse (mock .WithPredicate (validateReq ), mock .WithBody (accessTokenRespSuccess ))
144- srv .AppendResponse ()
145- srv .AppendResponse (mock .WithPredicate (validateReq ), mock .WithBody (accessTokenRespSuccess ))
146- srv .AppendResponse ()
147- opts := WorkloadIdentityCredentialOptions {
148- ClientOptions : policy.ClientOptions {Transport : srv },
149- }
150- cred , err := NewWorkloadIdentityCredential (fakeTenantID , fakeClientID , tempFile , & opts )
136+ }}
137+ cred , err := NewWorkloadIdentityCredential (& WorkloadIdentityCredentialOptions {
138+ ClientID : fakeClientID ,
139+ ClientOptions : policy.ClientOptions {Transport : & sts },
140+ TenantID : fakeTenantID ,
141+ TokenFilePath : tempFile ,
142+ })
151143 if err != nil {
152144 t .Fatal (err )
153145 }
@@ -167,3 +159,82 @@ func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
167159 t .Fatalf ("expected 2 token requests, got %d" , tokenReqs )
168160 }
169161}
162+
163+ func TestTestWorkloadIdentityCredential_IncompleteConfig (t * testing.T ) {
164+ f := filepath .Join (t .TempDir (), t .Name ())
165+ for _ , env := range []map [string ]string {
166+ {},
167+
168+ {azureClientID : fakeClientID },
169+ {azureFederatedTokenFile : f },
170+ {azureTenantID : fakeTenantID },
171+
172+ {azureClientID : fakeClientID , azureTenantID : fakeTenantID },
173+ {azureClientID : fakeClientID , azureFederatedTokenFile : f },
174+ {azureTenantID : fakeTenantID , azureFederatedTokenFile : f },
175+ } {
176+ t .Run ("" , func (t * testing.T ) {
177+ for k , v := range env {
178+ t .Setenv (k , v )
179+ }
180+ if _ , err := NewWorkloadIdentityCredential (nil ); err == nil {
181+ t .Fatal ("expected an error" )
182+ }
183+ })
184+ }
185+ }
186+
187+ func TestWorkloadIdentityCredential_Options (t * testing.T ) {
188+ clientID := "not-" + fakeClientID
189+ tenantID := "not-" + fakeTenantID
190+ wrongFile := filepath .Join (t .TempDir (), "wrong" )
191+ rightFile := filepath .Join (t .TempDir (), "right" )
192+ if err := os .WriteFile (rightFile , []byte (tokenValue ), os .ModePerm ); err != nil {
193+ t .Fatal (err )
194+ }
195+ sts := mockSTS {
196+ tenant : tenantID ,
197+ tokenRequestCallback : func (req * http.Request ) {
198+ if err := req .ParseForm (); err != nil {
199+ t .Error (err )
200+ }
201+ if actual , ok := req .PostForm ["client_assertion" ]; ! ok {
202+ t .Error ("expected a client_assertion" )
203+ } else if len (actual ) != 1 || actual [0 ] != tokenValue {
204+ t .Errorf (`unexpected assertion "%s"` , actual [0 ])
205+ }
206+ if actual , ok := req .PostForm ["client_id" ]; ! ok {
207+ t .Error ("expected a client_id" )
208+ } else if len (actual ) != 1 || actual [0 ] != clientID {
209+ t .Errorf (`unexpected assertion "%s"` , actual [0 ])
210+ }
211+ if actual := strings .Split (req .URL .Path , "/" )[1 ]; actual != tenantID {
212+ t .Errorf (`unexpected tenant "%s"` , actual )
213+ }
214+ },
215+ }
216+ // options should override environment variables
217+ for k , v := range map [string ]string {
218+ azureClientID : fakeClientID ,
219+ azureFederatedTokenFile : wrongFile ,
220+ azureTenantID : fakeTenantID ,
221+ } {
222+ t .Setenv (k , v )
223+ }
224+ cred , err := NewWorkloadIdentityCredential (& WorkloadIdentityCredentialOptions {
225+ ClientID : clientID ,
226+ ClientOptions : policy.ClientOptions {Transport : & sts },
227+ TenantID : tenantID ,
228+ TokenFilePath : rightFile ,
229+ })
230+ if err != nil {
231+ t .Fatal (err )
232+ }
233+ tk , err := cred .GetToken (context .Background (), policy.TokenRequestOptions {Scopes : []string {liveTestScope }})
234+ if err != nil {
235+ t .Fatal (err )
236+ }
237+ if tk .Token != tokenValue {
238+ t .Fatalf ("unexpected token %q" , tk .Token )
239+ }
240+ }
0 commit comments