Skip to content

Commit d419db7

Browse files
authored
[Messaging Clients] Shared Access Credential Fix (Azure#19872)
The focus of these changes is to enhance the shared access credential to eliminate a benign race condition and properly handle a concurrent update.
1 parent a25314e commit d419db7

File tree

4 files changed

+71
-18
lines changed

4 files changed

+71
-18
lines changed

sdk/eventhub/Azure.Messaging.EventHubs.Shared/src/Authorization/SharedAccessCredential.cs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Runtime.CompilerServices;
56
using System.Threading;
67
using System.Threading.Tasks;
78
using Azure.Core;
@@ -129,17 +130,17 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
129130
if ((!string.Equals(signature.SharedAccessKeyName, name, StringComparison.Ordinal))
130131
|| (!string.Equals(signature.SharedAccessKey, key, StringComparison.Ordinal)))
131132
{
132-
signature = new SharedAccessSignature(signature.Resource, name, key);
133-
Volatile.Write(ref _sharedAccessSignature, signature);
133+
var updatedSignature = new SharedAccessSignature(signature.Resource, name, key);
134+
signature = SafeUpdateSharedAccessSignature(signature, updatedSignature);
134135
}
135136
}
136137

137138
// If the key-based signature is approaching expiration, extend it.
138139

139140
if (signature.SignatureExpiration <= DateTimeOffset.UtcNow.Add(SignatureRefreshBuffer))
140141
{
141-
signature = signature.CloneWithNewExpiration(SignatureExtensionDuration);
142-
Volatile.Write(ref _sharedAccessSignature, signature);
142+
var updatedSignature = signature.CloneWithNewExpiration(SignatureExtensionDuration);
143+
signature = SafeUpdateSharedAccessSignature(signature, updatedSignature);
143144
}
144145

145146
return new AccessToken(signature.Value, signature.SignatureExpiration);
@@ -157,5 +158,32 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
157158
///
158159
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
159160
CancellationToken cancellationToken) => new ValueTask<AccessToken>(GetToken(requestContext, cancellationToken));
161+
162+
/// <summary>
163+
/// Attempts to update the current shared access signature reference of the credential while respecting concurrent updates.
164+
/// </summary>
165+
///
166+
/// <param name="cachedSignature">The cached signature that had been previously read. If this value is not the current <c>_sharedAccessSignature</c>, the update will not be performed.</param>
167+
/// <param name="updatedSignature">The signature that was locally updated and intended to replace the <paramref name="cachedSignature"/>.</param>
168+
///
169+
/// <returns>The current value of the <see cref="SharedAccessSignature" /> of the credential, after the attempted update. This will be the <paramref name="updatedSignature"/> if the update was performed.</returns>
170+
///
171+
/// <remarks>
172+
/// The class field "_sharedAccessSignature" may be mutated when calling this method.
173+
/// </remarks>
174+
///
175+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
176+
private SharedAccessSignature SafeUpdateSharedAccessSignature(SharedAccessSignature cachedSignature,
177+
SharedAccessSignature updatedSignature)
178+
{
179+
var signature = Interlocked.CompareExchange(ref _sharedAccessSignature, updatedSignature, cachedSignature);
180+
181+
// If the cached signature doesn't match the active one then the signature was not replaced because it had
182+
// already been updated by another caller; assume that active signature is correct and should be used.
183+
184+
return ReferenceEquals(signature, cachedSignature)
185+
? updatedSignature
186+
: signature;
187+
}
160188
}
161189
}

sdk/eventhub/Azure.Messaging.EventHubs.Shared/tests/Authorization/SharedAccessCredentialTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ public void GetTokenExtendsAnExpiredTokenWhenCreatedWithTheSharedKey()
189189
var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, DateTimeOffset.UtcNow.Subtract(TimeSpan.FromHours(2)));
190190
var credential = new SharedAccessCredential(signature);
191191

192-
var expectedExpiration = DateTimeOffset.Now.Add(GetSignatureExtensionDuration());
192+
var expectedExpiration = DateTimeOffset.UtcNow.Add(GetSignatureExtensionDuration());
193193
Assert.That(credential.GetToken(new TokenRequestContext(), default).ExpiresOn, Is.EqualTo(expectedExpiration).Within(TimeSpan.FromMinutes(1)));
194194
}
195195

@@ -206,7 +206,7 @@ public void GetTokenExtendsATokenCloseToExpiringWhenCreatedWithTheSharedKey()
206206
var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, tokenExpiration);
207207
var credential = new SharedAccessCredential(signature);
208208

209-
var expectedExpiration = DateTimeOffset.Now.Add(GetSignatureExtensionDuration());
209+
var expectedExpiration = DateTimeOffset.UtcNow.Add(GetSignatureExtensionDuration());
210210
Assert.That(credential.GetToken(new TokenRequestContext(), default).ExpiresOn, Is.EqualTo(expectedExpiration).Within(TimeSpan.FromMinutes(1)));
211211
}
212212

