diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java index fc99d177ab..1c91be9378 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,10 +36,13 @@ import reactor.util.context.Context; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.util.Assert; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -61,24 +64,37 @@ import org.springframework.web.context.request.ServletRequestAttributes; @Configuration(proxyBeanMethods = false) class SecurityReactorContextConfiguration { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + @Bean SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() { - return new SecurityReactorContextSubscriberRegistrar(); + SecurityReactorContextSubscriberRegistrar registrar = new SecurityReactorContextSubscriberRegistrar(); + registrar.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + return registrar; + } + + @Autowired(required = false) + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; } static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean { private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR"; - private static final Map> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>(); + private final Map> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>(); - static { - CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class, + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + SecurityReactorContextSubscriberRegistrar() { + this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class, SecurityReactorContextSubscriberRegistrar::getRequest); - CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class, + this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class, SecurityReactorContextSubscriberRegistrar::getResponse); - CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, - SecurityReactorContextSubscriberRegistrar::getAuthentication); + this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, this::getAuthentication); } @Override @@ -93,6 +109,11 @@ class SecurityReactorContextConfiguration { Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY); } + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + CoreSubscriber createSubscriberIfNecessary(CoreSubscriber delegate) { if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) { // Already enriched. No need to create Subscriber so return original @@ -101,8 +122,8 @@ class SecurityReactorContextConfiguration { return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes()); } - private static Map getContextAttributes() { - return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS); + private Map getContextAttributes() { + return new LoadingMap<>(this.CONTEXT_ATTRIBUTE_VALUE_LOADERS); } private static HttpServletRequest getRequest() { @@ -123,8 +144,8 @@ class SecurityReactorContextConfiguration { return null; } - private static Authentication getAuthentication() { - return SecurityContextHolder.getContext().getAuthentication(); + private Authentication getAuthentication() { + return this.securityContextHolderStrategy.getContext().getAuthentication(); } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java index a2f099c0bb..f89f65a678 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,9 +28,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications; import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction; @@ -40,6 +42,8 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.function.client.WebClient; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; @@ -85,6 +89,21 @@ public class SecurityReactorContextConfigurationResourceServerTests { // @formatter:on } + @Test + public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer(); + this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class, + SecurityContextChangedListenerConfig.class).autowire(); + MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication)); + // @formatter:off + this.mockMvc.perform(authenticatedRequest) + .andExpect(status().isOk()) + .andExpect(content().string("Bearer token")); + // @formatter:on + SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class); + verify(strategy, atLeastOnce()).getContext(); + } + @EnableWebSecurity static class BearerFilterConfig extends WebSecurityConfigurerAdapter { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java index 00497ed67c..baee67b1b7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,12 +38,14 @@ import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; @@ -54,6 +56,8 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.entry; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Tests for {@link SecurityReactorContextConfiguration}. @@ -232,6 +236,38 @@ public class SecurityReactorContextConfigurationTests { // @formatter:on } + @Test + public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() { + this.spring.register(SecurityConfig.class, SecurityContextChangedListenerConfig.class).autowire(); + SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class); + strategy.getContext().setAuthentication(this.authentication); + ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build(); + // @formatter:off + ExchangeFilterFunction filter = (req, next) -> Mono.deferContextual(Mono::just) + .filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .map((ctx) -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .cast(Map.class) + .map((attributes) -> clientResponseOk); + // @formatter:on + ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + MockExchangeFunction exchange = new MockExchangeFunction(); + Map expectedContextAttributes = new HashMap<>(); + expectedContextAttributes.put(HttpServletRequest.class, null); + expectedContextAttributes.put(HttpServletResponse.class, null); + expectedContextAttributes.put(Authentication.class, this.authentication); + Mono clientResponseMono = filter.filter(clientRequest, exchange) + .flatMap((response) -> filter.filter(clientRequest, exchange)); + // @formatter:off + StepVerifier.create(clientResponseMono) + .expectAccessibleContext() + .contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes) + .then() + .expectNext(clientResponseOk) + .verifyComplete(); + // @formatter:on + verify(strategy, times(2)).getContext(); + } + @EnableWebSecurity static class SecurityConfig extends WebSecurityConfigurerAdapter {