1
0
mirror of synced 2026-05-22 13:23:17 +00:00

CsrfTokenRequestHandler extends CsrfTokenRequestResolver

Closes gh-11896
This commit is contained in:
Steve Riesenberg
2022-09-23 11:18:31 -05:00
parent d140d95305
commit 46696a9226
18 changed files with 155 additions and 188 deletions
@@ -48,10 +48,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use
*/
public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor();
processor.setTokenRepository(csrfTokenRepository);
this.requestHandler = processor;
this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.csrfTokenRepository = csrfTokenRepository;
}
@@ -82,20 +82,30 @@ public final class CsrfFilter extends OncePerRequestFilter {
private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRequestHandler requestHandler;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private CsrfTokenRequestHandler requestHandler;
private CsrfTokenRequestResolver requestResolver;
/**
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use
* @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead
*/
@Deprecated
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository);
this.requestHandler = csrfTokenRequestProcessor;
this.requestResolver = csrfTokenRequestProcessor;
this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository));
}
/**
* Creates a new instance.
* @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is
* {@link CsrfTokenRepositoryRequestHandler}.
*/
public CsrfFilter(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
@Override
@@ -116,7 +126,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
return;
}
CsrfToken csrfToken = deferredCsrfToken.get();
String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken);
String actualToken = this.requestHandler.resolveCsrfTokenValue(request, csrfToken);
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
boolean missingToken = deferredCsrfToken.isGenerated();
this.logger.debug(
@@ -164,36 +174,6 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler;
}
/**
* Specifies a {@link CsrfTokenRequestHandler} that is used to make the
* {@link CsrfToken} available as a request attribute.
*
* <p>
* The default is {@link CsrfTokenRequestProcessor}.
* </p>
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
* @since 5.8
*/
public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
/**
* Specifies a {@link CsrfTokenRequestResolver} that is used to resolve the token
* value from the request.
*
* <p>
* The default is {@link CsrfTokenRequestProcessor}.
* </p>
* @param requestResolver the {@link CsrfTokenRequestResolver} to use
* @since 5.8
*/
public void setRequestResolver(CsrfTokenRequestResolver requestResolver) {
Assert.notNull(requestResolver, "requestResolver cannot be null");
this.requestResolver = requestResolver;
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected
@@ -24,28 +24,34 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* An implementation of the {@link CsrfTokenRequestHandler} and
* {@link CsrfTokenRequestResolver} interfaces that is capable of making the
* {@link CsrfToken} available as a request attribute and resolving the token value as
* either a header or parameter value of the request.
* An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of
* making the {@link CsrfToken} available as a request attribute and resolving the token
* value as either a header or parameter value of the request.
*
* @author Steve Riesenberg
* @since 5.8
*/
public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver {
public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler {
private final CsrfTokenRepository csrfTokenRepository;
private String csrfRequestAttributeName;
private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository();
/**
* Creates a new instance.
*/
public CsrfTokenRepositoryRequestHandler() {
this(new HttpSessionCsrfTokenRepository());
}
/**
* Sets the {@link CsrfTokenRepository} to use.
* @param tokenRepository the {@link CsrfTokenRepository} to use. Default
* Creates a new instance.
* @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default
* {@link HttpSessionCsrfTokenRepository}
*/
public void setTokenRepository(CsrfTokenRepository tokenRepository) {
Assert.notNull(tokenRepository, "tokenRepository cannot be null");
this.tokenRepository = tokenRepository;
public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository;
}
/**
@@ -75,17 +81,6 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
return deferredCsrfToken;
}
@Override
public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(csrfToken, "csrfToken cannot be null");
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
return actualToken;
}
private static final class SupplierCsrfToken implements CsrfToken {
private final Supplier<CsrfToken> csrfTokenSupplier;
@@ -150,11 +145,12 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
if (this.csrfToken != null) {
return;
}
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request);
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request);
this.missingToken = (this.csrfToken == null);
if (this.missingToken) {
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request);
CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response);
this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request);
CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request,
this.response);
}
}
@@ -19,18 +19,20 @@ package org.springframework.security.web.csrf;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* A callback interface that is used to determine the {@link CsrfToken} to use and make
* the {@link CsrfToken} available as a request attribute. Implementations of this
* interface may choose to perform additional tasks or customize how the token is made
* available to the application through request attributes.
* An interface that is used to determine the {@link CsrfToken} to use and make the
* {@link CsrfToken} available as a request attribute. Implementations of this interface
* may choose to perform additional tasks or customize how the token is made available to
* the application through request attributes.
*
* @author Steve Riesenberg
* @since 5.8
* @see CsrfTokenRequestProcessor
* @see CsrfTokenRepositoryRequestHandler
*/
@FunctionalInterface
public interface CsrfTokenRequestHandler {
public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
/**
* Handles a request using a {@link CsrfToken}.
@@ -39,4 +41,15 @@ public interface CsrfTokenRequestHandler {
*/
DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
@Override
default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(csrfToken, "csrfToken cannot be null");
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
return actualToken;
}
}
@@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletRequest;
*
* @author Steve Riesenberg
* @since 5.8
* @see CsrfTokenRequestProcessor
* @see CsrfTokenRepositoryRequestHandler
*/
@FunctionalInterface
public interface CsrfTokenRequestResolver {
@@ -86,7 +86,11 @@ public class CsrfFilterTests {
}
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
CsrfFilter filter = new CsrfFilter(repository);
return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
}
private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
CsrfFilter filter = new CsrfFilter(requestHandler);
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
filter.setAccessDeniedHandler(this.deniedHandler);
return filter;
@@ -99,7 +103,7 @@ public class CsrfFilterTests {
@Test
public void constructorNullRepository() {
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
}
// SEC-2276
@@ -249,7 +253,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = createCsrfFilter(this.tokenRepository);
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
resetRequestResponse();
@@ -269,7 +273,7 @@ public class CsrfFilterTests {
*/
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
resetRequestResponse();
@@ -284,7 +288,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setAccessDeniedHandler(this.deniedHandler);
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
resetRequestResponse();
@@ -299,7 +303,7 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
@@ -313,7 +317,7 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
CsrfFilter filter = new CsrfFilter(repository);
CsrfFilter filter = createCsrfFilter(repository);
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
MockHttpServletRequest request = new MockHttpServletRequest();
CsrfFilter.skipRequest(request);
@@ -340,25 +344,13 @@ public class CsrfFilterTests {
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
given(requestHandler.handle(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.setRequestHandler(requestHandler);
this.filter = createCsrfFilter(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestHandler).handle(eq(this.request), eq(this.response));
verify(this.filterChain).doFilter(this.request, this.response);
}
@Test
public void doFilterWhenRequestResolverThenUsed() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
CsrfTokenRequestResolver requestResolver = mock(CsrfTokenRequestResolver.class);
given(requestResolver.resolveCsrfTokenValue(this.request, this.token)).willReturn(this.token.getToken());
this.filter.setRequestResolver(requestResolver);
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestResolver).resolveCsrfTokenValue(this.request, this.token);
verify(this.filterChain).doFilter(this.request, this.response);
}
@Test
public void setRequireCsrfProtectionMatcherNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));
@@ -373,16 +365,14 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
throws ServletException, IOException {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf";
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
filter.setRequestHandler(csrfTokenRequestProcessor);
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
this.filter = createCsrfFilter(requestHandler);
CsrfToken expectedCsrfToken = spy(this.token);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
@@ -410,6 +400,6 @@ public class CsrfFilterTests {
return this.isGenerated;
}
};
}
}
@@ -31,13 +31,13 @@ import static org.mockito.BDDMockito.given;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/**
* Tests for {@link CsrfTokenRequestProcessor}.
* Tests for {@link CsrfTokenRepositoryRequestHandler}.
*
* @author Steve Riesenberg
* @since 5.8
*/
@ExtendWith(MockitoExtension.class)
public class CsrfTokenRequestProcessorTests {
public class CsrfTokenRepositoryRequestHandlerTests {
@Mock
CsrfTokenRepository tokenRepository;
@@ -48,34 +48,48 @@ public class CsrfTokenRequestProcessorTests {
private CsrfToken token;
private CsrfTokenRequestProcessor processor;
private CsrfTokenRepositoryRequestHandler handler;
@BeforeEach
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
this.processor = new CsrfTokenRequestProcessor();
this.processor.setTokenRepository(this.tokenRepository);
this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
}
@Test
public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null))
.withMessage("csrfTokenRepository cannot be null");
// @formatter:on
}
@Test
public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(null, this.response))
.withMessage("request cannot be null");
// @formatter:on
}
@Test
public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(this.request, null))
.withMessage("response cannot be null");
// @formatter:on
}
@Test
public void handleWhenCsrfRequestAttributeSetThenUsed() {
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.processor.setCsrfRequestAttributeName("_csrf");
this.processor.handle(this.request, this.response);
this.handler.setCsrfRequestAttributeName("_csrf");
this.handler.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
}
@@ -83,40 +97,46 @@ public class CsrfTokenRequestProcessorTests {
@Test
public void handleWhenValidParametersThenRequestAttributesSet() {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.processor.handle(this.request, this.response);
this.handler.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
}
@Test
public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(null, this.token))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
.withMessage("request cannot be null");
// @formatter:on
}
@Test
public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(this.request, null))
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
.withMessage("csrfToken cannot be null");
// @formatter:on
}
@Test
public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() {
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isNull();
}
@Test
public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() {
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@Test
public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@@ -124,7 +144,7 @@ public class CsrfTokenRequestProcessorTests {
public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() {
this.request.addHeader(this.token.getHeaderName(), "header");
this.request.setParameter(this.token.getParameterName(), "parameter");
String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo("header");
}