Skip to content

Commit 9a2679e

Browse files
authored
Add Azure Arc MI Support (Azure#17013)
1 parent e3e2b83 commit 9a2679e

File tree

5 files changed

+208
-40
lines changed

5 files changed

+208
-40
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.identity;
5+
6+
import com.azure.core.annotation.Immutable;
7+
import com.azure.core.credential.AccessToken;
8+
import com.azure.core.credential.TokenRequestContext;
9+
import com.azure.core.exception.ClientAuthenticationException;
10+
import com.azure.core.util.Configuration;
11+
import com.azure.core.util.logging.ClientLogger;
12+
import com.azure.identity.implementation.IdentityClient;
13+
import reactor.core.publisher.Mono;
14+
15+
/**
16+
* The Managed Service Identity credential for Azure Arc Service.
17+
*/
18+
@Immutable
19+
class ArcIdentityCredential extends ManagedIdentityServiceCredential {
20+
private final String identityEndpoint;
21+
private final ClientLogger logger = new ClientLogger(ArcIdentityCredential.class);
22+
23+
/**
24+
* Creates an instance of {@link ArcIdentityCredential}.
25+
*
26+
* @param clientId The client ID of user assigned or system assigned identity.
27+
* @param identityClient The identity client to acquire a token with.
28+
*/
29+
ArcIdentityCredential(String clientId, IdentityClient identityClient) {
30+
super(clientId, identityClient, "AZURE ARC IDENTITY ENDPOINT");
31+
Configuration configuration = Configuration.getGlobalConfiguration().clone();
32+
this.identityEndpoint = configuration.get(Configuration.PROPERTY_IDENTITY_ENDPOINT);
33+
if (identityEndpoint != null) {
34+
validateEndpointProtocol(this.identityEndpoint, "Identity", logger);
35+
}
36+
}
37+
38+
/**
39+
* Gets an access token for a token request.
40+
*
41+
* @param request The details of the token request.
42+
* @return A publisher that emits an {@link AccessToken}.
43+
*/
44+
public Mono<AccessToken> authenticate(TokenRequestContext request) {
45+
if (getClientId() == null) {
46+
return Mono.error(logger.logExceptionAsError(new ClientAuthenticationException(
47+
"User assigned identity is not supported by the Azure Arc Managed Identity Endpoint. To authenticate "
48+
+ "with the system assigned identity omit the client id when constructing the"
49+
+ " ManagedIdentityCredential.", null)));
50+
}
51+
return identityClient.authenticateToArcManagedIdentityEndpoint(identityEndpoint, request);
52+
}
53+
}

sdk/identity/azure-identity/src/main/java/com/azure/identity/ManagedIdentityCredential.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public final class ManagedIdentityCredential implements TokenCredential {
2323
private final ManagedIdentityServiceCredential managedIdentityServiceCredential;
2424
private final ClientLogger logger = new ClientLogger(ManagedIdentityCredential.class);
2525

26+
static final String PROPERTY_IMDS_ENDPOINT = "IMDS_ENDPOINT";
2627
static final String PROPERTY_IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT";
2728

2829

@@ -42,10 +43,12 @@ public final class ManagedIdentityCredential implements TokenCredential {
4243
if (configuration.contains(PROPERTY_IDENTITY_SERVER_THUMBPRINT)) {
4344
managedIdentityServiceCredential = new ServiceFabricMsiCredential(clientId, identityClient);
4445
} else {
45-
managedIdentityServiceCredential = new AppServiceMsiCredential(clientId, identityClient);
46+
managedIdentityServiceCredential = new VirtualMachineMsiCredential(clientId, identityClient);
4647
}
48+
} else if (configuration.contains(PROPERTY_IMDS_ENDPOINT)) {
49+
managedIdentityServiceCredential = new ArcIdentityCredential(clientId, identityClient);
4750
} else {
48-
managedIdentityServiceCredential = null;
51+
managedIdentityServiceCredential = new VirtualMachineMsiCredential(clientId, identityClient);
4952
}
5053
} else if (configuration.contains(Configuration.PROPERTY_MSI_ENDPOINT)) {
5154
managedIdentityServiceCredential = new AppServiceMsiCredential(clientId, identityClient);
@@ -68,13 +71,13 @@ public Mono<AccessToken> getToken(TokenRequestContext request) {
6871
if (managedIdentityServiceCredential == null) {
6972
return Mono.error(logger.logExceptionAsError(
7073
new CredentialUnavailableException("ManagedIdentityCredential authentication unavailable. "
71-
+ "The Target Azure platform could not be determined from environment variables.")));
74+
+ "The Target Azure platform could not be determined from environment variables.")));
7275
}
7376
return managedIdentityServiceCredential.authenticate(request)
74-
.doOnSuccess((t -> logger.info(String.format("Azure Identity => Managed Identity environment: %s",
75-
managedIdentityServiceCredential.getEnvironment()))))
76-
.doOnNext(token -> LoggingUtil.logTokenSuccess(logger, request))
77-
.doOnError(error -> LoggingUtil.logTokenError(logger, request, error));
77+
.doOnSuccess((t -> logger.info(String.format("Azure Identity => Managed Identity environment: %s",
78+
managedIdentityServiceCredential.getEnvironment()))))
79+
.doOnNext(token -> LoggingUtil.logTokenSuccess(logger, request))
80+
.doOnError(error -> LoggingUtil.logTokenError(logger, request, error));
7881
}
7982
}
8083

sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClient.java

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,101 @@ public Mono<MsalToken> authenticateWithSharedTokenCache(TokenRequestContext requ
750750

751751

752752
/**
753-
* Asynchronously acquire a token from the App Service Managed Service Identity endpoint.
753+
* Asynchronously acquire a token from the Azure Arc Managed Service Identity endpoint.
754+
*
755+
* @param identityEndpoint the Identity endpoint to acquire token from
756+
* @param request the details of the token request
757+
* @return a Publisher that emits an AccessToken
758+
*/
759+
public Mono<AccessToken> authenticateToArcManagedIdentityEndpoint(String identityEndpoint,
760+
TokenRequestContext request) {
761+
return Mono.fromCallable(() -> {
762+
HttpURLConnection connection = null;
763+
StringBuilder payload = new StringBuilder();
764+
payload.append("resource=");
765+
payload.append(URLEncoder.encode(ScopeUtil.scopesToResource(request.getScopes()), "UTF-8"));
766+
payload.append("&api-version=");
767+
payload.append(URLEncoder.encode("2019-11-01", "UTF-8"));
768+
769+
URL url = new URL(String.format("%s?%s", identityEndpoint, payload));
770+
771+
772+
String secretKey = null;
773+
try {
774+
connection = (HttpURLConnection) url.openConnection();
775+
connection.setRequestMethod("GET");
776+
connection.setRequestProperty("Metadata", "true");
777+
connection.connect();
778+
779+
new Scanner(connection.getInputStream(), "UTF-8").useDelimiter("\\A");
780+
} catch (IOException e) {
781+
if (connection == null) {
782+
throw logger.logExceptionAsError(new ClientAuthenticationException("Failed to initialize "
783+
+ "Http URL connection to the endpoint.",
784+
null, e));
785+
}
786+
int status = connection.getResponseCode();
787+
if (status != 401) {
788+
throw logger.logExceptionAsError(new ClientAuthenticationException(String.format("Expected a 401"
789+
+ " Unauthorized response from Azure Arc Managed Identity Endpoint, received: %d", status),
790+
null, e));
791+
}
792+
793+
String realm = connection.getHeaderField("WWW-Authenticate");
794+
795+
if (realm == null) {
796+
throw logger.logExceptionAsError(new ClientAuthenticationException("Did not receive a value"
797+
+ " for WWW-Authenticate header in the response from Azure Arc Managed Identity Endpoint",
798+
null));
799+
}
800+
801+
int separatorIndex = realm.indexOf("=");
802+
if (separatorIndex == -1) {
803+
throw logger.logExceptionAsError(new ClientAuthenticationException("Did not receive a correct value"
804+
+ " for WWW-Authenticate header in the response from Azure Arc Managed Identity Endpoint",
805+
null));
806+
}
807+
808+
String secretKeyPath = realm.substring(separatorIndex + 1);
809+
secretKey = new String(Files.readAllBytes(Paths.get(secretKeyPath)), StandardCharsets.UTF_8);
810+
811+
} finally {
812+
if (connection != null) {
813+
connection.disconnect();
814+
}
815+
}
816+
817+
818+
if (secretKey == null) {
819+
throw logger.logExceptionAsError(new ClientAuthenticationException("Did not receive a secret value"
820+
+ " in the response from Azure Arc Managed Identity Endpoint",
821+
null));
822+
}
823+
824+
825+
try {
826+
827+
connection = (HttpURLConnection) url.openConnection();
828+
connection.setRequestMethod("GET");
829+
connection.setRequestProperty("Authorization", String.format("Basic %s", secretKey));
830+
connection.setRequestProperty("Metadata", "true");
831+
connection.connect();
832+
833+
Scanner scanner = new Scanner(connection.getInputStream(), "UTF-8").useDelimiter("\\A");
834+
String result = scanner.hasNext() ? scanner.next() : "";
835+
836+
return SERIALIZER_ADAPTER.deserialize(result, MSIToken.class, SerializerEncoding.JSON);
837+
838+
} finally {
839+
if (connection != null) {
840+
connection.disconnect();
841+
}
842+
}
843+
});
844+
}
845+
846+
/**
847+
* Asynchronously acquire a token from the Azure Service Fabric Managed Service Identity endpoint.
754848
*
755849
* @param identityEndpoint the Identity endpoint to acquire token from
756850
* @param identityHeader the identity header to acquire token with
@@ -941,6 +1035,12 @@ public Mono<AccessToken> authenticateToIMDSEndpoint(TokenRequestContext request)
9411035
+ "Connection to IMDS endpoint cannot be established, "
9421036
+ e.getMessage() + ".", e));
9431037
}
1038+
if (responseCode == 400) {
1039+
throw logger.logExceptionAsError(
1040+
new CredentialUnavailableException(
1041+
"ManagedIdentityCredential authentication unavailable. "
1042+
+ "Connection to IMDS endpoint cannot be established.", null));
1043+
}
9441044
if (responseCode == 410
9451045
|| responseCode == 429
9461046
|| responseCode == 404

sdk/identity/azure-identity/src/test/java/com/azure/identity/ManagedIdentityCredentialTest.java

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,38 +67,6 @@ public void testMSIEndpoint() throws Exception {
6767
}
6868
}
6969

70-
@Test
71-
public void testIdentityEndpoint() throws Exception {
72-
Configuration configuration = Configuration.getGlobalConfiguration();
73-
74-
try {
75-
// setup
76-
String endpoint = "http://localhost";
77-
String secret = "secret";
78-
String token1 = "token1";
79-
TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com");
80-
OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
81-
configuration.put("IDENTITY_ENDPOINT", endpoint);
82-
configuration.put("IDENTITY_HEADER", secret);
83-
84-
// mock
85-
IdentityClient identityClient = PowerMockito.mock(IdentityClient.class);
86-
when(identityClient.authenticateToManagedIdentityEndpoint(endpoint, secret, null, null, request1)).thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
87-
PowerMockito.whenNew(IdentityClient.class).withAnyArguments().thenReturn(identityClient);
88-
89-
// test
90-
ManagedIdentityCredential credential = new ManagedIdentityCredentialBuilder().clientId(CLIENT_ID).build();
91-
StepVerifier.create(credential.getToken(request1))
92-
.expectNextMatches(token -> token1.equals(token.getToken())
93-
&& expiresAt.getSecond() == token.getExpiresAt().getSecond())
94-
.verifyComplete();
95-
} finally {
96-
// clean up
97-
configuration.remove("IDENTITY_ENDPOINT");
98-
configuration.remove("IDENTITY_HEADER");
99-
}
100-
}
101-
10270
@Test
10371
public void testIMDS() throws Exception {
10472
// setup

sdk/identity/azure-identity/src/test/java/com/azure/identity/implementation/IdentityClientTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import com.azure.core.credential.AccessToken;
77
import com.azure.core.credential.TokenRequestContext;
8+
import com.azure.core.exception.ClientAuthenticationException;
89
import com.azure.core.util.Configuration;
910
import com.azure.core.util.logging.ClientLogger;
1011
import com.azure.identity.implementation.util.CertificateUtil;
@@ -35,6 +36,7 @@
3536

3637
import javax.net.ssl.HttpsURLConnection;
3738
import java.io.ByteArrayInputStream;
39+
import java.io.IOException;
3840
import java.io.InputStream;
3941
import java.net.HttpURLConnection;
4042
import java.net.URI;
@@ -260,6 +262,36 @@ public void testValidIdentityEndpointMSICodeFlow() throws Exception {
260262
Assert.assertEquals(expiresOn.getSecond(), token.getExpiresAt().getSecond());
261263
}
262264

265+
@Test (expected = ClientAuthenticationException.class)
266+
public void testInValidIdentityEndpointSecretArcCodeFlow() throws Exception {
267+
// setup
268+
Configuration configuration = Configuration.getGlobalConfiguration();
269+
String endpoint = "http://localhost";
270+
TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com");
271+
configuration.put("IDENTITY_ENDPOINT", endpoint);
272+
// mock
273+
mockForArcCodeFlow(401);
274+
275+
// test
276+
IdentityClient client = new IdentityClientBuilder().build();
277+
client.authenticateToArcManagedIdentityEndpoint(endpoint, request).block();
278+
}
279+
280+
@Test (expected = ClientAuthenticationException.class)
281+
public void testInValidIdentityEndpointResponseCodeArcCodeFlow() throws Exception {
282+
// setup
283+
Configuration configuration = Configuration.getGlobalConfiguration();
284+
String endpoint = "http://localhost";
285+
TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com");
286+
configuration.put("IDENTITY_ENDPOINT", endpoint);
287+
// mock
288+
mockForArcCodeFlow(200);
289+
290+
// test
291+
IdentityClient client = new IdentityClientBuilder().build();
292+
client.authenticateToArcManagedIdentityEndpoint(endpoint, request).block();
293+
}
294+
263295
@Test
264296
public void testValidIMDSCodeFlow() throws Exception {
265297
// setup
@@ -527,6 +559,18 @@ private void mockForServiceFabricCodeFlow(String tokenJson) throws Exception {
527559
when(huc.getInputStream()).thenReturn(inputStream);
528560
}
529561

562+
private void mockForArcCodeFlow(int responseCode) throws Exception {
563+
URL u = PowerMockito.mock(URL.class);
564+
whenNew(URL.class).withAnyArguments().thenReturn(u);
565+
HttpURLConnection initConnection = PowerMockito.mock(HttpURLConnection.class);
566+
when(u.openConnection()).thenReturn(initConnection);
567+
PowerMockito.doNothing().when(initConnection).setRequestMethod(anyString());
568+
PowerMockito.doNothing().when(initConnection).setRequestProperty(anyString(), anyString());
569+
PowerMockito.doNothing().when(initConnection).connect();
570+
when(initConnection.getInputStream()).thenThrow(new IOException());
571+
when(initConnection.getResponseCode()).thenReturn(responseCode);
572+
}
573+
530574
private void mockForIMDSCodeFlow(String tokenJson) throws Exception {
531575
URL u = PowerMockito.mock(URL.class);
532576
whenNew(URL.class).withAnyArguments().thenReturn(u);

0 commit comments

Comments
 (0)