Skip to content

Commit d772b7c

Browse files
authored
Fix test failure because of "CONTENT_LEN_INVALID" (Azure#20034)
* fix test failure because of "CONTENT_LEN_INVALID" * fix lint * no need to redo challenge request when do retry request * rewrite auth policy to reduce auth requests time for each request * add changelog item * set test run time to 30m as ACR service is too slow sometimes * refine with code review * fix lint * prepare release * add synchronization
1 parent d85e5c3 commit d772b7c

File tree

5 files changed

+110
-95
lines changed

5 files changed

+110
-95
lines changed

sdk/containers/azcontainerregistry/CHANGELOG.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Release History
22

3-
## 0.1.1 (Unreleased)
4-
5-
### Features Added
6-
7-
### Breaking Changes
3+
## 0.1.1 (2023-03-07)
84

95
### Bugs Fixed
6+
* Fix possible failure when request retry
107

118
### Other Changes
9+
* Rewrite auth policy to promote efficiency of auth process
1210

1311
## 0.1.0 (2023-02-07)
1412

sdk/containers/azcontainerregistry/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "go",
44
"TagPrefix": "go/containers/azcontainerregistry",
5-
"Tag": "go/containers/azcontainerregistry_36bdeb68b8"
5+
"Tag": "go/containers/azcontainerregistry_4e940b1981"
66
}

sdk/containers/azcontainerregistry/authentication_policy.go

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
package azcontainerregistry
88

