Skip to content

Commit 60f92e8

Browse files
authored
Added support for multi-tenant authentication for Key Vault clients. (Azure#25300)
* Added support for multi-tenant authentication on Key Vault libraries. * Updated CHANGELOG. * Removed RuntimeException being thrown from KeyVaultCredentialPolicy. * Added tests and fixed an issue that grabbed the wrong segment when parsing an authorization URI. * Fixed test issues. * Fixed test issues for good.
1 parent b03f3da commit 60f92e8

File tree

21 files changed

+1001
-114
lines changed

21 files changed

+1001
-114
lines changed

sdk/keyvault/azure-security-keyvault-administration/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added support for multi-tenant authentication in clients.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/keyvault/azure-security-keyvault-administration/src/main/java/com/azure/security/keyvault/administration/implementation/KeyVaultCredentialPolicy.java

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import reactor.core.publisher.Flux;
1313
import reactor.core.publisher.Mono;
1414

15+
import java.net.URI;
16+
import java.net.URISyntaxException;
1517
import java.net.URL;
1618
import java.nio.ByteBuffer;
1719
import java.util.Collections;
@@ -34,8 +36,8 @@ public class KeyVaultCredentialPolicy extends BearerTokenAuthenticationPolicy {
3436
private static final String KEY_VAULT_STASHED_CONTENT_KEY = "KeyVaultCredentialPolicyStashedBody";
3537
private static final String KEY_VAULT_STASHED_CONTENT_LENGTH_KEY = "KeyVaultCredentialPolicyStashedContentLength";
3638
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
37-
private static final ConcurrentMap<String, String> SCOPE_CACHE = new ConcurrentHashMap<>();
38-
private String scope;
39+
private static final ConcurrentMap<String, ChallengeParameters> CHALLENGE_CACHE = new ConcurrentHashMap<>();
40+
private ChallengeParameters challenge;
3941

4042
/**
4143
* Creates a {@link KeyVaultCredentialPolicy}.
@@ -80,6 +82,7 @@ private static Map<String, String> extractChallengeAttributes(String authenticat
8082
*
8183
* @param authenticateHeader The authentication header containing all the challenges.
8284
* @param authChallengePrefix The authentication challenge name.
85+
*
8386
* @return A boolean indicating if the challenge is a bearer challenge or not.
8487
*/
8588
private static boolean isBearerChallenge(String authenticateHeader, String authChallengePrefix) {
@@ -92,15 +95,17 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
9295
return Mono.defer(() -> {
9396
HttpRequest request = context.getHttpRequest();
9497

95-
// If this policy doesn't have an authorityScope cached try to get it from the static challenge cache.
96-
if (this.scope == null) {
98+
// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
99+
if (this.challenge == null) {
97100
String authority = getRequestAuthority(request);
98-
this.scope = SCOPE_CACHE.get(authority);
101+
this.challenge = CHALLENGE_CACHE.get(authority);
99102
}
100103

101-
if (this.scope != null) {
102-
// We fetched the scope from the cache, but we have not initialized the scopes in the base yet.
103-
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
104+
if (this.challenge != null) {
105+
// We fetched the challenge from the cache, but we have not initialized the scopes in the base yet.
106+
TokenRequestContext tokenRequestContext = new TokenRequestContext()
107+
.addScopes(this.challenge.getScopes())
108+
.setTenantId(this.challenge.getTenantId());
104109

105110
return setAuthorizationHeader(context, tokenRequestContext);
106111
}
@@ -150,33 +155,92 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context
150155
}
151156

152157
if (scope == null) {
153-
this.scope = SCOPE_CACHE.get(authority);
158+
this.challenge = CHALLENGE_CACHE.get(authority);
154159

155-
if (this.scope == null) {
160+
if (this.challenge == null) {
156161
return Mono.just(false);
157162
}
158163
} else {
159-
this.scope = scope;
164+
String authorization = challengeAttributes.get("authorization");
165+
166+
if (authorization == null) {
167+
authorization = challengeAttributes.get("authorization_uri");
168+
}
160169

161-
SCOPE_CACHE.put(authority, this.scope);
170+
final URI authorizationUri;
171+
172+
try {
173+
authorizationUri = new URI(authorization);
174+
} catch (URISyntaxException e) {
175+
// The challenge authorization URI is invalid.
176+
return Mono.just(false);
177+
}
178+
179+
this.challenge = new ChallengeParameters(authorizationUri, new String[] { scope });
180+
181+
CHALLENGE_CACHE.put(authority, this.challenge);
162182
}
163183

164-
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
184+
TokenRequestContext tokenRequestContext = new TokenRequestContext()
185+
.addScopes(this.challenge.getScopes())
186+
.setTenantId(this.challenge.getTenantId());
165187

166188
return setAuthorizationHeader(context, tokenRequestContext)
167189
.then(Mono.just(true));
168190
});
169191
}
170192

171-
static void clearCache() {
172-
SCOPE_CACHE.clear();
193+
private static class ChallengeParameters {
194+
private final URI authorizationUri;
195+
private final String tenantId;
196+
private final String[] scopes;
197+
198+
ChallengeParameters(URI authorizationUri, String[] scopes) {
199+
this.authorizationUri = authorizationUri;
200+
tenantId = authorizationUri.getPath().split("/")[1];
201+
this.scopes = scopes;
202+
}
203+
204+
/**
205+
* Get the {@code authorization} or {@code authorization_uri} parameter from the challenge response.
206+
*/
207+
public URI getAuthorizationUri() {
208+
return authorizationUri;
209+
}
210+
211+
/**
212+
* Get the {@code resource} or {@code scope} parameter from the challenge response. This should end with
213+
* "/.default".
214+
*/
215+
public String[] getScopes() {
216+
return scopes;
217+
}
218+
219+
/**
220+
* Get the tenant ID from {@code authorizationUri}.
221+
*/
222+
public String getTenantId() {
223+
return tenantId;
224+
}
225+
}
226+
227+
public static void clearCache() {
228+
CHALLENGE_CACHE.clear();
173229
}
174230

231+
/**
232+
* Gets the host name and port of the Key Vault or Managed HSM endpoint.
233+
*
234+
* @param request The {@link HttpRequest} to extract the host name and port from.
235+
*
236+
* @return The host name and port of the Key Vault or Managed HSM endpoint.
237+
*/
175238
private static String getRequestAuthority(HttpRequest request) {
176239
URL url = request.getUrl();
177240
String authority = url.getAuthority();
178241
int port = url.getPort();
179242

243+
// Append port for complete authority.
180244
if (!authority.contains(":") && port > 0) {
181245
authority = authority + ":" + port;
182246
}

sdk/keyvault/azure-security-keyvault-certificates/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added support for multi-tenant authentication in clients.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/keyvault/azure-security-keyvault-certificates/src/main/java/com/azure/security/keyvault/certificates/implementation/KeyVaultCredentialPolicy.java

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import reactor.core.publisher.Flux;
1313
import reactor.core.publisher.Mono;
1414

15+
import java.net.URI;
16+
import java.net.URISyntaxException;
1517
import java.net.URL;
1618
import java.nio.ByteBuffer;
1719
import java.util.Collections;
@@ -34,8 +36,8 @@ public class KeyVaultCredentialPolicy extends BearerTokenAuthenticationPolicy {
3436
private static final String KEY_VAULT_STASHED_CONTENT_KEY = "KeyVaultCredentialPolicyStashedBody";
3537
private static final String KEY_VAULT_STASHED_CONTENT_LENGTH_KEY = "KeyVaultCredentialPolicyStashedContentLength";
3638
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
37-
private static final ConcurrentMap<String, String> SCOPE_CACHE = new ConcurrentHashMap<>();
38-
private String scope;
39+
private static final ConcurrentMap<String, ChallengeParameters> CHALLENGE_CACHE = new ConcurrentHashMap<>();
40+
private ChallengeParameters challenge;
3941

4042
/**
4143
* Creates a {@link KeyVaultCredentialPolicy}.
@@ -80,6 +82,7 @@ private static Map<String, String> extractChallengeAttributes(String authenticat
8082
*
8183
* @param authenticateHeader The authentication header containing all the challenges.
8284
* @param authChallengePrefix The authentication challenge name.
85+
*
8386
* @return A boolean indicating if the challenge is a bearer challenge or not.
8487
*/
8588
private static boolean isBearerChallenge(String authenticateHeader, String authChallengePrefix) {
@@ -92,15 +95,17 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
9295
return Mono.defer(() -> {
9396
HttpRequest request = context.getHttpRequest();
9497

95-
// If this policy doesn't have an authorityScope cached try to get it from the static challenge cache.
96-
if (this.scope == null) {
98+
// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
99+
if (this.challenge == null) {
97100
String authority = getRequestAuthority(request);
98-
this.scope = SCOPE_CACHE.get(authority);
101+
this.challenge = CHALLENGE_CACHE.get(authority);
99102
}
100103

101-
if (this.scope != null) {
102-
// We fetched the scope from the cache, but we have not initialized the scopes in the base yet.
103-
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
104+
if (this.challenge != null) {
105+
// We fetched the challenge from the cache, but we have not initialized the scopes in the base yet.
106+
TokenRequestContext tokenRequestContext = new TokenRequestContext()
107+
.addScopes(this.challenge.getScopes())
108+
.setTenantId(this.challenge.getTenantId());
104109

105110
return setAuthorizationHeader(context, tokenRequestContext);
106111
}
@@ -150,33 +155,92 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context
150155
}
151156

152157
if (scope == null) {
153-
this.scope = SCOPE_CACHE.get(authority);
158+
this.challenge = CHALLENGE_CACHE.get(authority);
154159

155-
if (this.scope == null) {
160+
if (this.challenge == null) {
156161
return Mono.just(false);
157162
}
158163
} else {
159-
this.scope = scope;
164+
String authorization = challengeAttributes.get("authorization");
165+
166+
if (authorization == null) {
167+
authorization = challengeAttributes.get("authorization_uri");
168+
}
160169

161-
SCOPE_CACHE.put(authority, this.scope);
170+
final URI authorizationUri;
171+
172+
try {
173+
authorizationUri = new URI(authorization);
174+
} catch (URISyntaxException e) {
175+
// The challenge authorization URI is invalid.
176+
return Mono.just(false);
177+
}
178+
179+
this.challenge = new ChallengeParameters(authorizationUri, new String[] { scope });
180+
181+
CHALLENGE_CACHE.put(authority, this.challenge);
162182
}
163183

164-
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
184+
TokenRequestContext tokenRequestContext = new TokenRequestContext()
185+
.addScopes(this.challenge.getScopes())
186+
.setTenantId(this.challenge.getTenantId());
165187

166188
return setAuthorizationHeader(context, tokenRequestContext)
167189
.then(Mono.just(true));
168190
});
169191
}
170192

171-
static void clearCache() {
172-
SCOPE_CACHE.clear();
193+
private static class ChallengeParameters {
194+
private final URI authorizationUri;
195+
private final String tenantId;
196+
private final String[] scopes;
197+
198+
ChallengeParameters(URI authorizationUri, String[] scopes) {
199+
this.authorizationUri = authorizationUri;
200+
tenantId = authorizationUri.getPath().split("/")[1];
201+
this.scopes = scopes;
202+
}
203+
204+
/**
205+
* Get the {@code authorization} or {@code authorization_uri} parameter from the challenge response.
206+
*/
207+
public URI getAuthorizationUri() {
208+
return authorizationUri;
209+
}
210+
211+
/**
212+
* Get the {@code resource} or {@code scope} parameter from the challenge response. This should end with
213+
* "/.default".
214+
*/
215+
public String[] getScopes() {
216+
return scopes;
217+
}
218+
219+
/**
220+
* Get the tenant ID from {@code authorizationUri}.
221+
*/
222+
public String getTenantId() {
223+
return tenantId;
224+
}
225+
}
226+
227+
public static void clearCache() {
228+
CHALLENGE_CACHE.clear();
173229
}
174230

231+
/**
232+
* Gets the host name and port of the Key Vault or Managed HSM endpoint.
233+
*
234+
* @param request The {@link HttpRequest} to extract the host name and port from.
235+
*
236+
* @return The host name and port of the Key Vault or Managed HSM endpoint.
237+
*/
175238
private static String getRequestAuthority(HttpRequest request) {
176239
URL url = request.getUrl();
177240
String authority = url.getAuthority();
178241
int port = url.getPort();
179242

243+
// Append port for complete authority.
180244
if (!authority.contains(":") && port > 0) {
181245
authority = authority + ":" + port;
182246
}

sdk/keyvault/azure-security-keyvault-certificates/src/test/java/com/azure/security/keyvault/certificates/CertificateClientTest.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.azure.core.util.Context;
1313
import com.azure.core.util.polling.PollResponse;
1414
import com.azure.core.util.polling.SyncPoller;
15+
import com.azure.security.keyvault.certificates.implementation.KeyVaultCredentialPolicy;
1516
import com.azure.security.keyvault.certificates.models.CertificateContact;
1617
import com.azure.security.keyvault.certificates.models.CertificateIssuer;
1718
import com.azure.security.keyvault.certificates.models.CertificateContentType;
@@ -35,6 +36,8 @@
3536
import java.util.HashMap;
3637
import java.util.Arrays;
3738
import java.util.HashSet;
39+
import java.util.UUID;
40+
3841
import static org.junit.jupiter.api.Assertions.assertEquals;
3942
import static org.junit.jupiter.api.Assertions.assertNotNull;
4043
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -52,7 +55,12 @@ protected void beforeTest() {
5255
}
5356

5457
private void createCertificateClient(HttpClient httpClient, CertificateServiceVersion serviceVersion) {
55-
HttpPipeline httpPipeline = getHttpPipeline(httpClient);
58+
createCertificateClient(httpClient, serviceVersion, null);
59+
}
60+
61+
private void createCertificateClient(HttpClient httpClient, CertificateServiceVersion serviceVersion,
62+
String testTenantId) {
63+
HttpPipeline httpPipeline = getHttpPipeline(httpClient, testTenantId);
5664
CertificateAsyncClient asyncClient = spy(new CertificateClientBuilder()
5765
.vaultUrl(getEndpoint())
5866
.pipeline(httpPipeline)
@@ -81,6 +89,31 @@ public void createCertificate(HttpClient httpClient, CertificateServiceVersion s
8189
});
8290
}
8391

92+
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
93+
@MethodSource("getTestParameters")
94+
public void createCertificateWithMultipleTenants(HttpClient httpClient, CertificateServiceVersion serviceVersion) {
95+
createCertificateClient(httpClient, serviceVersion, testResourceNamer.randomUuid());
96+
createCertificateRunner((policy) -> {
97+
String certName = generateResourceId("testCer");
98+
SyncPoller<CertificateOperation, KeyVaultCertificateWithPolicy> certPoller =
99+
client.beginCreateCertificate(certName, policy);
100+
certPoller.waitForCompletion();
101+
KeyVaultCertificateWithPolicy expected = certPoller.getFinalResult();
102+
assertEquals(certName, expected.getName());
103+
assertNotNull(expected.getProperties().getCreatedOn());
104+
});
105+
KeyVaultCredentialPolicy.clearCache(); // Ensure we don't have anything cached and try again.
106+
createCertificateRunner((policy) -> {
107+
String certName = generateResourceId("testCer2");
108+
SyncPoller<CertificateOperation, KeyVaultCertificateWithPolicy> certPoller =
109+
client.beginCreateCertificate(certName, policy);
110+
certPoller.waitForCompletion();
111+
KeyVaultCertificateWithPolicy expected = certPoller.getFinalResult();
112+
assertEquals(certName, expected.getName());
113+
assertNotNull(expected.getProperties().getCreatedOn());
114+
});
115+
}
116+
84117
private void deleteAndPurgeCertificate(String certName) {
85118
SyncPoller<DeletedCertificate, Void> deletePoller = client.beginDeleteCertificate(certName);
86119
deletePoller.poll();

0 commit comments

Comments
 (0)