Skip to content

Commit 9681941

Browse files
authored
BearerTokenChallengeAuthenticationPolicy refactor (Azure#19215)
* refactor BearerTokenChallengeAuthenticationPolicy
1 parent c941c23 commit 9681941

File tree

3 files changed

+60
-87
lines changed

3 files changed

+60
-87
lines changed

sdk/core/Azure.Core/src/Shared/ARMChallengeAuthenticationPolicy.cs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Threading.Tasks;
56
using System.Collections.Generic;
67

78
#nullable enable
@@ -29,25 +30,18 @@ public ARMChallengeAuthenticationPolicy(TokenCredential credential, string scope
2930
public ARMChallengeAuthenticationPolicy(TokenCredential credential, IEnumerable<string> scopes)
3031
: base(credential, scopes, TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) { }
3132

32-
/// <summary>
33-
/// Executed in the event a 401 response with a WWW-Authenticate authentication challenge header is received after the initial request.
34-
/// </summary>
35-
/// <remarks>Handles claims authentication challenges.</remarks>
36-
/// <param name="message">The <see cref="HttpMessage"/> to be authenticated.</param>
37-
/// <param name="context">If the return value is <c>true</c>, a <see cref="TokenRequestContext"/>.</param>
38-
/// <returns>A boolean indicated whether the request contained a valid challenge and a <see cref="TokenRequestContext"/> was successfully initialized with it.</returns>
39-
protected override bool TryGetTokenRequestContextFromChallenge(HttpMessage message, out TokenRequestContext context)
33+
/// <inheritdoc cref="BearerTokenChallengeAuthenticationPolicy.AuthenticateRequestOnChallengeAsync(HttpMessage, bool)" />
34+
protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(HttpMessage message, bool async)
4035
{
41-
context = default;
42-
4336
var challenge = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "claims");
4437
if (challenge == null)
4538
{
4639
return false;
4740
}
4841

4942
string claimsChallenge = Base64Url.DecodeString(challenge.ToString());
50-
context = new TokenRequestContext(Scopes, message.Request.ClientRequestId, claimsChallenge);
43+
var context = new TokenRequestContext(Scopes, message.Request.ClientRequestId, claimsChallenge);
44+
await SetAuthorizationHeader(message, context, async);
5145
return true;
5246
}
5347
}

sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace Azure.Core.Pipeline
1919
internal class BearerTokenChallengeAuthenticationPolicy : HttpPipelinePolicy
2020
{
2121
private readonly AccessTokenCache _accessTokenCache;
22-
protected string[] Scopes { get; set; }
22+
protected string[] Scopes { get; private set; }
23+
private readonly ValueTask<bool> _falseValueTask = new ValueTask<bool>(Task.FromResult(false));
2324

2425
/// <summary>
2526
/// Creates a new instance of <see cref="BearerTokenChallengeAuthenticationPolicy"/> using provided token credential and scope to authenticate for.
@@ -57,17 +58,29 @@ public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePol
5758
ProcessAsync(message, pipeline, false).EnsureCompleted();
5859
}
5960

61+
/// <summary>
62+
/// Executes before <see cref="ProcessAsync(HttpMessage, ReadOnlyMemory{HttpPipelinePolicy})"/> or <see cref="Process(HttpMessage, ReadOnlyMemory{HttpPipelinePolicy})"/> is called.
63+
/// Implementers of this method are expected to call <see cref="SetAuthorizationHeader(HttpMessage, TokenRequestContext, bool)"/> if authorization is required for requests not related to handling a challenge response.
64+
/// </summary>
65+
/// <param name="message">The <see cref="HttpMessage"/> this policy would be applied to.</param>
66+
/// <param name="async">Indicates whether the method was called from an asynchronous context.</param>
67+
/// <returns>The <see cref="ValueTask"/> representing the asynchronous operation.</returns>
68+
protected virtual Task AuthenticateRequestAsync(HttpMessage message, bool async)
69+
{
70+
var context = new TokenRequestContext(Scopes, message.Request.ClientRequestId);
71+
return SetAuthorizationHeader(message, context, async);
72+
}
73+
6074
/// <summary>
6175
/// Executed in the event a 401 response with a WWW-Authenticate authentication challenge header is received after the initial request.
6276
/// </summary>
63-
/// <remarks>Service client libraries may derive from this and extend to handle service specific authentication challenges.</remarks>
77+
/// <remarks>Service client libraries may override this to handle service specific authentication challenges.</remarks>
6478
/// <param name="message">The <see cref="HttpMessage"/> to be authenticated.</param>
65-
/// <param name="context">If the return value is <c>true</c>, a <see cref="TokenRequestContext"/>.</param>
66-
/// <returns>A boolean indicated whether the request contained a valid challenge and a <see cref="TokenRequestContext"/> was successfully initialized with it.</returns>
67-
protected virtual bool TryGetTokenRequestContextFromChallenge(HttpMessage message, out TokenRequestContext context)
79+
/// <param name="async">Indicates whether the method was called from an asynchronous context.</param>
80+
/// <returns>A boolean indicating whether the request was successfully authenticated and should be sent to the transport.</returns>
81+
protected virtual ValueTask<bool> AuthenticateRequestOnChallengeAsync(HttpMessage message, bool async)
6882
{
69-
context = default;
70-
return false;
83+
return _falseValueTask;
7184
}
7285