99
import (
10-
"bytes"
1110
"encoding/base64"
1211
"encoding/json"
1312
"errors"
1413
"fmt"
15-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
16-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
1714
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
1815
"net/http"
1916
"strings"
17+
"sync/atomic"
2018
"time"
2119

2220
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
@@ -32,93 +30,120 @@ const (
3230
type authenticationPolicyOptions struct {
3331
}
3432

33+
// authenticationPolicy is a policy to do the challenge-based authentication for container registry service. The authorization flow is as follows:
34+
// Step 1: GET /api/v1/acr/repositories
35+
// Return Header: 401: www-authenticate header - Bearer realm="{url}",service="{serviceName}",scope="{scope}",error="invalid_token"
36+
// Step 2: Retrieve the serviceName, scope from the WWW-Authenticate header.
37+
// Step 3: POST /api/oauth2/exchange
38+
// Request Body : { service, scope, grant-type, aadToken with ARM scope }
39+
// Response Body: { refreshToken }
40+
// Step 4: POST /api/oauth2/token
41+
// Request Body: { refreshToken, scope, grant-type }
42+
// Response Body: { accessToken }
43+
// Step 5: GET /api/v1/acr/repositories
44+
// Request Header: { Bearer acrTokenAccess }
45+
// Each registry service shares one refresh token, it will be cached in refreshTokenCache until expire time.
46+
// Since the scope will be different for different API/repository/artifact, accessTokenCache will only work when continuously calling same API.
3547
type authenticationPolicy struct {
36-
mainResource *temporal.Resource[azcore.AccessToken, acquiringResourceState]
37-
cred azcore.TokenCredential
38-
aadScopes []string
39-
acrScope string
40-
acrService string
41-
authClient *authenticationClient
48+
refreshTokenCache *temporal.Resource[azcore.AccessToken, acquiringResourceState]
49+
accessTokenCache atomic.Value
50+
cred azcore.TokenCredential
51+
aadScopes []string
52+
authClient *authenticationClient
4253
}
4354

4455
func newAuthenticationPolicy(cred azcore.TokenCredential, scopes []string, authClient *authenticationClient, opts *authenticationPolicyOptions) *authenticationPolicy {
4556
return &authenticationPolicy{
46-
cred: cred,
47-
aadScopes: scopes,
48-
authClient: authClient,
49-
mainResource: temporal.NewResource(acquire),
57+
cred: cred,
58+
aadScopes: scopes,
59+
authClient: authClient,
60+
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
5061
}
5162
}
5263

5364
func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
54-
// send a copy of the original request without body content
55-
challengeReq, err := p.getChallengeRequest(*req)
56-
if err != nil {
57-
return nil, err
65+
var resp *http.Response
66+
var err error
67+
if req.Raw().Header.Get(headerAuthorization) != "" {
68+
// retry request could do the request with existed token directly
69+
resp, err = req.Next()
70+
} else if accessToken := p.accessTokenCache.Load(); accessToken != nil && accessToken != "" {
71+
// if there is a previous access token, then we try to use this token to do the request
72+
req.Raw().Header.Set(
73+
headerAuthorization,
74+
fmt.Sprintf("%s%s", bearerHeader, accessToken),
75+
)
76+
resp, err = req.Next()
77+
} else {
78+
// do challenge process for the initial request
79+
var challengeReq *policy.Request
80+
challengeReq, err = p.getChallengeRequest(*req)
81+
if err != nil {
82+
return nil, err
83+
}
84+
resp, err = challengeReq.Next()
5885
}
59-
resp, err := challengeReq.Next()
6086
if err != nil {
6187
return nil, err
6288
}
6389

64-
// do challenge process
65-
if resp.StatusCode == 401 {
66-
err := p.findServiceAndScope(resp)
67-
if err != nil {
90+
// if 401 response, then try to get access token
91+
if resp.StatusCode == http.StatusUnauthorized {
92+
var service, scope, accessToken string
93+
if service, scope, err = findServiceAndScope(resp); err != nil {
6894
return nil, err
6995
}
70-
71-
accessToken, err := p.getAccessToken(req)
72-
if err != nil {
96+
if accessToken, err = p.getAccessToken(req, service, scope); err != nil {
7397
return nil, err
7498
}
75-
99+
p.accessTokenCache.Store(accessToken)
76100
req.Raw().Header.Set(
77101
headerAuthorization,
78102
fmt.Sprintf("%s%s", bearerHeader, accessToken),
79103
)
80-
81-
// send the original request with auth
104+
// since the request may already been used once, body should be rewound
105+
if err = req.RewindBody(); err != nil {
106+
return nil, err
107+
}
82108
return req.Next()
83109
}
84110

85111
return resp, nil
86112
}
87113

88-
func (p *authenticationPolicy) getAccessToken(req *policy.Request) (string, error) {
114+
func (p *authenticationPolicy) getAccessToken(req *policy.Request, service, scope string) (string, error) {
89115
// anonymous access
90116
if p.cred == nil {
91-
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), p.acrService, p.acrScope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
117+
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
92118
if err != nil {
93119
return "", err
94120
}
95121
return *resp.acrAccessToken.AccessToken, nil
96122
}
97123

98124
// access with token
99-
as := acquiringResourceState{
100-
policy: p,
101-
req: req,
102-
}
103-
104125
// get refresh token from cache/request
105-
refreshToken, err := p.mainResource.Get(as)
126+
refreshToken, err := p.refreshTokenCache.Get(acquiringResourceState{
127+
policy: p,
128+
req: req,
129+
service: service,
130+
})
106131
if err != nil {
107132
return "", err
108133
}
109134

110135
// get access token from request
111-
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), p.acrService, p.acrScope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
136+
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
112137
if err != nil {
113138
return "", err
114139
}
115140
return *resp.acrAccessToken.AccessToken, nil
116141
}
117142

118-
func (p *authenticationPolicy) findServiceAndScope(resp *http.Response) error {
143+
func findServiceAndScope(resp *http.Response) (string, string, error) {
119144
authHeader := resp.Header.Get("WWW-Authenticate")
120145
if authHeader == "" {
121-
return errors.New("response has no WWW-Authenticate header for challenge authentication")
146+
return "", "", errors.New("response has no WWW-Authenticate header for challenge authentication")
122147
}
123148

124149
authHeader = strings.ReplaceAll(authHeader, "Bearer ", "")
@@ -131,54 +156,35 @@ func (p *authenticationPolicy) findServiceAndScope(resp *http.Response) error {
131156
}
132157
}
133158

134-
if v, ok := valuesMap["scope"]; ok {
135-
p.acrScope = v
136-
}
137-
if p.acrScope == "" {
138-
return errors.New("could not find a valid scope in the WWW-Authenticate header")
159+
if _, ok := valuesMap["service"]; !ok {
160+
return "", "", errors.New("could not find a valid service in the WWW-Authenticate header")
139161
}
140162

141-
if v, ok := valuesMap["service"]; ok {
142-
p.acrService = v
143-
}
144-
if p.acrService == "" {
145-
return errors.New("could not find a valid service in the WWW-Authenticate header")
163+
if _, ok := valuesMap["scope"]; !ok {
164+
return "", "", errors.New("could not find a valid scope in the WWW-Authenticate header")
146165
}
147166

148-
return nil
167+
return valuesMap["service"], valuesMap["scope"], nil
149168
}
150169

151-
func (p authenticationPolicy) getChallengeRequest(orig policy.Request) (*policy.Request, error) {
152-
req, err := runtime.NewRequest(orig.Raw().Context(), orig.Raw().Method, orig.Raw().URL.String())
153-
if err != nil {
154-
return nil, err
155-
}
156-
157-
req.Raw().Header = orig.Raw().Header
158-
req.Raw().Header.Set("Content-Length", "0")
159-
req.Raw().ContentLength = 0
160-
161-
copied := orig.Clone(orig.Raw().Context())
162-
copied.Raw().Body = req.Body()
163-
copied.Raw().ContentLength = 0
164-
copied.Raw().Header.Set("Content-Length", "0")
165-
err = copied.SetBody(streaming.NopCloser(bytes.NewReader([]byte{})), "application/json")
170+
func (p authenticationPolicy) getChallengeRequest(oriReq policy.Request) (*policy.Request, error) {
171+
copied := oriReq.Clone(oriReq.Raw().Context())
172+
err := copied.SetBody(nil, "")
166173
if err != nil {
167174
return nil, err
168175
}
169176
copied.Raw().Header.Del("Content-Type")
170-
171-
return copied, err
177+
return copied, nil
172178
}
173179

174180
type acquiringResourceState struct {
175-
req *policy.Request
176-
policy *authenticationPolicy
181+
req *policy.Request
182+
policy *authenticationPolicy
183+
service string
177184
}
178185

179-
// acquire acquires or updates the resource; only one
180-
// thread/goroutine at a time ever calls this function
181-
func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
186+
// acquireRefreshToken acquires or updates the refresh token of ACR service; only one thread/goroutine at a time ever calls this function
187+
func acquireRefreshToken(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
182188
// get AAD token from credential
183189
aadToken, err := state.policy.cred.GetToken(
184190
state.req.Raw().Context(),
@@ -191,7 +197,7 @@ func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newE
191197
}
192198

193199
// exchange refresh token with AAD token
194-
refreshResp, err := state.policy.authClient.ExchangeAADAccessTokenForACRRefreshToken(state.req.Raw().Context(), postContentSchemaGrantTypeAccessToken, state.policy.acrService, &authenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
200+
refreshResp, err := state.policy.authClient.ExchangeAADAccessTokenForACRRefreshToken(state.req.Raw().Context(), postContentSchemaGrantTypeAccessToken, state.service, &authenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
195201
AccessToken: &aadToken.Token,
196202
})
197203
if err != nil {

sdk/containers/azcontainerregistry/authentication_policy_test.go

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
package azcontainerregistry
88

99
import (
10+
"bytes"
1011
"context"
1112
"fmt"
1213
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
1314
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
15+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
1416
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
1517
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
1618
"github.com/stretchr/testify/require"
1719
"net/http"
1820
"reflect"
1921
"strings"
22+
"sync/atomic"
2023
"testing"
2124
"time"
2225
)
@@ -77,7 +80,7 @@ func Test_getJWTExpireTime(t *testing.T) {
7780
}
7881
}
7982

80-
func Test_authenticationPolicy_findServiceAndScope(t *testing.T) {
83+
func Test_findServiceAndScope(t *testing.T) {
8184
resp1 := http.Response{}
8285
resp1.Header = http.Header{}
8386
resp1.Header.Set("WWW-Authenticate", "Bearer realm=\"https://contosoregistry.azurecr.io/oauth2/token\",service=\"contosoregistry.azurecr.io\",scope=\"registry:catalog:*\"")
@@ -97,14 +100,13 @@ func Test_authenticationPolicy_findServiceAndScope(t *testing.T) {
97100
{"error", "error", &http.Response{}, true},
98101
} {
99102
t.Run(fmt.Sprintf("%s-%s", test.acrService, test.acrScope), func(t *testing.T) {
100-
p := &authenticationPolicy{}
101-
err := p.findServiceAndScope(test.resp)
103+
service, scope, err := findServiceAndScope(test.resp)
102104
if test.err {
103105
require.Error(t, err)
104106
} else {
105107
require.NoError(t, err)
106-
require.Equal(t, test.acrScope, p.acrScope)
107-
require.Equal(t, test.acrService, p.acrService)
108+
require.Equal(t, test.acrService, service)
109+
require.Equal(t, test.acrScope, scope)
108110
}
109111
})
110112
}
@@ -118,16 +120,15 @@ func Test_authenticationPolicy_getAccessToken_live(t *testing.T) {
118120
}
119121
authClient := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
120122
p := &authenticationPolicy{
121-
temporal.NewResource(acquire),
123+
temporal.NewResource(acquireRefreshToken),
124+
atomic.Value{},
122125
cred,
123126
[]string{options.Cloud.Services[ServiceName].Audience + "/.default"},
124-
"registry:catalog:*",
125-
strings.TrimPrefix(endpoint, "https://"),
126127
authClient,
127128
}
128129
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
129130
require.NoError(t, err)
130-
token, err := p.getAccessToken(request)
131+
token, err := p.getAccessToken(request, strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
131132
require.NoError(t, err)
132133
require.NotEmpty(t, token)
133134
}
@@ -137,16 +138,12 @@ func Test_authenticationPolicy_getAccessToken_live_anonymous(t *testing.T) {
137138
endpoint, _, options := getEndpointCredAndClientOptions(t)
138139
authClient := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
139140
p := &authenticationPolicy{
140-
temporal.NewResource(acquire),
141-
nil,
142-
nil,
143-
"registry:catalog:*",
144-
strings.TrimPrefix(endpoint, "https://"),
145-
authClient,
141+
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
142+
authClient: authClient,
146143
}
147144
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
148145
require.NoError(t, err)
149-
token, err := p.getAccessToken(request)
146+
token, err := p.getAccessToken(request, strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
150147
require.NoError(t, err)
151148
require.NotEmpty(t, token)
152149
}
@@ -171,3 +168,16 @@ func Test_authenticationPolicy_anonymousAccess(t *testing.T) {
171168
_, err = client.UpdateRepositoryProperties(ctx, repositoryName, &ClientUpdateRepositoryPropertiesOptions{Value: &RepositoryWriteableProperties{CanDelete: to.Ptr(true)}})
172169
require.Error(t, err)
173170
}
171+
172+
func Test_authenticationPolicy_getChallengeRequest(t *testing.T) {
173+
oriReq, err := runtime.NewRequest(context.Background(), http.MethodPost, "https://test.com")
174+
require.NoError(t, err)
175+
testBody := []byte("test")
176+
err = oriReq.SetBody(streaming.NopCloser(bytes.NewReader(testBody)), "text/plain")
177+
require.NoError(t, err)
178+
p := &authenticationPolicy{}
179+
challengeReq, err := p.getChallengeRequest(*oriReq)
180+
require.NoError(t, err)
181+
require.Equal(t, fmt.Sprintf("%d", len(testBody)), oriReq.Raw().Header.Get("Content-Length"))
182+
require.Equal(t, "", challengeReq.Raw().Header.Get("Content-Length"))
183+
}

sdk/containers/azcontainerregistry/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ stages:
2626
parameters:
2727
ServiceDirectory: 'containers/azcontainerregistry'
2828
RunLiveTests: true
29+
TestRunTime: '30m'
2930
SupportedClouds: 'Public,UsGov,China'
3031
EnvVars:
3132
AZURE_CLIENT_ID: $(AZCONTAINERREGISTRY_CLIENT_ID)

0 commit comments

Comments
 (0)