|
5 | 5 |
|
6 | 6 | import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties; |
7 | 7 | import com.azure.spring.autoconfigure.aad.AADTokenClaim; |
8 | | -import com.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory; |
9 | | -import com.azure.spring.autoconfigure.aad.Membership; |
10 | | -import com.azure.spring.autoconfigure.aad.Memberships; |
11 | | -import com.fasterxml.jackson.databind.ObjectMapper; |
12 | | -import com.nimbusds.oauth2.sdk.http.HTTPResponse; |
13 | | -import org.slf4j.Logger; |
14 | | -import org.slf4j.LoggerFactory; |
15 | | -import org.springframework.http.HttpHeaders; |
16 | | -import org.springframework.http.HttpMethod; |
17 | | -import org.springframework.http.MediaType; |
18 | 8 | import org.springframework.security.core.GrantedAuthority; |
19 | 9 | import org.springframework.security.core.authority.SimpleGrantedAuthority; |
20 | 10 | import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; |
|
26 | 16 | import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
27 | 17 | import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; |
28 | 18 | import org.springframework.security.oauth2.core.oidc.user.OidcUser; |
| 19 | +import org.springframework.util.StringUtils; |
29 | 20 |
|
30 | | -import java.io.BufferedReader; |
31 | | -import java.io.IOException; |
32 | | -import java.io.InputStreamReader; |
33 | | -import java.net.HttpURLConnection; |
34 | | -import java.net.URL; |
35 | | -import java.nio.charset.StandardCharsets; |
36 | | -import java.util.LinkedHashSet; |
| 21 | +import java.util.Collections; |
37 | 22 | import java.util.Optional; |
38 | 23 | import java.util.Set; |
39 | 24 | import java.util.stream.Collectors; |
40 | 25 |
|
41 | | -import static com.azure.spring.autoconfigure.aad.Constants.DEFAULT_AUTHORITY_SET; |
42 | 26 | import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX; |
43 | 27 |
|
44 | 28 | /** |
45 | 29 | * This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link |
46 | 30 | * GrantedAuthority}. |
47 | 31 | */ |
48 | 32 | public class AzureActiveDirectoryOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> { |
49 | | - private static final Logger LOGGER = LoggerFactory.getLogger(AzureActiveDirectoryOAuth2UserService.class); |
50 | 33 |
|
51 | 34 | private final OidcUserService oidcUserService; |
52 | 35 | private final AADAuthenticationProperties properties; |
| 36 | + private final GraphClient graphClient; |
53 | 37 |
|
54 | 38 | public AzureActiveDirectoryOAuth2UserService( |
55 | 39 | AADAuthenticationProperties properties |
56 | 40 | ) { |
57 | 41 | this.properties = properties; |
58 | 42 | this.oidcUserService = new OidcUserService(); |
| 43 | + this.graphClient = new GraphClient(properties); |
59 | 44 | } |
60 | 45 |
|
61 | 46 | @Override |
62 | 47 | public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { |
63 | 48 | // Delegate to the default implementation for loading a user |
64 | 49 | OidcUser oidcUser = oidcUserService.loadUser(userRequest); |
65 | | - Set<SimpleGrantedAuthority> authorities = |
66 | | - Optional.of(userRequest) |
67 | | - .map(OAuth2UserRequest::getAccessToken) |
68 | | - .map(AbstractOAuth2Token::getTokenValue) |
69 | | - .map(this::getGroups) |
70 | | - .map(this::toGrantedAuthoritySet) |
71 | | - .filter(g -> !g.isEmpty()) |
72 | | - .orElse(DEFAULT_AUTHORITY_SET); |
| 50 | + Set<String> groups = Optional.of(userRequest) |
| 51 | + .map(OAuth2UserRequest::getAccessToken) |
| 52 | + .map(AbstractOAuth2Token::getTokenValue) |
| 53 | + .map(graphClient::getGroupsFromGraph) |
| 54 | + .orElseGet(Collections::emptySet); |
| 55 | + Set<String> groupRoles = groups.stream() |
| 56 | + .filter(properties::isAllowedGroup) |
| 57 | + .map(group -> ROLE_PREFIX + group) |
| 58 | + .collect(Collectors.toSet()); |
| 59 | + Set<String> allRoles = oidcUser.getAuthorities() |
| 60 | + .stream() |
| 61 | + .map(GrantedAuthority::getAuthority) |
| 62 | + .collect(Collectors.toSet()); |
| 63 | + allRoles.addAll(groupRoles); |
| 64 | + Set<SimpleGrantedAuthority> authorities = allRoles.stream() |
| 65 | + .map(SimpleGrantedAuthority::new) |
| 66 | + .collect(Collectors.toSet()); |
73 | 67 | String nameAttributeKey = |
74 | 68 | Optional.of(userRequest) |
75 | 69 | .map(OAuth2UserRequest::getClientRegistration) |
76 | 70 | .map(ClientRegistration::getProviderDetails) |
77 | 71 | .map(ClientRegistration.ProviderDetails::getUserInfoEndpoint) |
78 | 72 | .map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName) |
79 | | - .filter(s -> !s.isEmpty()) |
| 73 | + .filter(StringUtils::hasText) |
80 | 74 | .orElse(AADTokenClaim.NAME); |
81 | 75 | // Create a copy of oidcUser but use the mappedAuthorities instead |
82 | 76 | return new DefaultOidcUser(authorities, oidcUser.getIdToken(), nameAttributeKey); |
83 | 77 | } |
84 | | - |
85 | | - public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) { |
86 | | - Set<SimpleGrantedAuthority> grantedAuthoritySet = |
87 | | - groups.stream() |
88 | | - .filter(properties::isAllowedGroup) |
89 | | - .map(group -> new SimpleGrantedAuthority(ROLE_PREFIX + group)) |
90 | | - .collect(Collectors.toSet()); |
91 | | - return Optional.of(grantedAuthoritySet) |
92 | | - .filter(g -> !g.isEmpty()) |
93 | | - .orElse(DEFAULT_AUTHORITY_SET); |
94 | | - } |
95 | | - |
96 | | - public Set<String> getGroups(String accessToken) { |
97 | | - final Set<String> groups = new LinkedHashSet<>(); |
98 | | - final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance(); |
99 | | - String aadMembershipRestUri = properties.getGraphMembershipUri(); |
100 | | - while (aadMembershipRestUri != null) { |
101 | | - Memberships memberships; |
102 | | - try { |
103 | | - String membershipsJson = getUserMemberships(accessToken, aadMembershipRestUri); |
104 | | - memberships = objectMapper.readValue(membershipsJson, Memberships.class); |
105 | | - } catch (IOException ioException) { |
106 | | - LOGGER.error("Can not get group information from graph server.", ioException); |
107 | | - break; |
108 | | - } |
109 | | - memberships.getValue() |
110 | | - .stream() |
111 | | - .filter(this::isGroupObject) |
112 | | - .map(Membership::getDisplayName) |
113 | | - .forEach(groups::add); |
114 | | - aadMembershipRestUri = Optional.of(memberships) |
115 | | - .map(Memberships::getOdataNextLink) |
116 | | - .orElse(null); |
117 | | - } |
118 | | - return groups; |
119 | | - } |
120 | | - |
121 | | - private String getUserMemberships(String accessToken, String urlString) throws IOException { |
122 | | - URL url = new URL(urlString); |
123 | | - final HttpURLConnection connection = (HttpURLConnection) url.openConnection(); |
124 | | - connection.setRequestMethod(HttpMethod.GET.toString()); |
125 | | - connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken)); |
126 | | - connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE); |
127 | | - connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE); |
128 | | - final String responseInJson = getResponseString(connection); |
129 | | - final int responseCode = connection.getResponseCode(); |
130 | | - if (responseCode == HTTPResponse.SC_OK) { |
131 | | - return responseInJson; |
132 | | - } else { |
133 | | - throw new IllegalStateException( |
134 | | - "Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson); |
135 | | - } |
136 | | - } |
137 | | - |
138 | | - private String getResponseString(HttpURLConnection connection) throws IOException { |
139 | | - try (BufferedReader reader = |
140 | | - new BufferedReader( |
141 | | - new InputStreamReader(connection.getInputStream(), |
142 | | - StandardCharsets.UTF_8)) |
143 | | - ) { |
144 | | - final StringBuilder stringBuffer = new StringBuilder(); |
145 | | - String line; |
146 | | - while ((line = reader.readLine()) != null) { |
147 | | - stringBuffer.append(line); |
148 | | - } |
149 | | - return stringBuffer.toString(); |
150 | | - } |
151 | | - } |
152 | | - |
153 | | - private boolean isGroupObject(final Membership membership) { |
154 | | - return membership.getObjectType().equals(properties.getUserGroup().getValue()); |
155 | | - } |
156 | 78 | } |
0 commit comments