1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Add deferred CsrfTokenRepository.loadDeferredToken

* Move DeferredCsrfToken to top-level and implement Supplier<CsrfToken>
* Move RepositoryDeferredCsrfToken to top-level and make package-private
* Add CsrfTokenRepository.loadToken(HttpServletRequest, HttpServletResponse)
* Update CsrfFilter
* Rename CsrfTokenRepositoryRequestHandler to CsrfTokenRequestAttributeHandler

Issue gh-11892
Closes gh-11918
This commit is contained in:
Steve Riesenberg
2022-09-27 14:53:54 -05:00
parent 0e215a21ad
commit 475b3bb6bb
31 changed files with 536 additions and 353 deletions
@@ -36,7 +36,6 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
@@ -249,13 +248,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
@SuppressWarnings("unchecked")
@Override
public void configure(H http) {
CsrfFilter filter;
if (this.requestHandler != null) {
filter = new CsrfFilter(this.requestHandler);
}
else {
filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
}
CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@@ -272,6 +265,9 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
if (sessionConfigurer != null) {
sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
}
if (this.requestHandler != null) {
filter.setRequestHandler(this.requestHandler);
}
filter = postProcess(filter);
http.addFilter(filter);
}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -41,7 +41,6 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
@@ -112,18 +111,13 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
}
BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
if (!StringUtils.hasText(this.requestHandlerRef)) {
BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
builder.addConstructorArgValue(csrfTokenRequestHandler);
}
else {
builder.addConstructorArgReference(this.requestHandlerRef);
}
builder.addConstructorArgReference(this.csrfRepositoryRef);
if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
}
if (StringUtils.hasText(this.requestHandlerRef)) {
builder.addPropertyReference("requestHandler", this.requestHandlerRef);
}
this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter;
}
@@ -1152,7 +1152,7 @@ csrf-options.attlist &=
## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
attribute token-repository-ref { xsd:token }?
csrf-options.attlist &=
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
attribute request-handler-ref { xsd:token }?
headers =
@@ -3258,7 +3258,7 @@
</xs:attribute>
<xs:attribute name="request-handler-ref" type="xs:token">
<xs:annotation>
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
</xs:documentation>
</xs:annotation>
</xs:attribute>
@@ -33,7 +33,7 @@ import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
@@ -85,7 +85,7 @@ public class DeferHttpSessionJavaConfigTests {
csrfRepository.setDeferLoadToken(true);
HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
requestCache.setMatchingRequestParameterName("continue");
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler();
CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
requestHandler.setCsrfRequestAttributeName("_csrf");
// @formatter:off
http
@@ -44,8 +44,10 @@ import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -61,7 +63,6 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
@@ -207,30 +208,30 @@ public class CsrfConfigurerTests {
public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("/"));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class));
.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/some-url"));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class));
.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
// SEC-2422
@@ -277,11 +278,13 @@ public class CsrfConfigurerTests {
@Test
public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryConfig.REPO.loadToken(any()))
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/")).andExpect(status().isOk());
verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class));
verify(CsrfTokenRepositoryConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
@Test
@@ -297,8 +300,8 @@ public class CsrfConfigurerTests {
public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login")
@@ -314,11 +317,13 @@ public class CsrfConfigurerTests {
@Test
public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any()))
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
given(CsrfTokenRepositoryInLambdaConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/")).andExpect(status().isOk());
verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class));
verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
@Test
@@ -418,30 +423,30 @@ public class CsrfConfigurerTests {
}
@Test
public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken() throws Exception {
public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/login")).andExpect(status().isOk())
.andExpect(content().string(containsString(csrfToken.getToken())));
verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
any(HttpServletResponse.class));
verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}
@Test
public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception {
public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
@@ -449,9 +454,8 @@ public class CsrfConfigurerTests {
.param("password", "password");
// @formatter:on
this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}
@@ -799,9 +803,11 @@ public class CsrfConfigurerTests {
@Configuration
@EnableWebSecurity
static class CsrfTokenRequestProcessorConfig {
static class CsrfTokenRequestHandlerConfig {
static CsrfTokenRepositoryRequestHandler HANDLER;
static CsrfTokenRepository REPO;
static CsrfTokenRequestHandler HANDLER;
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
@@ -811,7 +817,10 @@ public class CsrfConfigurerTests {
.anyRequest().authenticated()
)
.formLogin(Customizer.withDefaults())
.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER));
.csrf((csrf) -> csrf
.csrfTokenRepository(REPO)
.csrfTokenRequestHandler(HANDLER)
);
// @formatter:on
return http.build();
@@ -841,4 +850,24 @@ public class CsrfConfigurerTests {
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
}
@@ -30,7 +30,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.test.SpringTestContext;
@@ -42,7 +41,6 @@ import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -546,9 +544,8 @@ public class CsrfConfigTests {
@Override
public void match(MvcResult result) {
MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse();
DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response);
assertThat(token.isGenerated()).isFalse();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
assertThat(token).isNotNull();
}
}
@@ -564,8 +561,7 @@ public class CsrfConfigTests {
@Override
public void match(MvcResult result) throws Exception {
MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse();
CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
assertThat(token).isNotNull();
assertThat(token.getToken()).isEqualTo(this.token.apply(result));
}
@@ -26,7 +26,7 @@
<csrf request-handler-ref="requestHandler"/>
</http>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="csrf-attribute-name"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>
@@ -42,7 +42,7 @@
<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
c:delegate-ref="httpSessionCsrfRepository"
p:deferLoadToken="true"/>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="_csrf"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>