OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter
Closes gh-18251
This commit is contained in:
+29
-3
@@ -16,10 +16,12 @@
|
||||
|
||||
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import jakarta.servlet.Filter;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.http.HttpMethod;
|
||||
@@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
|
||||
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
|
||||
import org.springframework.security.web.access.intercept.AuthorizationFilter;
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
||||
@@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM
|
||||
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
@@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
|
||||
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
|
||||
|
||||
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidatorComposite;
|
||||
|
||||
private SessionAuthenticationStrategy sessionAuthenticationStrategy;
|
||||
|
||||
/**
|
||||
@@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
authenticationProviders.addAll(0, this.authenticationProviders);
|
||||
}
|
||||
this.authenticationProvidersConsumer.accept(authenticationProviders);
|
||||
authenticationProviders.forEach(
|
||||
(authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
|
||||
authenticationProviders.forEach((authenticationProvider) -> {
|
||||
httpSecurity.authenticationProvider(postProcess(authenticationProvider));
|
||||
if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) {
|
||||
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
|
||||
"getAuthenticationValidatorComposite");
|
||||
ReflectionUtils.makeAccessible(method);
|
||||
this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext>) ReflectionUtils
|
||||
.invokeMethod(method, authenticationProvider);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
if (this.sessionAuthenticationStrategy != null) {
|
||||
authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
|
||||
}
|
||||
httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter),
|
||||
httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class);
|
||||
// Create and add
|
||||
// OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter
|
||||
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class,
|
||||
"createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class);
|
||||
ReflectionUtils.makeAccessible(method);
|
||||
RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils
|
||||
.getRegisteredClientRepository(httpSecurity);
|
||||
Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method,
|
||||
authorizationEndpointFilter, registeredClientRepository,
|
||||
this.authorizationCodeRequestAuthenticationValidatorComposite);
|
||||
httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter),
|
||||
AbstractPreAuthenticatedProcessingFilter.class);
|
||||
}
|
||||
|
||||
|
||||
+15
-6
@@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
|
||||
this.mvc
|
||||
.perform(
|
||||
get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)))
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
|
||||
.queryParams(getAuthorizationRequestParameters(registeredClient)))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andReturn();
|
||||
}
|
||||
@@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
this.registeredClientRepository.save(registeredClient);
|
||||
|
||||
TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
|
||||
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(),
|
||||
additionalParameters);
|
||||
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
|
||||
Instant.now().plus(5, ChronoUnit.MINUTES));
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
|
||||
registeredClient.getScopes());
|
||||
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
|
||||
given(authorizationRequestAuthenticationProvider
|
||||
.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
|
||||
given(authorizationRequestAuthenticationProvider.authenticate(any()))
|
||||
.willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
|
||||
this.mvc
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
|
||||
.queryParams(getAuthorizationRequestParameters(registeredClient))
|
||||
.with(user("user")))
|
||||
.andExpect(status().isOk());
|
||||
|
||||
@@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
|
||||
|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
|
||||
|
||||
verify(authorizationRequestAuthenticationProvider)
|
||||
.authenticate(eq(authorizationCodeRequestAuthenticationResult));
|
||||
verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor
|
||||
|
||||
Reference in New Issue
Block a user