From f4ca90e71952d257211431b39b6a5677af3ec7aa Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Thu, 1 Sep 2022 15:11:04 -0500 Subject: [PATCH] Add reactive interfaces for CSRF request handling Issue gh-11959 --- .../config/web/server/ServerHttpSecurity.java | 18 +++ .../config/web/server/ServerCsrfDsl.kt | 8 +- .../web/server/ServerHttpSecurityTests.java | 28 +++- .../config/web/server/ServerCsrfDslTests.kt | 56 +++++++- .../web/server/csrf/CsrfWebFilter.java | 51 +++---- ...erverCsrfTokenRequestAttributeHandler.java | 77 ++++++++++ .../csrf/ServerCsrfTokenRequestHandler.java | 54 +++++++ .../csrf/ServerCsrfTokenRequestResolver.java | 45 ++++++ .../web/server/csrf/CsrfWebFilterTests.java | 37 ++++- ...CsrfTokenRequestAttributeHandlerTests.java | 132 ++++++++++++++++++ 10 files changed, 477 insertions(+), 29 deletions(-) create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java create mode 100644 web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 3d0d985117..3da6c116d9 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -147,6 +147,8 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler; +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler; import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository; import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter; import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter; @@ -1852,12 +1854,28 @@ public class ServerHttpSecurity { * @param enabled true if should read from multipart form body, else false. * Default is false * @return the {@link CsrfSpec} for additional configuration + * @deprecated Use + * {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)} + * instead */ + @Deprecated public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) { this.filter.setTokenFromMultipartDataEnabled(enabled); return this; } + /** + * Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the + * {@code CsrfToken} available as an exchange attribute. + * @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use + * @return the {@link CsrfSpec} for additional configuration + * @since 5.8 + */ + public CsrfSpec csrfTokenRequestHandler(ServerCsrfTokenRequestHandler requestHandler) { + this.filter.setRequestHandler(requestHandler); + return this; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt index d1cdd139df..f9c9dc5f0d 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler import org.springframework.security.web.server.csrf.CsrfWebFilter import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher /** @@ -33,13 +34,17 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat * is enabled. * @property tokenFromMultipartDataEnabled if true, the [CsrfWebFilter] should try to resolve the actual CSRF * token from the body of multipart data requests. + * @property csrfTokenRequestHandler the [ServerCsrfTokenRequestHandler] that is used to make the CSRF token + * available as an exchange attribute */ @ServerSecurityMarker class ServerCsrfDsl { var accessDeniedHandler: ServerAccessDeniedHandler? = null var csrfTokenRepository: ServerCsrfTokenRepository? = null var requireCsrfProtectionMatcher: ServerWebExchangeMatcher? = null + @Deprecated("Use 'csrfTokenRequestHandler' instead") var tokenFromMultipartDataEnabled: Boolean? = null + var csrfTokenRequestHandler: ServerCsrfTokenRequestHandler? = null private var disabled = false @@ -56,6 +61,7 @@ class ServerCsrfDsl { csrfTokenRepository?.also { csrf.csrfTokenRepository(csrfTokenRepository) } requireCsrfProtectionMatcher?.also { csrf.requireCsrfProtectionMatcher(requireCsrfProtectionMatcher) } tokenFromMultipartDataEnabled?.also { csrf.tokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled!!) } + csrfTokenRequestHandler?.also { csrf.csrfTokenRequestHandler(csrfTokenRequestHandler) } if (disabled) { csrf.disable() } diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 5376ac9858..30bad495c0 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,8 +64,11 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; +import org.springframework.security.web.server.csrf.CsrfToken; import org.springframework.security.web.server.csrf.CsrfWebFilter; +import org.springframework.security.web.server.csrf.DefaultCsrfToken; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler; import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.test.util.ReflectionTestUtils; @@ -84,6 +87,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.springframework.security.config.Customizer.withDefaults; @@ -500,6 +504,28 @@ public class ServerHttpSecurityTests { verify(customServerCsrfTokenRepository).loadToken(any()); } + @Test + public void postWhenCustomRequestHandlerThenUsed() { + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "tokenValue"); + given(this.csrfTokenRepository.loadToken(any(ServerWebExchange.class))).willReturn(Mono.just(csrfToken)); + given(this.csrfTokenRepository.generateToken(any(ServerWebExchange.class))).willReturn(Mono.empty()); + ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class); + given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class))) + .willReturn(Mono.just(csrfToken.getToken())); + // @formatter:off + this.http.csrf((csrf) -> csrf + .csrfTokenRepository(this.csrfTokenRepository) + .csrfTokenRequestHandler(requestHandler) + ); + // @formatter:on + WebTestClient client = buildClient(); + client.post().uri("/").exchange().expectStatus().isOk(); + verify(this.csrfTokenRepository, times(2)).loadToken(any(ServerWebExchange.class)); + verify(this.csrfTokenRepository).generateToken(any(ServerWebExchange.class)); + verify(requestHandler).handle(any(ServerWebExchange.class), any()); + verify(requestHandler).resolveCsrfTokenValue(any(ServerWebExchange.class), any()); + } + @Test public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() { ServerRequestCache requestCache = spy(new WebSessionServerRequestCache()); diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt index 659d598b27..3475364733 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired import org.springframework.context.ApplicationContext import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration import org.springframework.http.MediaType import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.test.SpringTestContext @@ -33,6 +34,8 @@ import org.springframework.security.web.server.authorization.ServerAccessDeniedH import org.springframework.security.web.server.csrf.CsrfToken import org.springframework.security.web.server.csrf.DefaultCsrfToken import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler +import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher import org.springframework.test.web.reactive.server.WebTestClient @@ -299,4 +302,55 @@ class ServerCsrfDslTests { } } } + + @Test + fun `csrf when custom request handler then handler used`() { + this.spring.register(CustomRequestHandlerConfig::class.java).autowire() + mockkObject(CustomRequestHandlerConfig.REPOSITORY) + every { + CustomRequestHandlerConfig.REPOSITORY.loadToken(any()) + } returns Mono.just(this.token) + mockkObject(CustomRequestHandlerConfig.HANDLER) + every { + CustomRequestHandlerConfig.HANDLER.handle(any(), any()) + } returns Unit + every { + CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) + } returns Mono.just(this.token.token) + + this.client.post() + .uri("/") + .exchange() + .expectStatus().isOk + verify(exactly = 2) { CustomRequestHandlerConfig.REPOSITORY.loadToken(any()) } + verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) } + verify(exactly = 1) { CustomRequestHandlerConfig.HANDLER.handle(any(), any()) } + } + + @Configuration + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomRequestHandlerConfig { + companion object { + val REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository() + val HANDLER: ServerCsrfTokenRequestHandler = ServerCsrfTokenRequestAttributeHandler() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + csrf { + csrfTokenRepository = REPOSITORY + csrfTokenRequestHandler = HANDLER + } + } + } + + @RestController + internal class TestController { + @PostMapping("/") + fun home() { + } + } + } } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 241ad767b6..31d92f2041 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,12 +23,8 @@ import java.util.Set; import reactor.core.publisher.Mono; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.multipart.FormFieldPart; -import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; @@ -63,6 +59,7 @@ import org.springframework.web.server.WebFilterChain; * * @author Rob Winch * @author Parikshit Dutta + * @author Steve Riesenberg * @since 5.0 */ public class CsrfWebFilter implements WebFilter { @@ -86,7 +83,7 @@ public class CsrfWebFilter implements WebFilter { private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler( HttpStatus.FORBIDDEN); - private boolean isTokenFromMultipartDataEnabled; + private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler(); public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); @@ -103,14 +100,34 @@ public class CsrfWebFilter implements WebFilter { this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; } + /** + * Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the + * {@code CsrfToken} available as an exchange attribute. + *

