diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
index 3d0d985117..3da6c116d9 100644
--- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
+++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
@@ -147,6 +147,8 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler;
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
@@ -1852,12 +1854,28 @@ public class ServerHttpSecurity {
* @param enabled true if should read from multipart form body, else false.
* Default is false
* @return the {@link CsrfSpec} for additional configuration
+ * @deprecated Use
+ * {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)}
+ * instead
*/
+ @Deprecated
public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
this.filter.setTokenFromMultipartDataEnabled(enabled);
return this;
}
+ /**
+ * Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the
+ * {@code CsrfToken} available as an exchange attribute.
+ * @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use
+ * @return the {@link CsrfSpec} for additional configuration
+ * @since 5.8
+ */
+ public CsrfSpec csrfTokenRequestHandler(ServerCsrfTokenRequestHandler requestHandler) {
+ this.filter.setRequestHandler(requestHandler);
+ return this;
+ }
+
/**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring
diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt
index d1cdd139df..f9c9dc5f0d 100644
--- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt
+++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.config.web.server
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
import org.springframework.security.web.server.csrf.CsrfWebFilter
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
/**
@@ -33,13 +34,17 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
* is enabled.
* @property tokenFromMultipartDataEnabled if true, the [CsrfWebFilter] should try to resolve the actual CSRF
* token from the body of multipart data requests.
+ * @property csrfTokenRequestHandler the [ServerCsrfTokenRequestHandler] that is used to make the CSRF token
+ * available as an exchange attribute
*/
@ServerSecurityMarker
class ServerCsrfDsl {
var accessDeniedHandler: ServerAccessDeniedHandler? = null
var csrfTokenRepository: ServerCsrfTokenRepository? = null
var requireCsrfProtectionMatcher: ServerWebExchangeMatcher? = null
+ @Deprecated("Use 'csrfTokenRequestHandler' instead")
var tokenFromMultipartDataEnabled: Boolean? = null
+ var csrfTokenRequestHandler: ServerCsrfTokenRequestHandler? = null
private var disabled = false
@@ -56,6 +61,7 @@ class ServerCsrfDsl {
csrfTokenRepository?.also { csrf.csrfTokenRepository(csrfTokenRepository) }
requireCsrfProtectionMatcher?.also { csrf.requireCsrfProtectionMatcher(requireCsrfProtectionMatcher) }
tokenFromMultipartDataEnabled?.also { csrf.tokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled!!) }
+ csrfTokenRequestHandler?.also { csrf.csrfTokenRequestHandler(csrfTokenRequestHandler) }
if (disabled) {
csrf.disable()
}
diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java
index 5376ac9858..30bad495c0 100644
--- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java
+++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -64,8 +64,11 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
+import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.DefaultCsrfToken;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import org.springframework.test.util.ReflectionTestUtils;
@@ -84,6 +87,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.config.Customizer.withDefaults;
@@ -500,6 +504,28 @@ public class ServerHttpSecurityTests {
verify(customServerCsrfTokenRepository).loadToken(any());
}
+ @Test
+ public void postWhenCustomRequestHandlerThenUsed() {
+ CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "tokenValue");
+ given(this.csrfTokenRepository.loadToken(any(ServerWebExchange.class))).willReturn(Mono.just(csrfToken));
+ given(this.csrfTokenRepository.generateToken(any(ServerWebExchange.class))).willReturn(Mono.empty());
+ ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class);
+ given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class)))
+ .willReturn(Mono.just(csrfToken.getToken()));
+ // @formatter:off
+ this.http.csrf((csrf) -> csrf
+ .csrfTokenRepository(this.csrfTokenRepository)
+ .csrfTokenRequestHandler(requestHandler)
+ );
+ // @formatter:on
+ WebTestClient client = buildClient();
+ client.post().uri("/").exchange().expectStatus().isOk();
+ verify(this.csrfTokenRepository, times(2)).loadToken(any(ServerWebExchange.class));
+ verify(this.csrfTokenRepository).generateToken(any(ServerWebExchange.class));
+ verify(requestHandler).handle(any(ServerWebExchange.class), any());
+ verify(requestHandler).resolveCsrfTokenValue(any(ServerWebExchange.class), any());
+ }
+
@Test
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt
index 659d598b27..3475364733 100644
--- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt
+++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean
+import org.springframework.context.annotation.Configuration
import org.springframework.http.MediaType
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestContext
@@ -33,6 +34,8 @@ import org.springframework.security.web.server.authorization.ServerAccessDeniedH
import org.springframework.security.web.server.csrf.CsrfToken
import org.springframework.security.web.server.csrf.DefaultCsrfToken
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
import org.springframework.test.web.reactive.server.WebTestClient
@@ -299,4 +302,55 @@ class ServerCsrfDslTests {
}
}
}
+
+ @Test
+ fun `csrf when custom request handler then handler used`() {
+ this.spring.register(CustomRequestHandlerConfig::class.java).autowire()
+ mockkObject(CustomRequestHandlerConfig.REPOSITORY)
+ every {
+ CustomRequestHandlerConfig.REPOSITORY.loadToken(any())
+ } returns Mono.just(this.token)
+ mockkObject(CustomRequestHandlerConfig.HANDLER)
+ every {
+ CustomRequestHandlerConfig.HANDLER.handle(any(), any())
+ } returns Unit
+ every {
+ CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any())
+ } returns Mono.just(this.token.token)
+
+ this.client.post()
+ .uri("/")
+ .exchange()
+ .expectStatus().isOk
+ verify(exactly = 2) { CustomRequestHandlerConfig.REPOSITORY.loadToken(any()) }
+ verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) }
+ verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.handle(any(), any()) }
+ }
+
+ @Configuration
+ @EnableWebFluxSecurity
+ @EnableWebFlux
+ open class CustomRequestHandlerConfig {
+ companion object {
+ val REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
+ val HANDLER: ServerCsrfTokenRequestHandler = ServerCsrfTokenRequestAttributeHandler()
+ }
+
+ @Bean
+ open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+ return http {
+ csrf {
+ csrfTokenRepository = REPOSITORY
+ csrfTokenRequestHandler = HANDLER
+ }
+ }
+ }
+
+ @RestController
+ internal class TestController {
+ @PostMapping("/")
+ fun home() {
+ }
+ }
+ }
}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
index 241ad767b6..31d92f2041 100644
--- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,12 +23,8 @@ import java.util.Set;
import reactor.core.publisher.Mono;
-import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
-import org.springframework.http.codec.multipart.FormFieldPart;
-import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
@@ -63,6 +59,7 @@ import org.springframework.web.server.WebFilterChain;
*
* @author Rob Winch
* @author Parikshit Dutta
+ * @author Steve Riesenberg
* @since 5.0
*/
public class CsrfWebFilter implements WebFilter {
@@ -86,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
HttpStatus.FORBIDDEN);
- private boolean isTokenFromMultipartDataEnabled;
+ private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
@@ -103,14 +100,34 @@ public class CsrfWebFilter implements WebFilter {
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
}
+ /**
+ * Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the
+ * {@code CsrfToken} available as an exchange attribute.
+ *
+ * The default is {@link ServerCsrfTokenRequestAttributeHandler}.
+ * @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use
+ * @since 5.8
+ */
+ public void setRequestHandler(ServerCsrfTokenRequestHandler requestHandler) {
+ Assert.notNull(requestHandler, "requestHandler cannot be null");
+ this.requestHandler = requestHandler;
+ }
+
/**
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
* from the body of multipart data requests.
* @param tokenFromMultipartDataEnabled true if should read from multipart form body,
* else false. Default is false
+ * @deprecated Use
+ * {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)}
+ * instead
*/
+ @Deprecated
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
- this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
+ if (this.requestHandler instanceof ServerCsrfTokenRequestAttributeHandler) {
+ ((ServerCsrfTokenRequestAttributeHandler) this.requestHandler)
+ .setTokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled);
+ }
}
@Override
@@ -138,30 +155,14 @@ public class CsrfWebFilter implements WebFilter {
}
private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
- return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
- .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
- .switchIfEmpty(tokenFromMultipartData(exchange, expected))
+ return this.requestHandler.resolveCsrfTokenValue(exchange, expected)
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
}
- private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
- if (!this.isTokenFromMultipartDataEnabled) {
- return Mono.empty();
- }
- ServerHttpRequest request = exchange.getRequest();
- HttpHeaders headers = request.getHeaders();
- MediaType contentType = headers.getContentType();
- if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
- return Mono.empty();
- }
- return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
- .map(FormFieldPart::value);
- }
-
private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.defer(() -> {
Mono csrfToken = csrfToken(exchange);
- exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
+ this.requestHandler.handle(exchange, csrfToken);
return chain.filter(exchange);
});
}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java
new file mode 100644
index 0000000000..0db4288dfd
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import reactor.core.publisher.Mono;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.codec.multipart.FormFieldPart;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+
+/**
+ * An implementation of the {@link ServerCsrfTokenRequestHandler} interface that is
+ * capable of making the {@link CsrfToken} available as an exchange attribute and
+ * resolving the token value as either a form data value or header of the request.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+public class ServerCsrfTokenRequestAttributeHandler implements ServerCsrfTokenRequestHandler {
+
+ private boolean isTokenFromMultipartDataEnabled;
+
+ @Override
+ public void handle(ServerWebExchange exchange, Mono csrfToken) {
+ Assert.notNull(exchange, "exchange cannot be null");
+ Assert.notNull(csrfToken, "csrfToken cannot be null");
+ exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
+ }
+
+ @Override
+ public Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) {
+ return ServerCsrfTokenRequestHandler.super.resolveCsrfTokenValue(exchange, csrfToken)
+ .switchIfEmpty(tokenFromMultipartData(exchange, csrfToken));
+ }
+
+ /**
+ * Specifies if the {@code ServerCsrfTokenRequestResolver} should try to resolve the
+ * actual CSRF token from the body of multipart data requests.
+ * @param tokenFromMultipartDataEnabled true if should read from multipart form body,
+ * else false. Default is false
+ */
+ public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
+ this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
+ }
+
+ private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
+ if (!this.isTokenFromMultipartDataEnabled) {
+ return Mono.empty();
+ }
+ ServerHttpRequest request = exchange.getRequest();
+ HttpHeaders headers = request.getHeaders();
+ MediaType contentType = headers.getContentType();
+ if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
+ return Mono.empty();
+ }
+ return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
+ .map(FormFieldPart::value);
+ }
+
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java
new file mode 100644
index 0000000000..71fd06734b
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import reactor.core.publisher.Mono;
+
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+
+/**
+ * A callback interface that is used to make the {@link CsrfToken} created by the
+ * {@link ServerCsrfTokenRepository} available as an exchange attribute. Implementations
+ * of this interface may choose to perform additional tasks or customize how the token is
+ * made available to the application through exchange attributes.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ * @see ServerCsrfTokenRequestAttributeHandler
+ */
+@FunctionalInterface
+public interface ServerCsrfTokenRequestHandler extends ServerCsrfTokenRequestResolver {
+
+ /**
+ * Handles a request using a {@link CsrfToken}.
+ * @param exchange the {@code ServerWebExchange} with the request being handled
+ * @param csrfToken the {@code Mono} created by the
+ * {@link ServerCsrfTokenRepository}
+ */
+ void handle(ServerWebExchange exchange, Mono csrfToken);
+
+ @Override
+ default Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) {
+ Assert.notNull(exchange, "exchange cannot be null");
+ Assert.notNull(csrfToken, "csrfToken cannot be null");
+ return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(csrfToken.getParameterName())))
+ .switchIfEmpty(
+ Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(csrfToken.getHeaderName())));
+ }
+
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java
new file mode 100644
index 0000000000..483f0928ba
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import reactor.core.publisher.Mono;
+
+import org.springframework.web.server.ServerWebExchange;
+
+/**
+ * Implementations of this interface are capable of resolving the token value of a
+ * {@link CsrfToken} from the provided {@code ServerWebExchange}. Used by the
+ * {@link CsrfWebFilter}.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ * @see ServerCsrfTokenRequestAttributeHandler
+ */
+@FunctionalInterface
+public interface ServerCsrfTokenRequestResolver {
+
+ /**
+ * Returns the token value resolved from the provided {@code ServerWebExchange} and
+ * {@link CsrfToken} or {@code Mono.empty()} if not available.
+ * @param exchange the {@code ServerWebExchange} with the request being processed
+ * @param csrfToken the {@link CsrfToken} created by the
+ * {@link ServerCsrfTokenRepository}
+ * @return the token value resolved from the request
+ */
+ Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken);
+
+}
diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
index ba20eaff1c..a97fad9d0b 100644
--- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
+++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -34,13 +34,17 @@ import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
@@ -65,6 +69,15 @@ public class CsrfWebFilterTests {
private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/"));
+ @Test
+ public void setRequestHandlerWhenNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.csrfFilter.setRequestHandler(null))
+ .withMessage("requestHandler cannot be null");
+ // @formatter:on
+ }
+
@Test
public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
PublisherProbe chainResult = PublisherProbe.empty();
@@ -145,6 +158,28 @@ public class CsrfWebFilterTests {
chainResult.assertWasSubscribed();
}
+ @Test
+ public void filterWhenRequestHandlerSetThenUsed() {
+ ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class);
+ given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class)))
+ .willReturn(Mono.just(this.token.getToken()));
+ this.csrfFilter.setRequestHandler(requestHandler);
+
+ PublisherProbe chainResult = PublisherProbe.empty();
+ given(this.chain.filter(any())).willReturn(chainResult.mono());
+ this.csrfFilter.setCsrfTokenRepository(this.repository);
+ given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
+ given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
+ this.post = MockServerWebExchange
+ .from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
+ Mono result = this.csrfFilter.filter(this.post, this.chain);
+ StepVerifier.create(result).verifyComplete();
+ chainResult.assertWasSubscribed();
+
+ verify(requestHandler).handle(eq(this.post), any());
+ verify(requestHandler).resolveCsrfTokenValue(this.post, this.token);
+ }
+
@Test
// gh-8452
public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() {
diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java
new file mode 100644
index 0000000000..8ee5ed9f3a
--- /dev/null
+++ b/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+
+/**
+ * Tests for {@link ServerCsrfTokenRequestAttributeHandler}.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+public class ServerCsrfTokenRequestAttributeHandlerTests {
+
+ private ServerCsrfTokenRequestAttributeHandler handler;
+
+ private MockServerWebExchange exchange;
+
+ private CsrfToken token;
+
+ @BeforeEach
+ public void setUp() {
+ this.handler = new ServerCsrfTokenRequestAttributeHandler();
+ this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
+ this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
+ }
+
+ @Test
+ public void handleWhenExchangeIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.handle(null, Mono.just(this.token)))
+ .withMessage("exchange cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.handle(this.exchange, null))
+ .withMessage("csrfToken cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenValidParametersThenExchangeAttributeSet() {
+ Mono csrfToken = Mono.just(this.token);
+ this.handler.handle(this.exchange, csrfToken);
+ Mono csrfTokenAttribute = this.exchange.getAttribute(CsrfToken.class.getName());
+ assertThat(csrfTokenAttribute).isNotNull();
+ assertThat(csrfTokenAttribute).isEqualTo(csrfToken);
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenExchangeIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
+ .withMessage("exchange cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.exchange, null))
+ .withMessage("csrfToken cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsEmptyMono() {
+ Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token);
+ StepVerifier.create(csrfToken).verifyComplete();
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenFormDataSetThenReturnsTokenValue() {
+ this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/")
+ .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
+ .body(this.token.getParameterName() + "=" + this.token.getToken())).build();
+ Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token);
+ StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete();
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
+ this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/")
+ .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
+ .header(this.token.getHeaderName(), this.token.getToken())).build();
+ Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token);
+ StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete();
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenHeaderAndFormDataSetThenFormDataIsPreferred() {
+ this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/")
+ .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
+ .header(this.token.getHeaderName(), "header")
+ .body(this.token.getParameterName() + "=" + this.token.getToken())).build();
+ Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token);
+ StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete();
+ }
+
+}