1
0
mirror of synced 2026-05-22 21:33:16 +00:00

OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter

Closes gh-18251
This commit is contained in:
Joe Grandja
2025-11-28 06:06:35 -05:00
parent 244b5a16be
commit c53e66a217
7 changed files with 232 additions and 83 deletions
@@ -16,10 +16,12 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import jakarta.servlet.Filter;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.http.HttpMethod;
@@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
import org.springframework.security.web.access.intercept.AuthorizationFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
/**
@@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidatorComposite;
private SessionAuthenticationStrategy sessionAuthenticationStrategy;
/**
@@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
authenticationProviders.addAll(0, this.authenticationProviders);
}
this.authenticationProvidersConsumer.accept(authenticationProviders);
authenticationProviders.forEach(
(authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
authenticationProviders.forEach((authenticationProvider) -> {
httpSecurity.authenticationProvider(postProcess(authenticationProvider));
if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) {
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
"getAuthenticationValidatorComposite");
ReflectionUtils.makeAccessible(method);
this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext>) ReflectionUtils
.invokeMethod(method, authenticationProvider);
}
});
}
@Override
@@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
if (this.sessionAuthenticationStrategy != null) {
authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
}
httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter),
httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class);
// Create and add
// OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class,
"createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class);
ReflectionUtils.makeAccessible(method);
RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils
.getRegisteredClientRepository(httpSecurity);
Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method,
authorizationEndpointFilter, registeredClientRepository,
this.authorizationCodeRequestAuthenticationValidatorComposite);
httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter),
AbstractPreAuthenticatedProcessingFilter.class);
}
@@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
this.mvc
.perform(
get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)))
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.queryParams(getAuthorizationRequestParameters(registeredClient)))
.andExpect(status().isBadRequest())
.andReturn();
}
@@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests {
this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
this.registeredClientRepository.save(registeredClient);
TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal,
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(),
additionalParameters);
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
Instant.now().plus(5, ChronoUnit.MINUTES));
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
registeredClient.getScopes());
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult);
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
given(authorizationRequestAuthenticationProvider
.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
given(authorizationRequestAuthenticationProvider.authenticate(any()))
.willReturn(authorizationCodeRequestAuthenticationResult);
this.mvc
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user")))
.andExpect(status().isOk());
@@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests {
|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
verify(authorizationRequestAuthenticationProvider)
.authenticate(eq(authorizationCodeRequestAuthenticationResult));
verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication));
@SuppressWarnings("unchecked")
ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor