首页 > 解决方案 > Spring Oauth2 Client,自动刷新过期的access_token

问题描述

让我解释一下我的用例。

我需要一个 spring boot oauth2 客户端应用程序(不是资源服务器,因为我们已经有一个单独的资源服务器)。我还有以下要求:

  1. 对于每个向资源服务器发出的请求,我们需要发送 id_token。(通过自定义 resttemplate 来完成)。

  2. 对于任何请求,无论它是否调用资源服务器,如果访问令牌已过期,我的应用程序必须自动刷新它(无需任何用户干预,如任何弹出或重定向。)

  3. 如果 refresh_token 也过期,则必须注销用户。

问题:

对于第 2 点和第 3 点,我花了很多时间阅读文档和代码以及 Stack Overflow,但无法找到解决方案(或不明白)。所以我决定将我在许多博客和文档中找到的所有部分放在一起,并提出我的解决方案。以下是我对第 2 点的解决方案。

  1. 我们可以看看下面的代码并建议这种方法是否有任何问题?

    1. 如何解决第 3 点我正在考虑扩展第 2 点的解决方案,但不确定我需要编写什么代码,有人可以指导我吗?
/**
 * 
 * @author agam
 *
 */
@Component
public class ExpiredTokenFilter extends OncePerRequestFilter {

    private static final Logger log = LoggerFactory.getLogger(ExpiredTokenFilter.class);

    private Duration accessTokenExpiresSkew = Duration.ofMillis(1000);

    private Clock clock = Clock.systemUTC();

    @Autowired
    private OAuth2AuthorizedClientService oAuth2AuthorizedClientService;

    @Autowired
    CustomOidcUserService userService;

    private DefaultRefreshTokenTokenResponseClient accessTokenResponseClient;

    private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory;

    private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";

    public ExpiredTokenFilter() {
        super();
        this.accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
        this.jwtDecoderFactory = new OidcIdTokenDecoderFactory();
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
            throws ServletException, IOException {
        log.debug("my custom filter called ");
        /**
         * check if authentication is done.
         */
        if (null != SecurityContextHolder.getContext().getAuthentication()) {
            OAuth2AuthenticationToken currentUser = (OAuth2AuthenticationToken) SecurityContextHolder.getContext()
                    .getAuthentication();
            OAuth2AuthorizedClient authorizedClient = this.oAuth2AuthorizedClientService
                    .loadAuthorizedClient(currentUser.getAuthorizedClientRegistrationId(), currentUser.getName());
            /**
             * Check if token existing token is expired.
             */
            if (isExpired(authorizedClient.getAccessToken())) {

                /*
                 * do something to get new access token
                 */
                log.debug(
                        "=========================== Token Expired !! going to refresh ================================================");
                ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
                /*
                 * Call Auth server token endpoint to refresh token. 
                 */
                OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
                        clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken());
                OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient
                        .getTokenResponse(refreshTokenGrantRequest);
                /*
                 * Convert id_token to OidcToken.
                 */
                OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);
                /*
                 * Since I have already implemented a custom OidcUserService, reuse existing
                 * code to get new user. 
                 */
                OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(clientRegistration,
                        accessTokenResponse.getAccessToken(), idToken, accessTokenResponse.getAdditionalParameters()));

                log.debug(
                        "=========================== Token Refresh Done !! ================================================");
                /*
                 * Print old and new id_token, just in case.
                 */
                DefaultOidcUser user = (DefaultOidcUser) currentUser.getPrincipal();
                log.debug("new id token is " + oidcUser.getIdToken().getTokenValue());
                log.debug("old id token was " + user.getIdToken().getTokenValue());
                /*
                 * Create new authentication(OAuth2AuthenticationToken).
                 */
                OAuth2AuthenticationToken updatedUser = new OAuth2AuthenticationToken(oidcUser,
                        oidcUser.getAuthorities(), currentUser.getAuthorizedClientRegistrationId());
                /*
                 * Update access_token and refresh_token by saving new authorized client.
                 */
                OAuth2AuthorizedClient updatedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration,
                        currentUser.getName(), accessTokenResponse.getAccessToken(),
                        accessTokenResponse.getRefreshToken());
                this.oAuth2AuthorizedClientService.saveAuthorizedClient(updatedAuthorizedClient, updatedUser);
                /*
                 * Set new authentication in SecurityContextHolder.
                 */
                SecurityContextHolder.getContext().setAuthentication(updatedUser);
            }

        }
        filterChain.doFilter(request, response);
    }

    private Boolean isExpired(OAuth2AccessToken oAuth2AccessToken) {
        Instant now = this.clock.instant();
        Instant expiresAt = oAuth2AccessToken.getExpiresAt();
        return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
    }

    private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
            OAuth2AccessTokenResponse accessTokenResponse) {
        JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
        Jwt jwt;
        try {
            jwt = jwtDecoder
                    .decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
        } catch (JwtException ex) {
            OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
            throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
        }
        OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
                jwt.getClaims());
        return idToken;
    }
}

我愿意接受任何改进我的代码的建议。谢谢。

标签: spring-securityoauth-2.0spring-security-oauth2spring-oauth2

解决方案


没有足够的细节来完全理解您的用例。很高兴了解:

  • Spring 安全性正在围绕 OAuth2 快速发展,请考虑提及您正在使用的版本。我的回答假设5.2+
  • 您是在 servlet(用户以某种方式登录)还是非 servlet(类似@Scheduled方法)环境中

根据有限的信息和我有限的知识,我有以下提示:

  • 考虑使用WebClient代替RestTemplate,这是他们未来的方式。它是反应性的,但不要害怕。它也可以在“阻塞”环境中使用,您不会充分发挥它的潜力,但您仍然可以从它对 OAuth2 的更好支持中受益
  • WebClient本身有一个ServletOAuth2AuthorizedClientExchangeFilterFunction几乎可以实现您想要实现的目标
  • 创建时ServletOAuth2AuthorizedClientExchangeFilterFunction,您传入AuthorizedClientServiceOAuth2AuthorizedClientManagerwhich 是关于如何(重新)验证客户端的策略。

示例配置可能如下所示:

@Bean
public WebClient webClient(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientService authorizedClientService) {

    AuthorizedClientServiceOAuth2AuthorizedClientManager manager = new AuthorizedClientServiceOAuth2AuthorizedClientManager(clientRegistrationRepository, authorizedClientService);
    manager.setAuthorizedClientProvider(new DelegatingOAuth2AuthorizedClientProvider(
            new RefreshTokenOAuth2AuthorizedClientProvider(),
            new ClientCredentialsOAuth2AuthorizedClientProvider()));

    ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(manager);

    oauth2.setDefaultClientRegistrationId("your-client-registratioin-id");

    return WebClient.builder()
            .filter(oauth2)
            .apply(oauth2.oauth2Configuration())
            .build();
}

并将其用作:

@Autowire
private final WebClient webClient;

...

webClient.get()
    .uri("http://localhost:8081/api/message")
            .retrieve()
            .bodyToMono(String.class)
            .map(string -> "Retrieved using password grant: " + string)
            .subscribe(log::info);

希望这有助于朝着正确的方向前进!玩得开心


推荐阅读