7386
private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline, bool async)
@@ -77,31 +90,14 @@ private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPip
7790
throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints.");
7891
}
7992

80-
TokenRequestContext context;
81-
82-
// If the message already has a challenge response due to a sub-class pre-processing the request, get the context from the challenge.
83-
if (message.HasResponse && message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(HttpHeader.Names.WWWAuthenticate))
84-
{
85-
if (!TryGetTokenRequestContextFromChallenge(message, out context))
86-
{
87-
// We were unsuccessful in handling the challenge, so bail out now.
88-
return;
89-
}
90-
Scopes = context.Scopes;
91-
}
92-
else
93-
{
94-
context = new TokenRequestContext(Scopes, message.Request.ClientRequestId);
95-
}
96-
97-
await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false);
98-
9993
if (async)
10094
{
95+
await AuthenticateRequestAsync(message, true).ConfigureAwait(false);
10196
await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
10297
}
10398
else
10499
{
100+
AuthenticateRequestAsync(message, false).EnsureCompleted();
105101
ProcessNext(message, pipeline);
106102
}
107103

@@ -111,13 +107,8 @@ private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPip
111107
// Attempt to get the TokenRequestContext based on the challenge.
112108
// If we fail to get the context, the challenge was not present or invalid.
113109
// If we succeed in getting the context, authenticate the request and pass it up the policy chain.
114-
if (TryGetTokenRequestContextFromChallenge(message, out context))
110+
if (await AuthenticateRequestOnChallengeAsync(message, async).ConfigureAwait(false))
115111
{
116-
// Ensure the scopes are consistent with what was set by <see cref="TryGetTokenRequestContextFromChallenge" />.
117-
Scopes = context.Scopes;
118-
119-
await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false);
120-
121112
if (async)
122113
{
123114
await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
@@ -130,7 +121,13 @@ private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPip
130121
}
131122
}
132123

