Skip to content

Commit 65685ab

Browse files
authored
Removed dependency on TokenRequestOptions.TenantID (Azure#17709)
It wasn't working anyways and has been removed in the upcoming release. Refactored the expiring resource machinery to use generics.
1 parent 33fa99a commit 65685ab

File tree

5 files changed

+28
-36
lines changed

5 files changed

+28
-36
lines changed

sdk/keyvault/internal/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Release History
22

3-
## 0.3.1 (Unreleased)
3+
## 0.4.0 (Unreleased)
44

55
### Features Added
66

77
### Breaking Changes
8+
* Updated `ExpiringResource` and its dependent types to use generics.
89

910
### Bugs Fixed
1011

1112
### Other Changes
13+
* Remove reference to `TokenRequestOptions.TenantID` as it's been removed and wasn't working anyways.
1214

1315
## 0.3.0 (2022-04-04)
1416

sdk/keyvault/internal/challenge_policy.go

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const bearerHeader = "Bearer "
2727

2828
type KeyVaultChallengePolicy struct {
2929
// mainResource is the resource to be retrieved using the tenant specified in the credential
30-
mainResource *ExpiringResource
30+
mainResource *ExpiringResource[*azcore.AccessToken, acquiringResourceState]
3131
cred azcore.TokenCredential
3232
scope *string
3333
tenantID *string
@@ -73,12 +73,10 @@ func (k *KeyVaultChallengePolicy) Do(req *policy.Request) (*http.Response, error
7373
return nil, err
7474
}
7575

76-
if token, ok := tk.(*azcore.AccessToken); ok {
77-
req.Raw().Header.Set(
78-
headerAuthorization,
79-
fmt.Sprintf("%s%s", bearerHeader, token.Token),
80-
)
81-
}
76+
req.Raw().Header.Set(
77+
headerAuthorization,
78+
fmt.Sprintf("%s%s", bearerHeader, tk.Token),
79+
)
8280

8381
// send a copy of the request
8482
cloneReq := req.Clone(req.Raw().Context())
@@ -104,15 +102,10 @@ func (k *KeyVaultChallengePolicy) Do(req *policy.Request) (*http.Response, error
104102
return resp, err
105103
}
106104

107-
if token, ok := tk.(*azcore.AccessToken); ok {
108-
req.Raw().Header.Set(
109-
headerAuthorization,
110-
bearerHeader+token.Token,
111-
)
112-
} else {
113-
// tk is not an azcore.AccessToken type, something went wrong and we should return the 401 and accompanying error
114-
return resp, cloneReqErr
115-
}
105+
req.Raw().Header.Set(
106+
headerAuthorization,
107+
bearerHeader+tk.Token,
108+
)
116109

117110
// send the original request now
118111
return req.Next()
@@ -220,13 +213,11 @@ type acquiringResourceState struct {
220213

221214
// acquire acquires or updates the resource; only one
222215
// thread/goroutine at a time ever calls this function
223-
func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
224-
s := state.(acquiringResourceState)
225-
tk, err := s.p.cred.GetToken(
226-
s.req.Raw().Context(),
216+
func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) {
217+
tk, err := state.p.cred.GetToken(
218+
state.req.Raw().Context(),
227219
policy.TokenRequestOptions{
228-
Scopes: []string{*s.p.scope},
229-
TenantID: *s.p.scope,
220+
Scopes: []string{*state.p.scope},
230221
},
231222
)
232223
if err != nil {

sdk/keyvault/internal/constants.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
package internal
88

99
const (
10-
version = "v0.3.1" //nolint
10+
version = "v0.4.0" //nolint
1111
)

sdk/keyvault/internal/expiring_resource.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,34 @@ import (
1212
)
1313

1414
// AcquireResource abstracts a method for refreshing an expiring resource.
15-
type AcquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error)
15+
type AcquireResource[TResource any, TState any] func(state TState) (newResource TResource, newExpiration time.Time, err error)
1616

1717
// ExpiringResource is a temporal resource (usually a credential), that requires periodic refreshing.
18-
type ExpiringResource struct {
18+
type ExpiringResource[TResource any, TState any] struct {
1919
// cond is used to synchronize access to the shared resource embodied by the remaining fields
2020
cond *sync.Cond
2121

2222
// acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource
2323
acquiring bool
2424

2525
// resource contains the value of the shared resource
26-
resource interface{}
26+
resource TResource
2727

2828
// expiration indicates when the shared resource expires; it is 0 if the resource was never acquired
2929
expiration time.Time
3030

3131
// acquireResource is the callback function that actually acquires the resource
32-
acquireResource AcquireResource
32+
acquireResource AcquireResource[TResource, TState]
3333
}
3434

3535
// NewExpiringResource creates a new ExpiringResource that uses the specified AcquireResource for refreshing.
36-
func NewExpiringResource(ar AcquireResource) *ExpiringResource {
37-
return &ExpiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
36+
func NewExpiringResource[TResource any, TState any](ar AcquireResource[TResource, TState]) *ExpiringResource[TResource, TState] {
37+
return &ExpiringResource[TResource, TState]{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
3838
}
3939

4040
// GetResource returns the underlying resource.
4141
// If the resource is fresh, no refresh is performed.
42-
func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) {
42+
func (er *ExpiringResource[TResource, TState]) GetResource(state TState) (TResource, error) {
4343
// If the resource is expiring within this time window, update it eagerly.
4444
// This allows other threads/goroutines to keep running by using the not-yet-expired
4545
// resource value while one thread/goroutine updates the resource.
@@ -98,7 +98,7 @@ func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error)
9898
return resource, err // Return the resource this thread/goroutine can use
9999
}
100100

101-
func (er *ExpiringResource) Reset() {
101+
func (er *ExpiringResource[TResource, TState]) Reset() {
102102
// acquire exclusive lock
103103
er.cond.L.Lock()
104104

sdk/keyvault/internal/expiring_resource_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@ import (
1515
)
1616

1717
func TestNewExpiringResource(t *testing.T) {
18-
er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
19-
s := state.(string)
20-
switch s {
18+
er := NewExpiringResource(func(state string) (newResource string, newExpiration time.Time, err error) {
19+
switch state {
2120
case "initial":
2221
return "updated", time.Now(), nil
2322
case "updated":
2423
return "refreshed", time.Now().Add(1 * time.Hour), nil
2524
default:
26-
t.Fatalf("unexpected state %s", s)
25+
t.Fatalf("unexpected state %s", state)
2726
return "", time.Time{}, errors.New("unexpected")
2827
}
2928
})

0 commit comments

Comments
 (0)