|
4 | 4 | package com.azure.spring.aad.webapp; |
5 | 5 |
|
6 | 6 | import com.azure.spring.aad.AADClientRegistrationRepository; |
| 7 | +import org.slf4j.Logger; |
| 8 | +import org.slf4j.LoggerFactory; |
7 | 9 | import org.springframework.security.core.Authentication; |
8 | 10 | import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; |
9 | 11 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; |
10 | 12 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; |
11 | 13 | import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; |
12 | 14 | import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; |
13 | 15 | import org.springframework.security.oauth2.core.OAuth2AccessToken; |
| 16 | +import org.springframework.web.context.request.RequestContextHolder; |
| 17 | +import org.springframework.web.context.request.ServletRequestAttributes; |
14 | 18 |
|
15 | 19 | import javax.servlet.http.HttpServletRequest; |
16 | 20 | import javax.servlet.http.HttpServletResponse; |
|
24 | 28 | */ |
25 | 29 | public class AADOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { |
26 | 30 |
|
| 31 | + private static final Logger LOGGER = LoggerFactory.getLogger(AADOAuth2AuthorizedClientRepository.class); |
| 32 | + |
27 | 33 | private final AADWebAppClientRegistrationRepository repo; |
28 | 34 | private final OAuth2AuthorizedClientRepository delegate; |
29 | 35 | private final OAuth2AuthorizedClientProvider provider; |
@@ -73,7 +79,15 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String id, |
73 | 79 | .principal(principal) |
74 | 80 | .attributes(getAttributesConsumer(scopes)) |
75 | 81 | .build(); |
76 | | - return (T) provider.authorize(context); |
| 82 | + OAuth2AuthorizedClient clientGotByRefreshToken = provider.authorize(context); |
| 83 | + try { |
| 84 | + ServletRequestAttributes attributes = |
| 85 | + (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes(); |
| 86 | + delegate.saveAuthorizedClient(clientGotByRefreshToken, principal, request, attributes.getResponse()); |
| 87 | + } catch (IllegalStateException exception) { |
| 88 | + LOGGER.warn("Can not save OAuth2AuthorizedClient.", exception); |
| 89 | + } |
| 90 | + return (T) clientGotByRefreshToken; |
77 | 91 | } |
78 | 92 | return null; |
79 | 93 | } |
|
0 commit comments