1
0
mirror of synced 2026-05-22 21:33:16 +00:00

CsrfTokenRequestAttributeHandler -> CsrfTokenRequestHandler

This renames CsrfTokenRequestAttributeHandler to CsrfTokenRequestHandler and
moves usage from CsrfFilter into CsrfTokenRequestHandler.

Closes gh-11892
This commit is contained in:
Rob Winch
2022-09-22 09:26:53 -05:00
parent c1d27612af
commit d94677f87e
27 changed files with 408 additions and 312 deletions
@@ -41,7 +41,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
private final CsrfTokenRepository csrfTokenRepository;
private CsrfTokenRequestAttributeHandler requestAttributeHandler = new CsrfTokenRequestProcessor();
private CsrfTokenRequestHandler requestHandler;
/**
* Creates a new instance
@@ -49,30 +49,28 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
*/
public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor();
processor.setTokenRepository(csrfTokenRepository);
this.requestHandler = processor;
this.csrfTokenRepository = csrfTokenRepository;
}
/**
* Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the
* {@code CsrfToken} available as a request attribute.
* @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use
* Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken}
* available as a request attribute.
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
*/
public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) {
Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null");
this.requestAttributeHandler = requestAttributeHandler;
public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
@Override
public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) throws SessionAuthenticationException {
boolean containsToken = this.csrfTokenRepository.loadToken(request) != null;
if (containsToken) {
this.csrfTokenRepository.saveToken(null, request, response);
CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
this.csrfTokenRepository.saveToken(newToken, request, response);
this.requestAttributeHandler.handle(request, response, () -> newToken);
this.logger.debug("Replaced CSRF Token");
}
this.csrfTokenRepository.saveToken(null, request, response);
this.requestHandler.handle(request, response);
this.logger.debug("Replaced CSRF Token");
}
}
@@ -82,21 +82,19 @@ public final class CsrfFilter extends OncePerRequestFilter {
private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private CsrfTokenRequestAttributeHandler requestAttributeHandler;
private CsrfTokenRequestHandler requestHandler;
private CsrfTokenRequestResolver requestResolver;
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.tokenRepository = csrfTokenRepository;
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
this.requestAttributeHandler = csrfTokenRequestProcessor;
csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository);
this.requestHandler = csrfTokenRequestProcessor;
this.requestResolver = csrfTokenRequestProcessor;
}
@@ -108,15 +106,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
request.setAttribute(HttpServletResponse.class.getName(), response);
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
boolean missingToken = (csrfToken == null);
if (missingToken) {
csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response);
}
final CsrfToken finalCsrfToken = csrfToken;
this.requestAttributeHandler.handle(request, response, () -> finalCsrfToken);
DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response);
if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not protect against CSRF since request did not match "
@@ -125,8 +115,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
return;
}
CsrfToken csrfToken = deferredCsrfToken.get();
String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken);
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
boolean missingToken = deferredCsrfToken.isGenerated();
this.logger.debug(
LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
@@ -173,18 +165,18 @@ public final class CsrfFilter extends OncePerRequestFilter {
}
/**
* Specifies a {@link CsrfTokenRequestAttributeHandler} that is used to make the
* Specifies a {@link CsrfTokenRequestHandler} that is used to make the
* {@link CsrfToken} available as a request attribute.
*
* <p>
* The default is {@link CsrfTokenRequestProcessor}.
* </p>
* @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use
* @param requestHandler the {@link CsrfTokenRequestHandler} to use
* @since 5.8
*/
public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) {
Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null");
this.requestAttributeHandler = requestAttributeHandler;
public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
Assert.notNull(requestHandler, "requestHandler cannot be null");
this.requestHandler = requestHandler;
}
/**
@@ -16,14 +16,12 @@
package org.springframework.security.web.csrf;
import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* A callback interface that is used to make the {@link CsrfToken} created by the
* {@link CsrfTokenRepository} available as a request attribute. Implementations of this
* A callback interface that is used to determine the {@link CsrfToken} to use and make
* the {@link CsrfToken} available as a request attribute. Implementations of this
* interface may choose to perform additional tasks or customize how the token is made
* available to the application through request attributes.
*
@@ -32,14 +30,13 @@ import javax.servlet.http.HttpServletResponse;
* @see CsrfTokenRequestProcessor
*/
@FunctionalInterface
public interface CsrfTokenRequestAttributeHandler {
public interface CsrfTokenRequestHandler {
/**
* Handles a request using a {@link CsrfToken}.
* @param request the {@code HttpServletRequest} being handled
* @param response the {@code HttpServletResponse} being handled
* @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository}
*/
void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken);
DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
}
@@ -24,7 +24,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* An implementation of the {@link CsrfTokenRequestAttributeHandler} and
* An implementation of the {@link CsrfTokenRequestHandler} and
* {@link CsrfTokenRequestResolver} interfaces that is capable of making the
* {@link CsrfToken} available as a request attribute and resolving the token value as
* either a header or parameter value of the request.
@@ -32,10 +32,22 @@ import org.springframework.util.Assert;
* @author Steve Riesenberg
* @since 5.8
*/
public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandler, CsrfTokenRequestResolver {
public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver {
private String csrfRequestAttributeName;
private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository();
/**
* Sets the {@link CsrfTokenRepository} to use.
* @param tokenRepository the {@link CsrfTokenRepository} to use. Default
* {@link HttpSessionCsrfTokenRepository}
*/
public void setTokenRepository(CsrfTokenRepository tokenRepository) {
Assert.notNull(tokenRepository, "tokenRepository cannot be null");
this.tokenRepository = tokenRepository;
}
/**
* The {@link CsrfToken} is available as a request attribute named
* {@code CsrfToken.class.getName()}. By default, an additional request attribute that
@@ -49,16 +61,18 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl
}
@Override
public void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken) {
public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
Assert.notNull(csrfToken, "csrfToken supplier cannot be null");
CsrfToken actualCsrfToken = csrfToken.get();
Assert.notNull(actualCsrfToken, "csrfToken cannot be null");
request.setAttribute(CsrfToken.class.getName(), actualCsrfToken);
request.setAttribute(HttpServletResponse.class.getName(), response);
DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response);
CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get);
request.setAttribute(CsrfToken.class.getName(), csrfToken);
String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
: actualCsrfToken.getParameterName();
request.setAttribute(csrfAttrName, actualCsrfToken);
: csrfToken.getParameterName();
request.setAttribute(csrfAttrName, csrfToken);
return deferredCsrfToken;
}
@Override
@@ -72,4 +86,78 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl
return actualToken;
}
private static final class SupplierCsrfToken implements CsrfToken {
private final Supplier<CsrfToken> csrfTokenSupplier;
private SupplierCsrfToken(Supplier<CsrfToken> csrfTokenSupplier) {
this.csrfTokenSupplier = csrfTokenSupplier;
}
@Override
public String getHeaderName() {
return getDelegate().getHeaderName();
}
@Override
public String getParameterName() {
return getDelegate().getParameterName();
}
@Override
public String getToken() {
return getDelegate().getToken();
}
private CsrfToken getDelegate() {
CsrfToken delegate = this.csrfTokenSupplier.get();
if (delegate == null) {
throw new IllegalStateException("csrfTokenSupplier returned null delegate");
}
return delegate;
}
}
private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
private final HttpServletRequest request;
private final HttpServletResponse response;
private CsrfToken csrfToken;
private Boolean missingToken;
RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) {
this.request = request;
this.response = response;
}
@Override
public CsrfToken get() {
init();
return this.csrfToken;
}
@Override
public boolean isGenerated() {
init();
return this.missingToken;
}
private void init() {
if (this.csrfToken != null) {
return;
}
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request);
this.missingToken = (this.csrfToken == null);
if (this.missingToken) {
this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request);
CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response);
}
}
}
}
@@ -0,0 +1,41 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
/**
* An interface that allows delayed access to a {@link CsrfToken} that may be generated.
*
* @author Rob Winch
* @since 5.8
*/
public interface DeferredCsrfToken {
/***
* Gets the {@link CsrfToken}
* @return a non-null {@link CsrfToken}
*/
CsrfToken get();
/**
* Returns true if {@link #get()} refers to a generated {@link CsrfToken} or false if
* it already existed.
* @return true if {@link #get()} refers to a generated {@link CsrfToken} or false if
* it already existed.
*/
boolean isGenerated();
}
@@ -27,7 +27,10 @@ import org.springframework.util.Assert;
*
* @author Rob Winch
* @since 4.1
* @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which
* returns a {@link DeferredCsrfToken}
*/
@Deprecated
public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
/**
@@ -75,27 +75,24 @@ public class CsrfAuthenticationStrategyTests {
}
@Test
public void setRequestAttributeHandlerWhenNullThenIllegalStateException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestAttributeHandler(null))
.withMessage("requestAttributeHandler cannot be null");
public void setRequestHandlerWhenNullThenIllegalStateException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestHandler(null))
.withMessage("requestHandler cannot be null");
}
@Test
public void onAuthenticationWhenCustomRequestAttributeHandlerThenUsed() {
given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
this.strategy.setRequestAttributeHandler(requestAttributeHandler);
public void onAuthenticationWhenCustomRequestHandlerThenUsed() {
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
this.strategy.setRequestHandler(requestHandler);
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response);
verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
verifyNoMoreInteractions(requestAttributeHandler);
verify(requestHandler).handle(eq(this.request), eq(this.response));
verifyNoMoreInteractions(requestHandler);
}
@Test
public void logoutRemovesCsrfTokenAndSavesNew() {
given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken);
given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response);
@@ -114,7 +111,6 @@ public class CsrfAuthenticationStrategyTests {
@Test
public void delaySavingCsrf() {
this.strategy = new CsrfAuthenticationStrategy(new LazyCsrfTokenRepository(this.csrfTokenRepository));
given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response);
@@ -128,10 +124,11 @@ public class CsrfAuthenticationStrategyTests {
}
@Test
public void logoutRemovesNoActionIfNullToken() {
public void logoutWhenNoCsrfToken() {
given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
this.response);
verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
verify(this.csrfTokenRepository).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
any(HttpServletResponse.class));
}
@@ -24,8 +24,6 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.assertj.core.api.AbstractObjectAssert;
import org.assertj.core.api.ObjectAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -46,10 +44,12 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/**
* @author Rob Winch
@@ -126,8 +126,8 @@ public class CsrfFilterTests {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@@ -138,8 +138,8 @@ public class CsrfFilterTests {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@@ -150,8 +150,8 @@ public class CsrfFilterTests {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@@ -164,8 +164,8 @@ public class CsrfFilterTests {
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@@ -175,8 +175,8 @@ public class CsrfFilterTests {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@@ -186,8 +186,8 @@ public class CsrfFilterTests {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@@ -198,8 +198,8 @@ public class CsrfFilterTests {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@@ -212,8 +212,8 @@ public class CsrfFilterTests {
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@@ -224,8 +224,8 @@ public class CsrfFilterTests {
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
@@ -238,8 +238,8 @@ public class CsrfFilterTests {
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
// LazyCsrfTokenRepository requires the response as an attribute
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
verify(this.filterChain).doFilter(this.request, this.response);
@@ -304,8 +304,8 @@ public class CsrfFilterTests {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
verifyNoMoreInteractions(this.filterChain);
}
@@ -336,14 +336,14 @@ public class CsrfFilterTests {
}
@Test
public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
this.filter.setRequestAttributeHandler(requestAttributeHandler);
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
given(requestHandler.handle(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
verify(requestHandler).handle(eq(this.request), eq(this.response));
verify(this.filterChain).doFilter(this.request, this.response);
}
@@ -376,39 +376,40 @@ public class CsrfFilterTests {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf";
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
filter.setRequestAttributeHandler(csrfTokenRequestProcessor);
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
filter.setRequestHandler(csrfTokenRequestProcessor);
CsrfToken expectedCsrfToken = spy(this.token);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
}
private static CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken) token);
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private static class CsrfTokenAssert extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> {
private final CsrfToken csrfToken;
/**
* Creates a new {@link ObjectAssert}.
* @param actual the target to verify.
*/
protected CsrfTokenAssert(CsrfToken actual) {
super(actual, CsrfTokenAssert.class);
private final boolean isGenerated;
private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
this.csrfToken = csrfToken;
this.isGenerated = isGenerated;
}
CsrfTokenAssert isEqualTo(CsrfToken expected) {
assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName());
assertThat(this.actual.getToken()).isEqualTo(expected.getToken());
return this;
@Override
public CsrfToken get() {
return this.csrfToken;
}
}
@Override
public boolean isGenerated() {
return this.isGenerated;
}
};
}
@@ -0,0 +1,48 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import org.assertj.core.api.AbstractAssert;
import org.assertj.core.api.Assertions;
/**
* Assertion for validating the properties on CsrfToken are the same.
*/
public class CsrfTokenAssert extends AbstractAssert<CsrfTokenAssert, CsrfToken> {
protected CsrfTokenAssert(CsrfToken csrfToken) {
super(csrfToken, CsrfTokenAssert.class);
}
public static CsrfTokenAssert assertThatCsrfToken(Object csrfToken) {
return new CsrfTokenAssert((CsrfToken) csrfToken);
}
public static CsrfTokenAssert assertThat(CsrfToken csrfToken) {
return new CsrfTokenAssert(csrfToken);
}
public CsrfTokenAssert isEqualTo(CsrfToken csrfToken) {
isNotNull();
assertThat(csrfToken).isNotNull();
Assertions.assertThat(this.actual.getHeaderName()).isEqualTo(csrfToken.getHeaderName());
Assertions.assertThat(this.actual.getParameterName()).isEqualTo(csrfToken.getParameterName());
Assertions.assertThat(this.actual.getToken()).isEqualTo(csrfToken.getToken());
return this;
}
}
@@ -18,12 +18,17 @@ package org.springframework.security.web.csrf;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
/**
* Tests for {@link CsrfTokenRequestProcessor}.
@@ -31,8 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* @author Steve Riesenberg
* @since 5.8
*/
@ExtendWith(MockitoExtension.class)
public class CsrfTokenRequestProcessorTests {
@Mock
CsrfTokenRepository tokenRepository;
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@@ -47,48 +56,36 @@ public class CsrfTokenRequestProcessorTests {
this.response = new MockHttpServletResponse();
this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
this.processor = new CsrfTokenRequestProcessor();
this.processor.setTokenRepository(this.tokenRepository);
}
@Test
public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.processor.handle(null, this.response, () -> this.token))
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response))
.withMessage("request cannot be null");
}
@Test
public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.processor.handle(this.request, null, () -> this.token))
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null))
.withMessage("response cannot be null");
}
@Test
public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, this.response, null))
.withMessage("csrfToken supplier cannot be null");
}
@Test
public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.processor.handle(this.request, this.response, () -> null))
.withMessage("csrfToken cannot be null");
}
@Test
public void handleWhenCsrfRequestAttributeSetThenUsed() {
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
this.processor.setCsrfRequestAttributeName("_csrf");
this.processor.handle(this.request, this.response, () -> this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute("_csrf")).isEqualTo(this.token);
this.processor.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
}
@Test
public void handleWhenValidParametersThenRequestAttributesSet() {
this.processor.handle(this.request, this.response, () -> this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
this.processor.handle(this.request, this.response);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
}
@Test