Skip to content

Commit 1063321

Browse files
authored
Add support for token exchange managed identity (Azure#23775)
* Add support for token exchange managed identity * adding MI token exchange test * fix pipeline bug
1 parent c991552 commit 1063321

10 files changed

+631
-1
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Text;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using Azure.Core;
10+
using Azure.Core.Pipeline;
11+
using Microsoft.Identity.Client;
12+
13+
namespace Azure.Identity
14+
{
15+
internal class ClientAssertionCredential : TokenCredential
16+
{
17+
internal string TenantId { get; }
18+
internal string ClientId { get; }
19+
internal MsalConfidentialClient Client { get; }
20+
internal bool AllowMultiTenantAuthentication { get; }
21+
22+
public ClientAssertionCredential(string tenantId, string clientId, Func<Task<string>> getAssertionCallback, ClientAssertionCredentialOptions options = default) :
23+
this(tenantId, clientId, () => getAssertionCallback().GetAwaiter().GetResult(), options)
24+
{
25+
}
26+
27+
public ClientAssertionCredential(string tenantId, string clientId, Func<string> getAssertionCallback, ClientAssertionCredentialOptions options = default)
28+
{
29+
TenantId = tenantId;
30+
ClientId = clientId;
31+
AllowMultiTenantAuthentication = options?.AllowMultiTenantAuthentication ?? false;
32+
Client = options?.MsalClient ?? new MsalConfidentialClient(options?.Pipeline ?? CredentialPipeline.GetInstance(options), tenantId, clientId, getAssertionCallback, null, null, options?.IsLoggingPIIEnabled ?? false);
33+
}
34+
35+
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
36+
{
37+
using CredentialDiagnosticScope scope = Client.Pipeline.StartGetTokenScope("ClientAssertionCredential.GetToken", requestContext);
38+
39+
try
40+
{
41+
var tenantId = TenantIdResolver.Resolve(TenantId, requestContext, AllowMultiTenantAuthentication);
42+
43+
AuthenticationResult result = Client.AcquireTokenForClientAsync(requestContext.Scopes, tenantId, false, cancellationToken).EnsureCompleted();
44+
45+
return scope.Succeeded(new AccessToken(result.AccessToken, result.ExpiresOn));
46+
}
47+
catch (Exception e)
48+
{
49+
throw scope.FailWrapAndThrow(e);
50+
}
51+
}
52+
53+
public async override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
54+
{
55+
using CredentialDiagnosticScope scope = Client.Pipeline.StartGetTokenScope("ClientAssertionCredential.GetToken", requestContext);
56+
57+
try
58+
{
59+
var tenantId = TenantIdResolver.Resolve(TenantId, requestContext, AllowMultiTenantAuthentication);
60+
61+
AuthenticationResult result = await Client.AcquireTokenForClientAsync(requestContext.Scopes, tenantId, true, cancellationToken).ConfigureAwait(false);
62+
63+
return scope.Succeeded(new AccessToken(result.AccessToken, result.ExpiresOn));
64+
}
65+
catch (Exception e)
66+
{
67+
throw scope.FailWrapAndThrow(e);
68+
}
69+
}
70+
}
71+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
namespace Azure.Identity
5+
{
6+
internal class ClientAssertionCredentialOptions : TokenCredentialOptions
7+
{
8+
internal CredentialPipeline Pipeline { get; set; }
9+
10+
internal MsalConfidentialClient MsalClient { get; set; }
11+
}
12+
}

sdk/identity/Azure.Identity/src/EnvironmentVariables.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@ internal class EnvironmentVariables
2929
public static string AuthorityHost => Environment.GetEnvironmentVariable("AZURE_AUTHORITY_HOST");
3030

3131
public static string AzureRegionalAuthorityName => Environment.GetEnvironmentVariable("AZURE_REGIONAL_AUTHORITY_NAME");
32+
33+
public static string AzureFederatedTokenFile => Environment.GetEnvironmentVariable("AZURE_FEDERATED_TOKEN_FILE");
3234
}
3335
}

sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentity
4646
CloudShellManagedIdentitySource.TryCreate(options) ??
4747
AzureArcManagedIdentitySource.TryCreate(options) ??
4848
ServiceFabricManagedIdentitySource.TryCreate(options) ??
49+
TokenExchangeManagedIdentitySource.TryCreate(options) ??
4950
new ImdsManagedIdentitySource(options.Pipeline, options.ClientId);
5051
}
5152
}

sdk/identity/Azure.Identity/src/MsalClientBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ protected MsalClientBase(CredentialPipeline pipeline, string tenantId, string cl
4343

4444
internal TokenCache TokenCache { get; }
4545

46-
protected CredentialPipeline Pipeline { get; }
46+
protected internal CredentialPipeline Pipeline { get; }
4747

4848
protected abstract ValueTask<TClient> CreateClientAsync(bool async, CancellationToken cancellationToken);
4949

sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
using System;
45
using System.Security.Cryptography.X509Certificates;
56
using System.Threading;
67
using System.Threading.Tasks;
@@ -13,6 +14,7 @@ internal class MsalConfidentialClient : MsalClientBase<IConfidentialClientApplic
1314
internal readonly string _clientSecret;
1415
internal readonly bool _includeX5CClaimHeader;
1516
internal readonly IX509Certificate2Provider _certificateProvider;
17+
private readonly Func<string> _assertionCallback;
1618

1719
/// <summary>
1820
/// For mocking purposes only.
@@ -35,6 +37,13 @@ public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, stri
3537
RegionalAuthority = regionalAuthority;
3638
}
3739

40+
public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, string clientId, Func<string> assertionCallback, ITokenCacheOptions cacheOptions, RegionalAuthority? regionalAuthority, bool isPiiLoggingEnabled)
41+
: base(pipeline, tenantId, clientId, isPiiLoggingEnabled, cacheOptions)
42+
{
43+
_assertionCallback = assertionCallback;
44+
RegionalAuthority = regionalAuthority;
45+
}
46+
3847
internal RegionalAuthority? RegionalAuthority { get; }
3948

4049
protected override async ValueTask<IConfidentialClientApplication> CreateClientAsync(bool async, CancellationToken cancellationToken)
@@ -49,6 +58,11 @@ protected override async ValueTask<IConfidentialClientApplication> CreateClientA
4958
confClientBuilder.WithClientSecret(_clientSecret);
5059
}
5160

