diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 98c13bf40b..082b8e4fc9 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -180,6 +180,8 @@ public final class OAuth2LoginConfigurer> private OAuth2AuthorizedClientRepository authorizedClientRepository; + private SecurityContextRepository securityContextRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -233,6 +235,17 @@ public final class OAuth2LoginConfigurer> return this; } + /** + * Sets the {@link SecurityContextRepository} to use. + * @param securityContextRepository the {@link SecurityContextRepository} to use + * @return the {@link OAuth2LoginConfigurer} for further configuration + */ + @Override + public OAuth2LoginConfigurer securityContextRepository(SecurityContextRepository securityContextRepository) { + this.securityContextRepository = securityContextRepository; + return this; + } + /** * Sets the registry for managing the OIDC client-provider session link * @param oidcSessionRegistry the {@link OidcSessionRegistry} to use @@ -299,6 +312,9 @@ public final class OAuth2LoginConfigurer> RequestMatcher processUri = getRequestMatcherBuilder().matcher(this.loginProcessingUrl); authenticationFilter.setRequiresAuthenticationRequestMatcher(processUri); authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + if (this.securityContextRepository != null) { + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + } this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); if (this.loginPage != null) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index 613d825fff..1e0baa57fc 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -106,6 +106,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.session.HttpSessionDestroyedEvent; @@ -116,6 +117,7 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -717,6 +719,12 @@ public class OAuth2LoginConfigurerTests { verify(this.context.getBean(SpyObjectPostProcessor.class).spy).authenticate(any()); } + // gh-16623 + @Test + public void oauth2LoginWithCustomSecurityContextRepository() { + assertThatNoException().isThrownBy(() -> loadConfig(OAuth2LoginConfigSecurityContextRepository.class)); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -961,6 +969,24 @@ public class OAuth2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigSecurityContextRepository extends CommonSecurityFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((login) -> login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .securityContextRepository(new NullSecurityContextRepository())); + // @formatter:on + return super.configureFilterChain(http); + } + + } + @Configuration @EnableWebSecurity static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonSecurityFilterChainConfig {