Skip to content

Commit 74a118f

Browse files
committed
Ref #28976: reduce overhead of access token refresh
1 parent b9343d1 commit 74a118f

1 file changed

Lines changed: 43 additions & 20 deletions

File tree

src/main/java/eu/openanalytics/containerproxy/auth/impl/oidc/OpenIdReAuthorizeFilter.java

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,26 @@
2525
import org.springframework.security.oauth2.client.ClientAuthorizationException;
2626
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2727
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
28+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2829
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
30+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
31+
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
2932
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
3033
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3134
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
3235
import org.springframework.security.web.util.matcher.OrRequestMatcher;
3336
import org.springframework.security.web.util.matcher.RequestMatcher;
34-
import org.springframework.web.filter.GenericFilterBean;
37+
import org.springframework.web.filter.OncePerRequestFilter;
3538

39+
import javax.annotation.Nonnull;
3640
import javax.inject.Inject;
3741
import javax.servlet.FilterChain;
3842
import javax.servlet.ServletException;
39-
import javax.servlet.ServletRequest;
40-
import javax.servlet.ServletResponse;
4143
import javax.servlet.http.HttpServletRequest;
4244
import javax.servlet.http.HttpServletResponse;
4345
import java.io.IOException;
46+
import java.time.Clock;
47+
import java.time.Duration;
4448

4549
import static eu.openanalytics.containerproxy.auth.impl.oidc.OpenIDConfiguration.REG_ID;
4650

@@ -52,7 +56,7 @@
5256
* This filter only applies to a limited set of requests and not to requests that are proxied to apps.
5357
* Otherwise, this filter would be called too much and cause too much overhead.
5458
*/
55-
public class OpenIdReAuthorizeFilter extends GenericFilterBean {
59+
public class OpenIdReAuthorizeFilter extends OncePerRequestFilter {
5660

5761
private static final RequestMatcher REQUEST_MATCHER = new OrRequestMatcher(
5862
new AntPathRequestMatcher("/app/**"),
@@ -63,31 +67,50 @@ public class OpenIdReAuthorizeFilter extends GenericFilterBean {
6367
@Inject
6468
private OAuth2AuthorizedClientManager oAuth2AuthorizedClientManager;
6569

70+
@Inject
71+
private OAuth2AuthorizedClientService oAuth2AuthorizedClientService;
72+
73+
private final Clock clock = Clock.systemUTC();
74+
75+
// use clock skew of 20 seconds instead of 60 seconds. Otherwise, if the access token is valid for 1 minute, it would get refreshed at each request.
76+
private final Duration clockSkew = Duration.ofSeconds(20);
77+
6678
@Override
67-
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
68-
if (REQUEST_MATCHER.matches((HttpServletRequest) request)) {
79+
protected void doFilterInternal(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response, @Nonnull FilterChain chain) throws ServletException, IOException {
80+
if (REQUEST_MATCHER.matches(request)) {
6981
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
82+
7083
if (auth instanceof OAuth2AuthenticationToken) {
71-
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
72-
.withClientRegistrationId(REG_ID)
73-
.principal(auth)
74-
.attribute(HttpServletRequest.class.getName(), request)
75-
.attribute(HttpServletResponse.class.getName(), response)
76-
.build();
84+
OAuth2AuthorizedClient authorizedClient = oAuth2AuthorizedClientService.loadAuthorizedClient(REG_ID, auth.getName());
85+
86+
if (accessTokenExpired(authorizedClient)) {
87+
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
88+
.withAuthorizedClient(authorizedClient)
89+
.principal(auth)
90+
.build();
7791

78-
// re-authorize
79-
try {
80-
oAuth2AuthorizedClientManager.authorize(authorizeRequest);
81-
} catch (ClientAuthorizationException ex) {
82-
if (ex.getError().getErrorCode().equals(OAuth2ErrorCodes.INVALID_GRANT)) {
83-
// if refresh token has expired or is invalid -> re-start authorization process
84-
throw new ClientAuthorizationRequiredException(ex.getClientRegistrationId());
92+
// re-authorize
93+
try {
94+
oAuth2AuthorizedClientManager.authorize(authorizeRequest);
95+
} catch (ClientAuthorizationException ex) {
96+
if (ex.getError().getErrorCode().equals(OAuth2ErrorCodes.INVALID_GRANT)) {
97+
// if refresh token has expired or is invalid -> re-start authorization process
98+
throw new ClientAuthorizationRequiredException(ex.getClientRegistrationId());
99+
}
100+
throw ex;
85101
}
86-
throw ex;
87102
}
88103
}
89104
}
90105
chain.doFilter(request, response);
91106
}
92107

108+
/**
109+
* See {@link RefreshTokenOAuth2AuthorizedClientProvider}
110+
*/
111+
private boolean accessTokenExpired(OAuth2AuthorizedClient authorizedClient) {
112+
return authorizedClient.getAccessToken().getExpiresAt() == null ||
113+
clock.instant().isAfter(authorizedClient.getAccessToken().getExpiresAt().minus(this.clockSkew));
114+
}
115+
93116
}

0 commit comments

Comments
 (0)