61+
if (_assertionCallback != null)
62+
{
63+
confClientBuilder.WithClientAssertion(_assertionCallback);
64+
}
65+
5266
if (_certificateProvider != null)
5367
{
5468
X509Certificate2 clientCertificate = await _certificateProvider.GetCertificateAsync(async, cancellationToken).ConfigureAwait(false);
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using System.Text;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Azure.Core;
11+
12+
namespace Azure.Identity
13+
{
14+
internal class TokenExchangeManagedIdentitySource : ManagedIdentitySource
15+
{
16+
private TokenFileCache _tokenFileCache;
17+
private ClientAssertionCredential _clientAssertionCredential;
18+
19+
private TokenExchangeManagedIdentitySource(CredentialPipeline pipeline, string tenantId, string clientId, string tokenFilePath)
20+
: base(pipeline)
21+
{
22+
_tokenFileCache = new TokenFileCache(tokenFilePath);
23+
_clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContents, new ClientAssertionCredentialOptions { Pipeline = pipeline });
24+
}
25+
26+
public static ManagedIdentitySource TryCreate(ManagedIdentityClientOptions options)
27+
{
28+
string tokenFilePath = EnvironmentVariables.AzureFederatedTokenFile;
29+
string tenantId = EnvironmentVariables.TenantId;
30+
string clientId = options.ClientId ?? EnvironmentVariables.ClientId;
31+
32+
if (string.IsNullOrEmpty(tokenFilePath) || string.IsNullOrEmpty(tenantId) || string.IsNullOrEmpty(clientId))
33+
{
34+
return default;
35+
}
36+
37+
return new TokenExchangeManagedIdentitySource(options.Pipeline, tenantId, clientId, tokenFilePath);
38+
}
39+
40+
public async override ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
41+
{
42+
return async ? await _clientAssertionCredential.GetTokenAsync(context, cancellationToken).ConfigureAwait(false) : _clientAssertionCredential.GetToken(context, cancellationToken);
43+
}
44+
45+
protected override Request CreateRequest(string[] scopes)
46+
{
47+
throw new NotImplementedException();
48+
}
49+
50+
// Ideally this class should handle I/O asynchronously, and have a design similar to AccessTokenCache in BearerTokenAuthenticationPolicy.
51+
// However, MSAL currently only accepts sync callbacks for client assertions so this has been radically simplified in light of this. If MSAL
52+
// were to add support for an async callback we should update this accordingly.
53+
// See, https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/2863
54+
private class TokenFileCache
55+
{
56+
private readonly object _lock = new object();
57+
private readonly string _tokenFilePath;
58+
private string _tokenFileContents;
59+
private DateTimeOffset _refreshOn = DateTimeOffset.MinValue;
60+
61+
public TokenFileCache(string tokenFilePath)
62+
{
63+
_tokenFilePath = tokenFilePath;
64+
}
65+
66+
public string GetTokenFileContents()
67+
{
68+
if (_refreshOn <= DateTimeOffset.UtcNow)
69+
{
70+
lock (_lock)
71+
{
72+
if (_refreshOn <= DateTimeOffset.UtcNow)
73+
{
74+
_tokenFileContents = File.ReadAllText(_tokenFilePath);
75+
76+
_refreshOn = DateTimeOffset.UtcNow.AddMinutes(5);
77+
}
78+
}
79+
}
80+
81+
return _tokenFileContents;
82+
}
83+
}
84+
}
85+
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using System.Linq;
8+
using System.Security.Cryptography;
9+
using System.Security.Cryptography.X509Certificates;
10+
using System.Text;
11+
using System.Text.Json;
12+
using System.Threading.Tasks;
13+
using Azure.Core;
14+
using Azure.Core.TestFramework;
15+
using NUnit.Framework;
16+
17+
namespace Azure.Identity.Tests
18+
{
19+
public class ManagedIdentityCredentialFederatedTokenLiveTests : IdentityRecordedTestBase
20+
{
21+
public ManagedIdentityCredentialFederatedTokenLiveTests(bool isAsync) : base(isAsync)
22+
{
23+
}
24+
25+
[SetUp]
26+
public void ClearDiscoveryCache()
27+
{
28+
StaticCachesUtilities.ClearStaticMetadataProviderCache();
29+
StaticCachesUtilities.ClearAuthorityEndpointResolutionManagerCache();
30+
}
31+
32+
[NonParallelizable]
33+
[Test]
34+
public async Task VerifyViaMockK8TokenExchangeEnvironment()
35+
{
36+
var tenantId = TestEnvironment.ServicePrincipalTenantId;
37+
var clientId = TestEnvironment.ServicePrincipalClientId;
38+
var authorityHostUrl = TestEnvironment.AuthorityHostUrl;
39+
40+
var assertionAudienceBuilder = new RequestUriBuilder();
41+
assertionAudienceBuilder.Reset(new Uri(authorityHostUrl));
42+
assertionAudienceBuilder.AppendPath(tenantId);
43+
assertionAudienceBuilder.AppendPath("/oauth2/v2.0/token", escape: false);
44+
var assertionAudience = assertionAudienceBuilder.ToString();
45+
46+
var assertionCert = new X509Certificate2(TestEnvironment.ServicePrincipalCertificatePfxPath);
47+
48+
string tokenFilePath = Path.Combine(Path.GetTempPath(), Path.GetTempFileName());
49+
50+
File.WriteAllText(tokenFilePath, CreateClientAssertionJWT(clientId, assertionAudience, assertionCert));
51+
52+
try
53+
{
54+
using (var environment = new TestEnvVar(new()
55+
{
56+
{ "MSI_ENDPOINT", null },
57+
{ "MSI_SECRET", null },
58+
{ "IDENTITY_ENDPOINT", null },
59+
{ "IDENTITY_HEADER", null },
60+
{ "AZURE_POD_IDENTITY_AUTHORITY_HOST", null },
61+
{ "AZURE_CLIENT_ID", clientId },
62+
{ "AZURE_TENANT_ID", tenantId },
63+
{ "AZURE_AUTHORITY_HOST", authorityHostUrl },
64+
{ "AZURE_FEDERATED_TOKEN_FILE", tokenFilePath }
65+
}))
66+
{
67+
var options = InstrumentClientOptions(new TokenCredentialOptions());
68+
var credential = InstrumentClient(new ManagedIdentityCredential(options: options));
69+
70+
var tokenRequestContext = new TokenRequestContext(new[] { AzureAuthorityHosts.GetDefaultScope(new Uri(TestEnvironment.AuthorityHostUrl)) });
71+
72+
var accessToken = await credential.GetTokenAsync(tokenRequestContext);
73+
74+
Assert.IsNotNull(accessToken.Token);
75+
}
76+
}
77+
finally
78+
{
79+
File.Delete(tokenFilePath);
80+
}
81+
}
82+
83+
private static string CreateClientAssertionJWT(string clientId, string audience, X509Certificate2 clientCertificate)
84+
{
85+
var headerBuff = new ArrayBufferWriter<byte>();
86+
87+
using (var headerJson = new Utf8JsonWriter(headerBuff))
88+
{
89+
headerJson.WriteStartObject();
90+
91+
headerJson.WriteString("typ", "JWT");
92+
headerJson.WriteString("alg", "RS256");
93+
headerJson.WriteString("x5t", HexToBase64Url(clientCertificate.Thumbprint));
94+
95+
headerJson.WriteEndObject();
96+
97+
headerJson.Flush();
98+
}
99+
100+
var payloadBuff = new ArrayBufferWriter<byte>();
101+
102+
using (var payloadJson = new Utf8JsonWriter(payloadBuff))
103+
{
104+
payloadJson.WriteStartObject();
105+
106+
payloadJson.WriteString("jti", Guid.NewGuid());
107+
payloadJson.WriteString("aud", audience);
108+
payloadJson.WriteString("iss", clientId);
109+
payloadJson.WriteString("sub", clientId);
110+
payloadJson.WriteNumber("nbf", DateTimeOffset.UtcNow.ToUnixTimeSeconds());
111+
payloadJson.WriteNumber("exp", (DateTimeOffset.UtcNow + TimeSpan.FromMinutes(30)).ToUnixTimeSeconds());
112+
113+
payloadJson.WriteEndObject();
114+
115+
payloadJson.Flush();
116+
}
117+
118+
string header = Base64Url.Encode(headerBuff.WrittenMemory.ToArray());
119+
120+
string payload = Base64Url.Encode(payloadBuff.WrittenMemory.ToArray());
121+
122+
string flattenedJws = header + "." + payload;
123+
124+
byte[] signature = clientCertificate.GetRSAPrivateKey().SignData(Encoding.ASCII.GetBytes(flattenedJws), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
125+
126+
return flattenedJws + "." + Base64Url.Encode(signature);
127+
}
128+
129+
private static string HexToBase64Url(string hex)
130+
{
131+
byte[] bytes = new byte[hex.Length / 2];
132+
133+
for (int i = 0; i < hex.Length; i += 2)
134+
bytes[i / 2] = Convert.ToByte(hex.Substring(i, 2), 16);
135+
136+
return Base64Url.Encode(bytes);
137+
}
138+
}
139+
}

0 commit comments

Comments
 (0)