Skip to content

Commit e7a8c6c

Browse files
authored
Make UserPrincipalManager#getRoles more robust. (Azure#31803)
1 parent 980f59e commit e7a8c6c

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

sdk/spring/spring-cloud-azure-autoconfigure/src/main/java/com/azure/spring/cloud/autoconfigure/aad/filter/UserPrincipalManager.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
import java.net.MalformedURLException;
3131
import java.net.URL;
3232
import java.text.ParseException;
33-
import java.util.Collection;
33+
import java.util.Collections;
3434
import java.util.HashSet;
3535
import java.util.Optional;
3636
import java.util.Set;
3737
import java.util.stream.Collectors;
3838
import java.util.stream.Stream;
39+
import java.util.stream.StreamSupport;
3940

4041
/**
4142
* A user principal manager to load user info from JWT.
@@ -153,11 +154,19 @@ public UserPrincipal buildUserPrincipal(String aadIssuedBearerToken) throws Pars
153154
}
154155

155156
Set<String> getRoles(JWTClaimsSet set) {
156-
return Optional.of(set)
157-
.map(p -> p.getClaim(AadJwtClaimNames.ROLES))
158-
.map(Collection.class::cast)
159-
.map(Collection<Object>::stream)
160-
.orElseGet(Stream::empty)
157+
if (set == null) {
158+
return Collections.emptySet();
159+
}
160+
Object rolesClaim = set.getClaim(AadJwtClaimNames.ROLES);
161+
if (rolesClaim == null) {
162+
return Collections.emptySet();
163+
}
164+
if (rolesClaim instanceof Iterable<?>) {
165+
return StreamSupport.stream(((Iterable<?>) rolesClaim).spliterator(), false)
166+
.map(Object::toString)
167+
.collect(Collectors.toSet());
168+
}
169+
return Stream.of(rolesClaim)
161170
.map(Object::toString)
162171
.collect(Collectors.toSet());
163172
}

sdk/spring/spring-cloud-azure-autoconfigure/src/test/java/com/azure/spring/cloud/autoconfigure/aad/filter/UserPrincipalManagerTests.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
import java.nio.file.Paths;
2121
import java.security.cert.CertificateFactory;
2222
import java.security.cert.X509Certificate;
23+
import java.util.ArrayList;
2324
import java.util.Arrays;
25+
import java.util.Collection;
26+
import java.util.HashSet;
2427
import java.util.Set;
2528
import java.util.stream.Stream;
2629

@@ -77,14 +80,21 @@ void nullIssuer() {
7780
}
7881

7982
@Test
80-
void testRolesExtracted() {
83+
void getRolesTest() {
84+
rolesExtractedAsExpected(null, new ArrayList<>());
85+
rolesExtractedAsExpected("role1", Arrays.asList("role1"));
86+
rolesExtractedAsExpected(Arrays.asList("role1", "role2"), Arrays.asList("role1", "role2"));
87+
rolesExtractedAsExpected(new HashSet<>(Arrays.asList("role1", "role2")), Arrays.asList("role1", "role2"));
88+
}
89+
90+
private void rolesExtractedAsExpected(Object rolesClaimValue, Collection<String> expected) {
8191
JWTClaimsSet set = new JWTClaimsSet.Builder()
82-
.claim("roles", Arrays.asList("role1", "role2"))
92+
.claim("roles", rolesClaimValue)
8393
.build();
84-
Set<String> result = new UserPrincipalManager(null).getRoles(set);
85-
assertEquals(2, result.size());
86-
assertTrue(result.contains("role1"));
87-
assertTrue(result.contains("role2"));
94+
Set<String> actual = new UserPrincipalManager(null).getRoles(set);
95+
assertEquals(expected.size(), actual.size());
96+
assertTrue(expected.containsAll(actual));
97+
assertTrue(actual.containsAll(expected));
8898
}
8999

90100
private String readJwtValidIssuerTxt() {

0 commit comments

Comments
 (0)