sdk/servicebus/Azure.Messaging.ServiceBus/src/Authorization/SharedAccessCredential.cs

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Runtime.CompilerServices;
56
using System.Threading;
67
using System.Threading.Tasks;
78
using Azure.Core;
@@ -73,8 +74,7 @@ public SharedAccessCredential(AzureSasCredential sourceCredential)
7374
/// <param name="sourceCredential">The <see cref="AzureNamedKeyCredential"/> to base signatures on.</param>
7475
/// <param name="signatureResource">The fully-qualified identifier for the resource to which this credential is intended to serve as authorization for. This is also known as the "token audience" in some contexts.</param>
7576
///
76-
public SharedAccessCredential(AzureNamedKeyCredential sourceCredential,
77-
string signatureResource)
77+
public SharedAccessCredential(AzureNamedKeyCredential sourceCredential, string signatureResource)
7878
{
7979
Argument.AssertNotNull(sourceCredential, nameof(sourceCredential));
8080
Argument.AssertNotNullOrEmpty(signatureResource, nameof(signatureResource));
@@ -96,8 +96,7 @@ public SharedAccessCredential(AzureNamedKeyCredential sourceCredential,
9696
///
9797
/// <returns>The token representing the shared access signature for this credential.</returns>
9898
///
99-
public override AccessToken GetToken(TokenRequestContext requestContext,
100-
CancellationToken cancellationToken)
99+
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
101100
{
102101
var signature = Volatile.Read(ref _sharedAccessSignature);
103102

@@ -129,17 +128,17 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
129128
if ((!string.Equals(signature.SharedAccessKeyName, name, StringComparison.Ordinal))
130129
|| (!string.Equals(signature.SharedAccessKey, key, StringComparison.Ordinal)))
131130
{
132-
signature = new SharedAccessSignature(signature.Resource, name, key);
133-
Volatile.Write(ref _sharedAccessSignature, signature);
131+
var updatedSignature = new SharedAccessSignature(signature.Resource, name, key);
132+
signature = SafeUpdateSharedAccessSignature(signature, updatedSignature);
134133
}
135134
}
136135

137136
// If the key-based signature is approaching expiration, extend it.
138137

139138
if (signature.SignatureExpiration <= DateTimeOffset.UtcNow.Add(SignatureRefreshBuffer))
140139
{
141-
signature = signature.CloneWithNewExpiration(SignatureExtensionDuration);
142-
Volatile.Write(ref _sharedAccessSignature, signature);
140+
var updatedSignature = signature.CloneWithNewExpiration(SignatureExtensionDuration);
141+
signature = SafeUpdateSharedAccessSignature(signature, updatedSignature);
143142
}
144143

145144
return new AccessToken(signature.Value, signature.SignatureExpiration);
@@ -155,7 +154,33 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
155154
///
156155
/// <returns>The token representing the shared access signature for this credential.</returns>
157156
///
158-
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
159-
CancellationToken cancellationToken) => new ValueTask<AccessToken>(GetToken(requestContext, cancellationToken));
157+
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) =>
158+
new ValueTask<AccessToken>(GetToken(requestContext, cancellationToken));
159+
160+
/// <summary>
161+
/// Attempts to update the current shared access signature reference of the credential while respecting concurrent updates.
162+
/// </summary>
163+
///
164+
/// <param name="cachedSignature">The cached signature that had been previously read. If this value is not the current <c>_sharedAccessSignature</c>, the update will not be performed.</param>
165+
/// <param name="updatedSignature">The signature that was locally updated and intended to replace the <paramref name="cachedSignature"/>.</param>
166+
///
167+
/// <returns>The current value of the <see cref="SharedAccessSignature" /> of the credential, after the attempted update. This will be the <paramref name="updatedSignature"/> if the update was performed.</returns>
168+
///
169+
/// <remarks>
170+
/// The class field "_sharedAccessSignature" may be mutated when calling this method.
171+
/// </remarks>
172+
///
173+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
174+
private SharedAccessSignature SafeUpdateSharedAccessSignature(SharedAccessSignature cachedSignature, SharedAccessSignature updatedSignature)
175+
{
176+
var signature = Interlocked.CompareExchange(ref _sharedAccessSignature, updatedSignature, cachedSignature);
177+
178+
// If the cached signature doesn't match the active one then the signature was not replaced because it had
179+
// already been updated by another caller; assume that active signature is correct and should be used.
180+
181+
return ReferenceEquals(signature, cachedSignature)
182+
? updatedSignature
183+
: signature;
184+
}
160185
}
161186
}

sdk/servicebus/Azure.Messaging.ServiceBus/tests/Authorization/SharedAccessCredentialTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ public void GetTokenExtendsAnExpiredTokenWhenCreatedWithTheSharedKey()
189189
var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, DateTimeOffset.UtcNow.Subtract(TimeSpan.FromHours(2)));
190190
var credential = new SharedAccessCredential(signature);
191191

192-
var expectedExpiration = DateTimeOffset.Now.Add(GetSignatureExtensionDuration());
192+
var expectedExpiration = DateTimeOffset.UtcNow.Add(GetSignatureExtensionDuration());
193193
Assert.That(credential.GetToken(new TokenRequestContext(), default).ExpiresOn, Is.EqualTo(expectedExpiration).Within(TimeSpan.FromMinutes(1)));
194194
}
195195

@@ -206,7 +206,7 @@ public void GetTokenExtendsATokenCloseToExpiringWhenCreatedWithTheSharedKey()
206206
var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, tokenExpiration);
207207
var credential = new SharedAccessCredential(signature);
208208

209-
var expectedExpiration = DateTimeOffset.Now.Add(GetSignatureExtensionDuration());
209+
var expectedExpiration = DateTimeOffset.UtcNow.Add(GetSignatureExtensionDuration());
210210
Assert.That(credential.GetToken(new TokenRequestContext(), default).ExpiresOn, Is.EqualTo(expectedExpiration).Within(TimeSpan.FromMinutes(1)));
211211
}
212212

0 commit comments

Comments
 (0)