+ * The default is {@link ServerCsrfTokenRequestAttributeHandler}. + * @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use + * @since 5.8 + */ + public void setRequestHandler(ServerCsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; + } + /** * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token * from the body of multipart data requests. * @param tokenFromMultipartDataEnabled true if should read from multipart form body, * else false. Default is false + * @deprecated Use + * {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)} + * instead */ + @Deprecated public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { - this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; + if (this.requestHandler instanceof ServerCsrfTokenRequestAttributeHandler) { + ((ServerCsrfTokenRequestAttributeHandler) this.requestHandler) + .setTokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled); + } } @Override @@ -138,30 +155,14 @@ public class CsrfWebFilter implements WebFilter { } private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { - return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) - .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) - .switchIfEmpty(tokenFromMultipartData(exchange, expected)) + return this.requestHandler.resolveCsrfTokenValue(exchange, expected) .map((actual) -> equalsConstantTime(actual, expected.getToken())); } - private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { - if (!this.isTokenFromMultipartDataEnabled) { - return Mono.empty(); - } - ServerHttpRequest request = exchange.getRequest(); - HttpHeaders headers = request.getHeaders(); - MediaType contentType = headers.getContentType(); - if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { - return Mono.empty(); - } - return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class) - .map(FormFieldPart::value); - } - private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { return Mono.defer(() -> { Mono csrfToken = csrfToken(exchange); - exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); + this.requestHandler.handle(exchange, csrfToken); return chain.filter(exchange); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java new file mode 100644 index 0000000000..0db4288dfd --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandler.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.FormFieldPart; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * An implementation of the {@link ServerCsrfTokenRequestHandler} interface that is + * capable of making the {@link CsrfToken} available as an exchange attribute and + * resolving the token value as either a form data value or header of the request. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public class ServerCsrfTokenRequestAttributeHandler implements ServerCsrfTokenRequestHandler { + + private boolean isTokenFromMultipartDataEnabled; + + @Override + public void handle(ServerWebExchange exchange, Mono csrfToken) { + Assert.notNull(exchange, "exchange cannot be null"); + Assert.notNull(csrfToken, "csrfToken cannot be null"); + exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); + } + + @Override + public Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) { + return ServerCsrfTokenRequestHandler.super.resolveCsrfTokenValue(exchange, csrfToken) + .switchIfEmpty(tokenFromMultipartData(exchange, csrfToken)); + } + + /** + * Specifies if the {@code ServerCsrfTokenRequestResolver} should try to resolve the + * actual CSRF token from the body of multipart data requests. + * @param tokenFromMultipartDataEnabled true if should read from multipart form body, + * else false. Default is false + */ + public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { + this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; + } + + private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { + if (!this.isTokenFromMultipartDataEnabled) { + return Mono.empty(); + } + ServerHttpRequest request = exchange.getRequest(); + HttpHeaders headers = request.getHeaders(); + MediaType contentType = headers.getContentType(); + if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { + return Mono.empty(); + } + return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class) + .map(FormFieldPart::value); + } + +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java new file mode 100644 index 0000000000..71fd06734b --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * A callback interface that is used to make the {@link CsrfToken} created by the + * {@link ServerCsrfTokenRepository} available as an exchange attribute. Implementations + * of this interface may choose to perform additional tasks or customize how the token is + * made available to the application through exchange attributes. + * + * @author Steve Riesenberg + * @since 5.8 + * @see ServerCsrfTokenRequestAttributeHandler + */ +@FunctionalInterface +public interface ServerCsrfTokenRequestHandler extends ServerCsrfTokenRequestResolver { + + /** + * Handles a request using a {@link CsrfToken}. + * @param exchange the {@code ServerWebExchange} with the request being handled + * @param csrfToken the {@code Mono} created by the + * {@link ServerCsrfTokenRepository} + */ + void handle(ServerWebExchange exchange, Mono csrfToken); + + @Override + default Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) { + Assert.notNull(exchange, "exchange cannot be null"); + Assert.notNull(csrfToken, "csrfToken cannot be null"); + return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(csrfToken.getParameterName()))) + .switchIfEmpty( + Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(csrfToken.getHeaderName()))); + } + +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java new file mode 100644 index 0000000000..483f0928ba --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestResolver.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; + +/** + * Implementations of this interface are capable of resolving the token value of a + * {@link CsrfToken} from the provided {@code ServerWebExchange}. Used by the + * {@link CsrfWebFilter}. + * + * @author Steve Riesenberg + * @since 5.8 + * @see ServerCsrfTokenRequestAttributeHandler + */ +@FunctionalInterface +public interface ServerCsrfTokenRequestResolver { + + /** + * Returns the token value resolved from the provided {@code ServerWebExchange} and + * {@link CsrfToken} or {@code Mono.empty()} if not available. + * @param exchange the {@code ServerWebExchange} with the request being processed + * @param csrfToken the {@link CsrfToken} created by the + * {@link ServerCsrfTokenRepository} + * @return the token value resolved from the request + */ + Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken); + +} diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index ba20eaff1c..a97fad9d0b 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,13 +34,17 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebSession; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; /** @@ -65,6 +69,15 @@ public class CsrfWebFilterTests { private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/")); + @Test + public void setRequestHandlerWhenNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.csrfFilter.setRequestHandler(null)) + .withMessage("requestHandler cannot be null"); + // @formatter:on + } + @Test public void filterWhenGetThenSessionNotCreatedAndChainContinues() { PublisherProbe chainResult = PublisherProbe.empty(); @@ -145,6 +158,28 @@ public class CsrfWebFilterTests { chainResult.assertWasSubscribed(); } + @Test + public void filterWhenRequestHandlerSetThenUsed() { + ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class); + given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class))) + .willReturn(Mono.just(this.token.getToken())); + this.csrfFilter.setRequestHandler(requestHandler); + + PublisherProbe chainResult = PublisherProbe.empty(); + given(this.chain.filter(any())).willReturn(chainResult.mono()); + this.csrfFilter.setCsrfTokenRepository(this.repository); + given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); + given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); + this.post = MockServerWebExchange + .from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken())); + Mono result = this.csrfFilter.filter(this.post, this.chain); + StepVerifier.create(result).verifyComplete(); + chainResult.assertWasSubscribed(); + + verify(requestHandler).handle(eq(this.post), any()); + verify(requestHandler).resolveCsrfTokenValue(this.post, this.token); + } + @Test // gh-8452 public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() { diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java new file mode 100644 index 0000000000..8ee5ed9f3a --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestAttributeHandlerTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link ServerCsrfTokenRequestAttributeHandler}. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public class ServerCsrfTokenRequestAttributeHandlerTests { + + private ServerCsrfTokenRequestAttributeHandler handler; + + private MockServerWebExchange exchange; + + private CsrfToken token; + + @BeforeEach + public void setUp() { + this.handler = new ServerCsrfTokenRequestAttributeHandler(); + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); + } + + @Test + public void handleWhenExchangeIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(null, Mono.just(this.token))) + .withMessage("exchange cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(this.exchange, null)) + .withMessage("csrfToken cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenValidParametersThenExchangeAttributeSet() { + Mono csrfToken = Mono.just(this.token); + this.handler.handle(this.exchange, csrfToken); + Mono csrfTokenAttribute = this.exchange.getAttribute(CsrfToken.class.getName()); + assertThat(csrfTokenAttribute).isNotNull(); + assertThat(csrfTokenAttribute).isEqualTo(csrfToken); + } + + @Test + public void resolveCsrfTokenValueWhenExchangeIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token)) + .withMessage("exchange cannot be null"); + // @formatter:on + } + + @Test + public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.exchange, null)) + .withMessage("csrfToken cannot be null"); + // @formatter:on + } + + @Test + public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsEmptyMono() { + Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token); + StepVerifier.create(csrfToken).verifyComplete(); + } + + @Test + public void resolveCsrfTokenValueWhenFormDataSetThenReturnsTokenValue() { + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .body(this.token.getParameterName() + "=" + this.token.getToken())).build(); + Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token); + StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete(); + } + + @Test + public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() { + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), this.token.getToken())).build(); + Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token); + StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete(); + } + + @Test + public void resolveCsrfTokenValueWhenHeaderAndFormDataSetThenFormDataIsPreferred() { + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), "header") + .body(this.token.getParameterName() + "=" + this.token.getToken())).build(); + Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, this.token); + StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete(); + } + +}