133-
private async Task AuthenticateRequestAsync(HttpMessage message, TokenRequestContext context, bool async)
124+
/// <summary>
125+
/// Sets the Authorization header on the <see cref="Request"/>.
126+
/// </summary>
127+
/// <param name="message">The <see cref="HttpMessage"/> with the <see cref="Request"/> to be authorized.</param>
128+
/// <param name="context">The <see cref="TokenRequestContext"/> used to authorize the <see cref="Request"/>.</param>
129+
/// <param name="async">Indicates whether the method was called from an asynchronous context.</param>
130+
protected async Task SetAuthorizationHeader(HttpMessage message, TokenRequestContext context, bool async)
134131
{
135132
string headerValue;
136133
if (async)

sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,69 +13,50 @@ namespace Azure.Security.KeyVault
1313
internal class ChallengeBasedAuthenticationPolicy : BearerTokenChallengeAuthenticationPolicy
1414
{
1515
private static ConcurrentDictionary<string, AuthorityScope> _scopeCache = new ConcurrentDictionary<string, AuthorityScope>();
16+
private const string KeyVaultStashedContentKey = "KeyVaultContent";
1617
private AuthorityScope _scope;
1718

1819
public ChallengeBasedAuthenticationPolicy(TokenCredential credential) : base(credential, Array.Empty<string>())
1920
{ }
2021

21-
public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
22-
{
23-
PreProcessAsync(message, pipeline, false).EnsureCompleted();
24-
base.Process(message, pipeline);
25-
}
26-
27-
public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline)
28-
{
29-
await PreProcessAsync(message, pipeline, true).ConfigureAwait(false);
30-
await base.ProcessAsync(message, pipeline).ConfigureAwait(false);
31-
}
32-
33-
protected async Task PreProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline, bool async)
22+
/// <inheritdoc cref="BearerTokenChallengeAuthenticationPolicy.AuthenticateRequestOnChallengeAsync(HttpMessage, bool)" />
23+
protected override async Task AuthenticateRequestAsync(HttpMessage message, bool async)
3424
{
3525
if (message.Request.Uri.Scheme != Uri.UriSchemeHttps)
3626
{
3727
throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints.");
3828
}
3929

40-
// if this policy doesn't have _scope cached try to get it from the static challenge cache.
41-
if (_scope != null)
30+
// If this policy doesn't have _scope cached try to get it from the static challenge cache.
31+
if (_scope == null)
4232
{
43-
return;
33+
string authority = GetRequestAuthority(message.Request);
34+
_scopeCache.TryGetValue(authority, out _scope);
4435
}
4536

46-
string authority = GetRequestAuthority(message.Request);
47-
_scopeCache.TryGetValue(authority, out _scope);
48-
49-
if (_scope == null)
37+
if (_scope != null)
5038
{
51-
// The body is removed from the initial request because Key Vault supports other authentication schemes which also protect the body of the request.
52-
// As a result, before we know the auth scheme we need to avoid sending an unprotected body to Key Vault.
53-
// We don't currently support this enhanced auth scheme in the SDK but we still don't want to send any unprotected data to vaults which require it.
39+
// We fetched the scope from the cache, but we have not initialized the Scopes in the base yet.
40+
var context = new TokenRequestContext(_scope.Scopes, message.Request.ClientRequestId);
41+
await SetAuthorizationHeader(message, context, async).ConfigureAwait(false);
42+
return;
43+
}
5444

55-
RequestContent originalContent = message.Request.Content;
56-
message.Request.Content = null;
45+
// The body is removed from the initial request because Key Vault supports other authentication schemes which also protect the body of the request.
46+
// As a result, before we know the auth scheme we need to avoid sending an unprotected body to Key Vault.
47+
// We don't currently support this enhanced auth scheme in the SDK but we still don't want to send any unprotected data to vaults which require it.
5748

58-
if (async)
59-
{
60-
await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
61-
}
62-
else
63-
{
64-
ProcessNext(message, pipeline);
65-
}
49+
message.SetProperty(KeyVaultStashedContentKey, message.Request.Content);
50+
message.Request.Content = null;
51+
}
6652

67-
// set the content to the original content.
68-
message.Request.Content = originalContent;
69-
}
70-
else
53+
protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(HttpMessage message, bool async)
54+
{
55+
if (message.Request.Content == null && message.TryGetProperty(KeyVaultStashedContentKey, out var content))
7156
{
72-
// We fetched the scope from the cache, but we have not initialized the Scopes in the base yet.
73-
Scopes = _scope.Scopes;
57+
message.Request.Content = content as RequestContent;
7458
}
75-
}
7659

77-
protected override bool TryGetTokenRequestContextFromChallenge(HttpMessage message, out TokenRequestContext context)
78-
{
7960
string authority = GetRequestAuthority(message.Request);
8061
string scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "resource");
8162
if (scope != null)
@@ -100,7 +81,8 @@ protected override bool TryGetTokenRequestContextFromChallenge(HttpMessage messa
10081
_scopeCache[authority] = _scope;
10182
}
10283

103-
context = new TokenRequestContext(_scope.Scopes, message.Request.ClientRequestId);
84+
var context = new TokenRequestContext(_scope.Scopes, message.Request.ClientRequestId);
85+
await SetAuthorizationHeader(message, context, async).ConfigureAwait(false);
10486
return true;
10587
}
10688

0 commit comments

Comments
 (0)