diff --git a/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java index 2ad6cf7b39..07d22be617 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java @@ -63,6 +63,9 @@ public class CorsSpecTests { @BeforeEach public void setup() { this.http = new TestingServerHttpSecurity().applicationContext(this.context); + } + + private void givenGetCorsConfigurationWillReturnWildcard() { CorsConfiguration value = new CorsConfiguration(); value.setAllowedOrigins(Arrays.asList("*")); given(this.source.getCorsConfiguration(any())).willReturn(value); @@ -70,6 +73,7 @@ public class CorsSpecTests { @Test public void corsWhenEnabledThenAccessControlAllowOriginAndSecurityHeaders() { + givenGetCorsConfigurationWillReturnWildcard(); this.http.cors().configurationSource(this.source); this.expectedHeaders.set("Access-Control-Allow-Origin", "*"); this.expectedHeaders.set("X-Frame-Options", "DENY"); @@ -78,6 +82,7 @@ public class CorsSpecTests { @Test public void corsWhenEnabledInLambdaThenAccessControlAllowOriginAndSecurityHeaders() { + givenGetCorsConfigurationWillReturnWildcard(); this.http.cors((cors) -> cors.configurationSource(this.source)); this.expectedHeaders.set("Access-Control-Allow-Origin", "*"); this.expectedHeaders.set("X-Frame-Options", "DENY"); @@ -86,6 +91,7 @@ public class CorsSpecTests { @Test public void corsWhenCorsConfigurationSourceBeanThenAccessControlAllowOriginAndSecurityHeaders() { + givenGetCorsConfigurationWillReturnWildcard(); given(this.context.getBeanNamesForType(any(ResolvableType.class))).willReturn(new String[] { "source" }, new String[0]); given(this.context.getBean("source")).willReturn(this.source); diff --git a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java index 2305d9afdd..ce5d2b015e 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java @@ -61,6 +61,9 @@ public class DefaultMethodSecurityExpressionHandlerTests { @BeforeEach public void setup() { this.handler = new DefaultMethodSecurityExpressionHandler(); + } + + private void setupMocks() { given(this.methodInvocation.getThis()).willReturn(new Foo()); given(this.methodInvocation.getMethod()).willReturn(Foo.class.getMethods()[0]); } @@ -77,6 +80,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test public void createEvaluationContextCustomTrustResolver() { + setupMocks(); this.handler.setTrustResolver(this.trustResolver); Expression expression = this.handler.getExpressionParser().parseExpression("anonymous"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); @@ -87,6 +91,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test @SuppressWarnings("unchecked") public void filterByKeyWhenUsingMapThenFiltersMap() { + setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); map.put("key2", "value2"); @@ -104,6 +109,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test @SuppressWarnings("unchecked") public void filterByValueWhenUsingMapThenFiltersMap() { + setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); map.put("key2", "value2"); @@ -121,6 +127,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test @SuppressWarnings("unchecked") public void filterByKeyAndValueWhenUsingMapThenFiltersMap() { + setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); map.put("key2", "value2"); @@ -139,6 +146,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test @SuppressWarnings("unchecked") public void filterWhenUsingStreamThenFiltersStream() { + setupMocks(); final Stream stream = Stream.of("1", "2", "3"); Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); @@ -150,6 +158,7 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test public void filterStreamWhenClosedThenUpstreamGetsClosed() { + setupMocks(); final Stream upstream = mock(Stream.class); doReturn(Stream.empty()).when(upstream).filter(any()); Expression expression = this.handler.getExpressionParser().parseExpression("true"); diff --git a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java index 24ddb2916e..df6f583223 100644 --- a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java @@ -78,10 +78,6 @@ public class UserDetailsRepositoryReactiveAuthenticationManagerTests { @BeforeEach public void setup() { this.manager = new UserDetailsRepositoryReactiveAuthenticationManager(this.userDetailsService); - given(this.scheduler.schedule(any())).willAnswer((a) -> { - Runnable r = a.getArgument(0); - return Schedulers.immediate().schedule(r); - }); } @Test @@ -91,6 +87,10 @@ public class UserDetailsRepositoryReactiveAuthenticationManagerTests { @Test public void authentiateWhenCustomSchedulerThenUsed() { + given(this.scheduler.schedule(any())).willAnswer((a) -> { + Runnable r = a.getArgument(0); + return Schedulers.immediate().schedule(r); + }); given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); given(this.encoder.matches(any(), any())).willReturn(true); this.manager.setScheduler(this.scheduler); diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java index 2d830f6f03..c5110efd2a 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java @@ -64,6 +64,10 @@ public class DelegatingSecurityContextCallableTests { @SuppressWarnings("serial") public void setUp() throws Exception { this.originalSecurityContext = SecurityContextHolder.createEmptyContext(); + this.executor = Executors.newFixedThreadPool(1); + } + + private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception { given(this.delegate.call()).willAnswer(new Returns(this.callableResult) { @Override public Object answer(InvocationOnMock invocation) throws Throwable { @@ -72,7 +76,6 @@ public class DelegatingSecurityContextCallableTests { return super.answer(invocation); } }); - this.executor = Executors.newFixedThreadPool(1); } @AfterEach @@ -104,12 +107,14 @@ public class DelegatingSecurityContextCallableTests { @Test public void call() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(); this.callable = new DelegatingSecurityContextCallable<>(this.delegate, this.securityContext); assertWrapped(this.callable); } @Test public void callDefaultSecurityContext() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(); SecurityContextHolder.setContext(this.securityContext); this.callable = new DelegatingSecurityContextCallable<>(this.delegate); // ensure callable is what sets up the SecurityContextHolder @@ -120,6 +125,7 @@ public class DelegatingSecurityContextCallableTests { // SEC-3031 @Test public void callOnSameThread() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(); this.originalSecurityContext = this.securityContext; SecurityContextHolder.setContext(this.originalSecurityContext); this.callable = new DelegatingSecurityContextCallable<>(this.delegate, this.securityContext); @@ -139,6 +145,7 @@ public class DelegatingSecurityContextCallableTests { @Test public void createNullSecurityContext() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(); SecurityContextHolder.setContext(this.securityContext); this.callable = DelegatingSecurityContextCallable.create(this.delegate, null); // ensure callable is what sets up the SecurityContextHolder @@ -148,6 +155,7 @@ public class DelegatingSecurityContextCallableTests { @Test public void create() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(); this.callable = DelegatingSecurityContextCallable.create(this.delegate, this.securityContext); assertWrapped(this.callable); } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java index 4f01a2b255..8b1c3852f6 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java @@ -63,11 +63,14 @@ public class DelegatingSecurityContextRunnableTests { @BeforeEach public void setUp() { this.originalSecurityContext = SecurityContextHolder.createEmptyContext(); + this.executor = Executors.newFixedThreadPool(1); + } + + private void givenDelegateRunWillAnswerWithCurrentSecurityContext() { willAnswer((Answer) (invocation) -> { assertThat(SecurityContextHolder.getContext()).isEqualTo(this.securityContext); return null; }).given(this.delegate).run(); - this.executor = Executors.newFixedThreadPool(1); } @AfterEach @@ -99,12 +102,14 @@ public class DelegatingSecurityContextRunnableTests { @Test public void call() throws Exception { + givenDelegateRunWillAnswerWithCurrentSecurityContext(); this.runnable = new DelegatingSecurityContextRunnable(this.delegate, this.securityContext); assertWrapped(this.runnable); } @Test public void callDefaultSecurityContext() throws Exception { + givenDelegateRunWillAnswerWithCurrentSecurityContext(); SecurityContextHolder.setContext(this.securityContext); this.runnable = new DelegatingSecurityContextRunnable(this.delegate); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the @@ -115,6 +120,7 @@ public class DelegatingSecurityContextRunnableTests { // SEC-3031 @Test public void callOnSameThread() throws Exception { + givenDelegateRunWillAnswerWithCurrentSecurityContext(); this.originalSecurityContext = this.securityContext; SecurityContextHolder.setContext(this.originalSecurityContext); this.executor = synchronousExecutor(); @@ -135,6 +141,7 @@ public class DelegatingSecurityContextRunnableTests { @Test public void createNullSecurityContext() throws Exception { + givenDelegateRunWillAnswerWithCurrentSecurityContext(); SecurityContextHolder.setContext(this.securityContext); this.runnable = DelegatingSecurityContextRunnable.create(this.delegate, null); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the @@ -144,6 +151,7 @@ public class DelegatingSecurityContextRunnableTests { @Test public void create() throws Exception { + givenDelegateRunWillAnswerWithCurrentSecurityContext(); this.runnable = DelegatingSecurityContextRunnable.create(this.delegate, this.securityContext); assertWrapped(this.runnable); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 54c13f5e68..343a0f47ef 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -176,7 +176,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager); + } + + private void setupMocks() { + setupMockSaveAuthorizedClient(); + setupMockHeaders(); + } + + private void setupMockSaveAuthorizedClient() { given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + } + + private void setupMockHeaders() { given(this.exchange.getResponse().headers()).willReturn(mock(ClientResponse.Headers.class)); } @@ -250,6 +261,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { + setupMocks(); // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("new-token") @@ -314,6 +326,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { + setupMocks(); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); @@ -353,6 +366,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { + setupMocks(); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); @@ -427,6 +441,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenUnauthorizedThenInvokeFailureHandler() { + setupMockHeaders(); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) @@ -501,6 +516,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenForbiddenThenInvokeFailureHandler() { + setupMockHeaders(); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) @@ -636,6 +652,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenOtherHttpStatusShouldNotInvokeFailureHandler() { + setupMockHeaders(); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", @@ -650,6 +667,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenPasswordClientNotAuthorizedThenGetNewToken() { + setupMocks(); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); ClientRegistration registration = TestClientRegistrations.password().build(); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") @@ -798,6 +816,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { // gh-7544 @Test public void filterWhenClientCredentialsClientNotAuthorizedAndOutsideRequestContextThenGetNewToken() { + setupMockHeaders(); // Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating // outside of a request context ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy( diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 1c01827d90..d177e1a87a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -95,8 +95,6 @@ public class OAuth2AuthorizedClientArgumentResolverTests { this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.authentication.getName(), TestOAuth2AccessTokens.noScopes()); - given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) - .willReturn(Mono.just(this.authorizedClient)); } @Test @@ -146,6 +144,8 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) + .willReturn(Mono.just(this.authorizedClient)); this.authentication = mock(OAuth2AuthenticationToken.class); given(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()) .willReturn("client1"); @@ -155,6 +155,8 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) + .willReturn(Mono.just(this.authorizedClient)); this.authentication = null; MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); @@ -163,6 +165,10 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) + .willReturn(Mono.just(this.authorizedClient)); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) + .willReturn(Mono.just(this.authorizedClient)); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java index 045dc2ea8c..d515c73e06 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -72,9 +72,6 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { FilteringWebHandler webHandler = new FilteringWebHandler((e) -> e.getResponse().setComplete(), Arrays.asList(this.filter)); this.client = WebTestClient.bindToWebHandler(webHandler).build(); - given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) - .willReturn(Mono.just(this.registration)); - given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); } @Test @@ -96,6 +93,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { @Test public void filterWhenDoesMatchThenClientRegistrationRepositoryNotSubscribed() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); // @formatter:off FluxExchangeResult result = this.client.get() .uri("https://example.com/oauth2/authorization/registration-id") @@ -116,6 +116,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { // gh-5520 @Test public void filterWhenDoesMatchThenResolveRedirectUriExpandedExcludesQueryString() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); // @formatter:off FluxExchangeResult result = this.client.get() .uri("https://example.com/oauth2/authorization/registration-id?foo=bar").exchange().expectStatus() @@ -137,6 +140,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { @Test public void filterWhenExceptionThenRedirected() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); FilteringWebHandler webHandler = new FilteringWebHandler( (e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), Arrays.asList(this.filter)); @@ -153,6 +159,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { @Test public void filterWhenExceptionThenSaveRequestSessionAttribute() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); this.filter.setRequestCache(this.requestCache); given(this.requestCache.saveRequest(any())).willReturn(Mono.empty()); FilteringWebHandler webHandler = new FilteringWebHandler( @@ -172,6 +181,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { @Test public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); this.filter.setRequestCache(this.requestCache); // @formatter:off this.client.get() diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java index 80bf54749b..f4cfabe261 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java @@ -19,7 +19,6 @@ package org.springframework.security.test.web.servlet.setup; import javax.servlet.Filter; import javax.servlet.ServletContext; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; @@ -55,14 +54,9 @@ public class SecurityMockMvcConfigurerTests { @Mock private ServletContext servletContext; - @BeforeEach - public void setup() { - given(this.context.getServletContext()).willReturn(this.servletContext); - } - @Test public void beforeMockMvcCreatedOverrideBean() throws Exception { - returnFilterBean(); + given(this.context.getServletContext()).willReturn(this.servletContext); SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(this.filter); configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); @@ -72,6 +66,7 @@ public class SecurityMockMvcConfigurerTests { @Test public void beforeMockMvcCreatedBean() throws Exception { + given(this.context.getServletContext()).willReturn(this.servletContext); returnFilterBean(); SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(); configurer.afterConfigurerAdded(this.builder); @@ -81,6 +76,7 @@ public class SecurityMockMvcConfigurerTests { @Test public void beforeMockMvcCreatedNoBean() throws Exception { + given(this.context.getServletContext()).willReturn(this.servletContext); SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(this.filter); configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); diff --git a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java index e4f93f4951..a8a97eebb1 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java @@ -23,7 +23,6 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -76,8 +75,7 @@ public class AuthenticationFilterTests { @Mock private RequestMatcher requestMatcher; - @BeforeEach - public void setup() { + private void givenResolveWillReturnAuthenticationManager() { given(this.authenticationManagerResolver.resolve(any())).willReturn(this.authenticationManager); } @@ -131,6 +129,7 @@ public class AuthenticationFilterTests { @Test public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues() throws Exception { + givenResolveWillReturnAuthenticationManager(); Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); given(this.authenticationConverter.convert(any())).willReturn(authentication); given(this.authenticationManager.authenticate(any())).willReturn(authentication); @@ -163,6 +162,7 @@ public class AuthenticationFilterTests { @Test public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationFailThenUnauthorized() throws Exception { + givenResolveWillReturnAuthenticationManager(); Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); given(this.authenticationConverter.convert(any())).willReturn(authentication); given(this.authenticationManager.authenticate(any())).willThrow(new BadCredentialsException("failed")); @@ -191,6 +191,7 @@ public class AuthenticationFilterTests { @Test public void filterWhenConvertAndAuthenticationSuccessThenSuccess() throws Exception { + givenResolveWillReturnAuthenticationManager(); Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); given(this.authenticationConverter.convert(any())).willReturn(authentication); given(this.authenticationManager.authenticate(any())).willReturn(authentication); @@ -208,6 +209,7 @@ public class AuthenticationFilterTests { @Test public void filterWhenConvertAndAuthenticationEmptyThenServerError() throws Exception { + givenResolveWillReturnAuthenticationManager(); Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); given(this.authenticationConverter.convert(any())).willReturn(authentication); given(this.authenticationManager.authenticate(any())).willReturn(null); diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java index 63c1859c21..5ad35258d4 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -56,8 +56,6 @@ public class LazyCsrfTokenRepositoryTests { @BeforeEach public void setup() { this.token = new DefaultCsrfToken("header", "param", "token"); - given(this.delegate.generateToken(this.request)).willReturn(this.token); - given(this.request.getAttribute(HttpServletResponse.class.getName())).willReturn(this.response); } @Test @@ -73,6 +71,8 @@ public class LazyCsrfTokenRepositoryTests { @Test public void generateTokenGetTokenSavesToken() { + given(this.delegate.generateToken(this.request)).willReturn(this.token); + given(this.request.getAttribute(HttpServletResponse.class.getName())).willReturn(this.response); CsrfToken newToken = this.repository.generateToken(this.request); newToken.getToken(); verify(this.delegate).saveToken(this.token, this.request, this.response); diff --git a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java index 0c2ed56571..309b2454e1 100644 --- a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java @@ -77,15 +77,19 @@ public class DebugFilterTests { @BeforeEach public void setUp() { - given(this.request.getHeaderNames()).willReturn(Collections.enumeration(Collections.emptyList())); - given(this.request.getServletPath()).willReturn("/login"); this.filter = new DebugFilter(this.fcp); ReflectionTestUtils.setField(this.filter, "logger", this.logger); this.requestAttr = DebugFilter.ALREADY_FILTERED_ATTR_NAME; } + private void setupMocks() { + given(this.request.getHeaderNames()).willReturn(Collections.enumeration(Collections.emptyList())); + given(this.request.getServletPath()).willReturn("/login"); + } + @Test public void doFilterProcessesRequests() throws Exception { + setupMocks(); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.logger).info(anyString()); verify(this.request).setAttribute(this.requestAttr, Boolean.TRUE); @@ -97,6 +101,7 @@ public class DebugFilterTests { // SEC-1901 @Test public void doFilterProcessesForwardedRequests() throws Exception { + setupMocks(); given(this.request.getAttribute(this.requestAttr)).willReturn(Boolean.TRUE); HttpServletRequest request = new DebugRequestWrapper(this.request); this.filter.doFilter(request, this.response, this.filterChain); @@ -107,6 +112,7 @@ public class DebugFilterTests { @Test public void doFilterDoesNotWrapWithDebugRequestWrapperAgain() throws Exception { + setupMocks(); given(this.request.getAttribute(this.requestAttr)).willReturn(Boolean.TRUE); HttpServletRequest fireWalledRequest = new HttpServletRequestWrapper(new DebugRequestWrapper(this.request)); this.filter.doFilter(fireWalledRequest, this.response, this.filterChain); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/DelegatingServerAuthenticationSuccessHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/DelegatingServerAuthenticationSuccessHandlerTests.java index 3473394377..23c178f434 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/DelegatingServerAuthenticationSuccessHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/DelegatingServerAuthenticationSuccessHandlerTests.java @@ -21,7 +21,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -60,9 +59,11 @@ public class DelegatingServerAuthenticationSuccessHandlerTests { @Mock private Authentication authentication; - @BeforeEach - public void setup() { + private void givenDelegate1WillReturnMock() { given(this.delegate1.onAuthenticationSuccess(any(), any())).willReturn(this.delegate1Result.mono()); + } + + private void givenDelegate2WillReturnMock() { given(this.delegate2.onAuthenticationSuccess(any(), any())).willReturn(this.delegate2Result.mono()); } @@ -80,6 +81,7 @@ public class DelegatingServerAuthenticationSuccessHandlerTests { @Test public void onAuthenticationSuccessWhenSingleThenExecuted() { + givenDelegate1WillReturnMock(); DelegatingServerAuthenticationSuccessHandler handler = new DelegatingServerAuthenticationSuccessHandler( this.delegate1); handler.onAuthenticationSuccess(this.exchange, this.authentication).block(); @@ -88,6 +90,8 @@ public class DelegatingServerAuthenticationSuccessHandlerTests { @Test public void onAuthenticationSuccessWhenMultipleThenExecuted() { + givenDelegate1WillReturnMock(); + givenDelegate2WillReturnMock(); DelegatingServerAuthenticationSuccessHandler handler = new DelegatingServerAuthenticationSuccessHandler( this.delegate1, this.delegate2); handler.onAuthenticationSuccess(this.exchange, this.authentication).block(); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/ServerFormLoginAuthenticationConverterTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/ServerFormLoginAuthenticationConverterTests.java index 16544a9d05..40b9347646 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/ServerFormLoginAuthenticationConverterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/ServerFormLoginAuthenticationConverterTests.java @@ -16,7 +16,6 @@ package org.springframework.security.web.server.authentication; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -46,13 +45,13 @@ public class ServerFormLoginAuthenticationConverterTests { private ServerFormLoginAuthenticationConverter converter = new ServerFormLoginAuthenticationConverter(); - @BeforeEach - public void setup() { + public void setupMocks() { given(this.exchange.getFormData()).willReturn(Mono.just(this.data)); } @Test public void applyWhenUsernameAndPasswordThenCreatesTokenSuccess() { + setupMocks(); String username = "username"; String password = "password"; this.data.add("username", username); @@ -65,6 +64,7 @@ public class ServerFormLoginAuthenticationConverterTests { @Test public void applyWhenCustomParametersAndUsernameAndPasswordThenCreatesTokenSuccess() { + setupMocks(); String usernameParameter = "j_username"; String passwordParameter = "j_password"; String username = "username"; @@ -81,6 +81,7 @@ public class ServerFormLoginAuthenticationConverterTests { @Test public void applyWhenNoDataThenCreatesTokenSuccess() { + setupMocks(); Authentication authentication = this.converter.convert(this.exchange).block(); assertThat(authentication.getName()).isNullOrEmpty(); assertThat(authentication.getCredentials()).isNull(); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverterTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverterTests.java index e2f1073244..693309e1d4 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverterTests.java @@ -53,6 +53,9 @@ public class ServerX509AuthenticationConverterTests { public void setUp() throws Exception { this.request = MockServerHttpRequest.get("/"); this.certificate = X509TestUtils.buildTestCertificate(); + } + + private void givenExtractPrincipalWillReturn() { given(this.principalExtractor.extractPrincipal(any())).willReturn("Luke Taylor"); } @@ -65,6 +68,7 @@ public class ServerX509AuthenticationConverterTests { @Test public void shouldReturnAuthenticationForValidCertificate() { + givenExtractPrincipalWillReturn(); this.request.sslInfo(new MockSslInfo(this.certificate)); Authentication authentication = this.converter.convert(MockServerWebExchange.from(this.request.build())) .block(); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandlerTests.java index eda78ec76e..8d609a6cfb 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandlerTests.java @@ -22,7 +22,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -61,10 +60,12 @@ public class DelegatingServerLogoutHandlerTests { @Mock private Authentication authentication; - @BeforeEach - public void setup() { + private void givenDelegate1WillReturn() { given(this.delegate1.logout(any(WebFilterExchange.class), any(Authentication.class))) .willReturn(this.delegate1Result.mono()); + } + + private void givenDelegate2WillReturn() { given(this.delegate2.logout(any(WebFilterExchange.class), any(Authentication.class))) .willReturn(this.delegate2Result.mono()); } @@ -92,6 +93,7 @@ public class DelegatingServerLogoutHandlerTests { @Test public void logoutWhenSingleThenExecuted() { + givenDelegate1WillReturn(); DelegatingServerLogoutHandler handler = new DelegatingServerLogoutHandler(this.delegate1); handler.logout(this.exchange, this.authentication).block(); this.delegate1Result.assertWasSubscribed(); @@ -99,6 +101,8 @@ public class DelegatingServerLogoutHandlerTests { @Test public void logoutWhenMultipleThenExecuted() { + givenDelegate1WillReturn(); + givenDelegate2WillReturn(); DelegatingServerLogoutHandler handler = new DelegatingServerLogoutHandler(this.delegate1, this.delegate2); handler.logout(this.exchange, this.authentication).block(); this.delegate1Result.assertWasSubscribed(); diff --git a/web/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java index f258c4569c..b8a0aa9258 100644 --- a/web/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java @@ -74,9 +74,6 @@ public class ExceptionTranslationWebFilterTests { @BeforeEach public void setup() { - given(this.exchange.getResponse()).willReturn(new MockServerHttpResponse()); - given(this.deniedHandler.handle(any(), any())).willReturn(this.deniedPublisher.mono()); - given(this.entryPoint.commence(any(), any())).willReturn(this.entryPointPublisher.mono()); this.filter.setAuthenticationEntryPoint(this.entryPoint); this.filter.setAccessDeniedHandler(this.deniedHandler); } @@ -100,6 +97,7 @@ public class ExceptionTranslationWebFilterTests { @Test public void filterWhenAccessDeniedExceptionAndNotAuthenticatedThenHandled() { + given(this.entryPoint.commence(any(), any())).willReturn(this.entryPointPublisher.mono()); given(this.exchange.getPrincipal()).willReturn(Mono.empty()); given(this.chain.filter(this.exchange)).willReturn(Mono.error(new AccessDeniedException("Not Authorized"))); StepVerifier.create(this.filter.filter(this.exchange, this.chain)).verifyComplete(); @@ -109,6 +107,7 @@ public class ExceptionTranslationWebFilterTests { @Test public void filterWhenDefaultsAndAccessDeniedExceptionAndAuthenticatedThenForbidden() { + given(this.exchange.getResponse()).willReturn(new MockServerHttpResponse()); this.filter = new ExceptionTranslationWebFilter(); given(this.exchange.getPrincipal()).willReturn(Mono.just(this.principal)); given(this.chain.filter(this.exchange)).willReturn(Mono.error(new AccessDeniedException("Not Authorized"))); @@ -118,6 +117,7 @@ public class ExceptionTranslationWebFilterTests { @Test public void filterWhenDefaultsAndAccessDeniedExceptionAndNotAuthenticatedThenUnauthorized() { + given(this.exchange.getResponse()).willReturn(new MockServerHttpResponse()); this.filter = new ExceptionTranslationWebFilter(); given(this.exchange.getPrincipal()).willReturn(Mono.empty()); given(this.chain.filter(this.exchange)).willReturn(Mono.error(new AccessDeniedException("Not Authorized"))); @@ -127,6 +127,8 @@ public class ExceptionTranslationWebFilterTests { @Test public void filterWhenAccessDeniedExceptionAndAuthenticatedThenHandled() { + given(this.deniedHandler.handle(any(), any())).willReturn(this.deniedPublisher.mono()); + given(this.entryPoint.commence(any(), any())).willReturn(this.entryPointPublisher.mono()); given(this.exchange.getPrincipal()).willReturn(Mono.just(this.principal)); given(this.chain.filter(this.exchange)).willReturn(Mono.error(new AccessDeniedException("Not Authorized"))); StepVerifier.create(this.filter.filter(this.exchange, this.chain)).expectComplete().verify(); @@ -136,6 +138,7 @@ public class ExceptionTranslationWebFilterTests { @Test public void filterWhenAccessDeniedExceptionAndAnonymousAuthenticatedThenHandled() { + given(this.entryPoint.commence(any(), any())).willReturn(this.entryPointPublisher.mono()); given(this.exchange.getPrincipal()).willReturn(Mono.just(this.anonymousPrincipal)); given(this.chain.filter(this.exchange)).willReturn(Mono.error(new AccessDeniedException("Not Authorized"))); StepVerifier.create(this.filter.filter(this.exchange, this.chain)).expectComplete().verify(); diff --git a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java index e730f5928b..7f3e6a3f1f 100644 --- a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java @@ -67,7 +67,6 @@ public class ReactorContextWebFilterTests { public void setup() { this.filter = new ReactorContextWebFilter(this.repository); this.handler = WebTestHandler.bindToWebFilters(this.filter); - given(this.repository.load(any())).willReturn(this.securityContext.mono()); } @Test @@ -77,12 +76,14 @@ public class ReactorContextWebFilterTests { @Test public void filterWhenNoPrincipalAccessThenNoInteractions() { + given(this.repository.load(any())).willReturn(this.securityContext.mono()); this.handler.exchange(this.exchange); this.securityContext.assertWasNotSubscribed(); } @Test public void filterWhenGetPrincipalMonoThenNoInteractions() { + given(this.repository.load(any())).willReturn(this.securityContext.mono()); this.handler = WebTestHandler.bindToWebFilters(this.filter, (e, c) -> { ReactiveSecurityContextHolder.getContext(); return c.filter(e); @@ -105,6 +106,7 @@ public class ReactorContextWebFilterTests { @Test // gh-4962 public void filterWhenMainContextThenDoesNotOverride() { + given(this.repository.load(any())).willReturn(this.securityContext.mono()); String contextKey = "main"; WebFilter mainContextWebFilter = (e, c) -> c.filter(e).subscriberContext(Context.of(contextKey, true)); WebFilterChain chain = new DefaultWebFilterChain((e) -> Mono.empty(), mainContextWebFilter, this.filter); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfServerLogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfServerLogoutHandlerTests.java index 2fdfc91961..397b2b1890 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfServerLogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfServerLogoutHandlerTests.java @@ -57,7 +57,6 @@ public class CsrfServerLogoutHandlerTests { this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); this.filterExchange = new WebFilterExchange(this.exchange, this.filterChain); this.handler = new CsrfServerLogoutHandler(this.csrfTokenRepository); - given(this.csrfTokenRepository.saveToken(this.exchange, null)).willReturn(Mono.empty()); } @Test @@ -68,6 +67,7 @@ public class CsrfServerLogoutHandlerTests { @Test public void logoutRemovesCsrfToken() { + given(this.csrfTokenRepository.saveToken(this.exchange, null)).willReturn(Mono.empty()); this.handler.logout(this.filterExchange, new TestingAuthenticationToken("user", "password", "ROLE_USER")) .block(); verify(this.csrfTokenRepository).saveToken(this.exchange, null); diff --git a/web/src/test/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilterTests.java index d06a330bbf..38dc442cb4 100644 --- a/web/src/test/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilterTests.java @@ -55,11 +55,11 @@ public class HttpsRedirectWebFilterTests { @BeforeEach public void configureFilter() { this.filter = new HttpsRedirectWebFilter(); - given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); } @Test public void filterWhenExchangeIsInsecureThenRedirects() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchange exchange = get("http://localhost"); this.filter.filter(exchange, this.chain).block(); assertThat(statusCode(exchange)).isEqualTo(302); @@ -68,6 +68,7 @@ public class HttpsRedirectWebFilterTests { @Test public void filterWhenExchangeIsSecureThenNoRedirect() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchange exchange = get("https://localhost"); this.filter.filter(exchange, this.chain).block(); assertThat(exchange.getResponse().getStatusCode()).isNull(); @@ -75,6 +76,7 @@ public class HttpsRedirectWebFilterTests { @Test public void filterWhenExchangeMismatchesThenNoRedirect() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class); given(matcher.matches(any(ServerWebExchange.class))) .willReturn(ServerWebExchangeMatcher.MatchResult.notMatch()); @@ -86,6 +88,7 @@ public class HttpsRedirectWebFilterTests { @Test public void filterWhenExchangeMatchesAndRequestIsInsecureThenRedirects() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class); given(matcher.matches(any(ServerWebExchange.class))).willReturn(ServerWebExchangeMatcher.MatchResult.match()); this.filter.setRequiresHttpsRedirectMatcher(matcher); @@ -98,6 +101,7 @@ public class HttpsRedirectWebFilterTests { @Test public void filterWhenRequestIsInsecureThenPortMapperRemapsPort() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); PortMapper portMapper = mock(PortMapper.class); given(portMapper.lookupHttpsPort(314)).willReturn(159); this.filter.setPortMapper(portMapper); @@ -110,12 +114,14 @@ public class HttpsRedirectWebFilterTests { @Test public void filterWhenRequestIsInsecureAndNoPortMappingThenThrowsIllegalState() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchange exchange = get("http://localhost:1234"); assertThatIllegalStateException().isThrownBy(() -> this.filter.filter(exchange, this.chain).block()); } @Test public void filterWhenInsecureRequestHasAPathThenRedirects() { + given(this.chain.filter(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerWebExchange exchange = get("http://localhost:8080/path/page.html?query=string"); this.filter.filter(exchange, this.chain).block(); assertThat(statusCode(exchange)).isEqualTo(302); diff --git a/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java index 9b7dac899d..82a6cab806 100644 --- a/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java @@ -59,18 +59,26 @@ public class OnCommittedResponseWrapperTests { OnCommittedResponseWrapperTests.this.committed = true; } }; + } + + private void givenGetWriterThenReturn() throws IOException { given(this.delegate.getWriter()).willReturn(this.writer); + } + + private void givenGetOutputStreamThenReturn() throws IOException { given(this.delegate.getOutputStream()).willReturn(this.out); } @Test public void printWriterHashCode() throws Exception { + givenGetWriterThenReturn(); int expected = this.writer.hashCode(); assertThat(this.response.getWriter().hashCode()).isEqualTo(expected); } @Test public void printWriterCheckError() throws Exception { + givenGetWriterThenReturn(); boolean expected = true; given(this.writer.checkError()).willReturn(expected); assertThat(this.response.getWriter().checkError()).isEqualTo(expected); @@ -78,6 +86,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteInt() throws Exception { + givenGetWriterThenReturn(); int expected = 1; this.response.getWriter().write(expected); verify(this.writer).write(expected); @@ -85,6 +94,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteCharIntInt() throws Exception { + givenGetWriterThenReturn(); char[] buff = new char[0]; int off = 2; int len = 3; @@ -94,6 +104,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteChar() throws Exception { + givenGetWriterThenReturn(); char[] buff = new char[0]; this.response.getWriter().write(buff); verify(this.writer).write(buff); @@ -101,6 +112,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteStringIntInt() throws Exception { + givenGetWriterThenReturn(); String s = ""; int off = 2; int len = 3; @@ -110,6 +122,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteString() throws Exception { + givenGetWriterThenReturn(); String s = ""; this.response.getWriter().write(s); verify(this.writer).write(s); @@ -117,6 +130,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintBoolean() throws Exception { + givenGetWriterThenReturn(); boolean b = true; this.response.getWriter().print(b); verify(this.writer).print(b); @@ -124,6 +138,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintChar() throws Exception { + givenGetWriterThenReturn(); char c = 1; this.response.getWriter().print(c); verify(this.writer).print(c); @@ -131,6 +146,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintInt() throws Exception { + givenGetWriterThenReturn(); int i = 1; this.response.getWriter().print(i); verify(this.writer).print(i); @@ -138,6 +154,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintLong() throws Exception { + givenGetWriterThenReturn(); long l = 1; this.response.getWriter().print(l); verify(this.writer).print(l); @@ -145,6 +162,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintFloat() throws Exception { + givenGetWriterThenReturn(); float f = 1; this.response.getWriter().print(f); verify(this.writer).print(f); @@ -152,6 +170,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintDouble() throws Exception { + givenGetWriterThenReturn(); double x = 1; this.response.getWriter().print(x); verify(this.writer).print(x); @@ -159,6 +178,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintCharArray() throws Exception { + givenGetWriterThenReturn(); char[] x = new char[0]; this.response.getWriter().print(x); verify(this.writer).print(x); @@ -166,6 +186,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintString() throws Exception { + givenGetWriterThenReturn(); String x = "1"; this.response.getWriter().print(x); verify(this.writer).print(x); @@ -173,6 +194,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintObject() throws Exception { + givenGetWriterThenReturn(); Object x = "1"; this.response.getWriter().print(x); verify(this.writer).print(x); @@ -180,12 +202,14 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintln() throws Exception { + givenGetWriterThenReturn(); this.response.getWriter().println(); verify(this.writer).println(); } @Test public void printWriterPrintlnBoolean() throws Exception { + givenGetWriterThenReturn(); boolean b = true; this.response.getWriter().println(b); verify(this.writer).println(b); @@ -193,6 +217,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnChar() throws Exception { + givenGetWriterThenReturn(); char c = 1; this.response.getWriter().println(c); verify(this.writer).println(c); @@ -200,6 +225,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnInt() throws Exception { + givenGetWriterThenReturn(); int i = 1; this.response.getWriter().println(i); verify(this.writer).println(i); @@ -207,6 +233,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnLong() throws Exception { + givenGetWriterThenReturn(); long l = 1; this.response.getWriter().println(l); verify(this.writer).println(l); @@ -214,6 +241,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnFloat() throws Exception { + givenGetWriterThenReturn(); float f = 1; this.response.getWriter().println(f); verify(this.writer).println(f); @@ -221,6 +249,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnDouble() throws Exception { + givenGetWriterThenReturn(); double x = 1; this.response.getWriter().println(x); verify(this.writer).println(x); @@ -228,6 +257,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnCharArray() throws Exception { + givenGetWriterThenReturn(); char[] x = new char[0]; this.response.getWriter().println(x); verify(this.writer).println(x); @@ -235,6 +265,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnString() throws Exception { + givenGetWriterThenReturn(); String x = "1"; this.response.getWriter().println(x); verify(this.writer).println(x); @@ -242,6 +273,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintlnObject() throws Exception { + givenGetWriterThenReturn(); Object x = "1"; this.response.getWriter().println(x); verify(this.writer).println(x); @@ -249,6 +281,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintfStringObjectVargs() throws Exception { + givenGetWriterThenReturn(); String format = "format"; Object[] args = new Object[] { "1" }; this.response.getWriter().printf(format, args); @@ -257,6 +290,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterPrintfLocaleStringObjectVargs() throws Exception { + givenGetWriterThenReturn(); Locale l = Locale.US; String format = "format"; Object[] args = new Object[] { "1" }; @@ -266,6 +300,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterFormatStringObjectVargs() throws Exception { + givenGetWriterThenReturn(); String format = "format"; Object[] args = new Object[] { "1" }; this.response.getWriter().format(format, args); @@ -274,6 +309,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterFormatLocaleStringObjectVargs() throws Exception { + givenGetWriterThenReturn(); Locale l = Locale.US; String format = "format"; Object[] args = new Object[] { "1" }; @@ -283,6 +319,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterAppendCharSequence() throws Exception { + givenGetWriterThenReturn(); String x = "a"; this.response.getWriter().append(x); verify(this.writer).append(x); @@ -290,6 +327,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterAppendCharSequenceIntInt() throws Exception { + givenGetWriterThenReturn(); String x = "abcdef"; int start = 1; int end = 3; @@ -299,6 +337,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterAppendChar() throws Exception { + givenGetWriterThenReturn(); char x = 1; this.response.getWriter().append(x); verify(this.writer).append(x); @@ -307,12 +346,14 @@ public class OnCommittedResponseWrapperTests { // servletoutputstream @Test public void outputStreamHashCode() throws Exception { + givenGetOutputStreamThenReturn(); int expected = this.out.hashCode(); assertThat(this.response.getOutputStream().hashCode()).isEqualTo(expected); } @Test public void outputStreamWriteInt() throws Exception { + givenGetOutputStreamThenReturn(); int expected = 1; this.response.getOutputStream().write(expected); verify(this.out).write(expected); @@ -320,6 +361,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamWriteByte() throws Exception { + givenGetOutputStreamThenReturn(); byte[] expected = new byte[0]; this.response.getOutputStream().write(expected); verify(this.out).write(expected); @@ -327,6 +369,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamWriteByteIntInt() throws Exception { + givenGetOutputStreamThenReturn(); int start = 1; int end = 2; byte[] expected = new byte[0]; @@ -336,6 +379,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintBoolean() throws Exception { + givenGetOutputStreamThenReturn(); boolean b = true; this.response.getOutputStream().print(b); verify(this.out).print(b); @@ -343,6 +387,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintChar() throws Exception { + givenGetOutputStreamThenReturn(); char c = 1; this.response.getOutputStream().print(c); verify(this.out).print(c); @@ -350,6 +395,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintInt() throws Exception { + givenGetOutputStreamThenReturn(); int i = 1; this.response.getOutputStream().print(i); verify(this.out).print(i); @@ -357,6 +403,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintLong() throws Exception { + givenGetOutputStreamThenReturn(); long l = 1; this.response.getOutputStream().print(l); verify(this.out).print(l); @@ -364,6 +411,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintFloat() throws Exception { + givenGetOutputStreamThenReturn(); float f = 1; this.response.getOutputStream().print(f); verify(this.out).print(f); @@ -371,6 +419,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintDouble() throws Exception { + givenGetOutputStreamThenReturn(); double x = 1; this.response.getOutputStream().print(x); verify(this.out).print(x); @@ -378,6 +427,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintString() throws Exception { + givenGetOutputStreamThenReturn(); String x = "1"; this.response.getOutputStream().print(x); verify(this.out).print(x); @@ -385,12 +435,14 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintln() throws Exception { + givenGetOutputStreamThenReturn(); this.response.getOutputStream().println(); verify(this.out).println(); } @Test public void outputStreamPrintlnBoolean() throws Exception { + givenGetOutputStreamThenReturn(); boolean b = true; this.response.getOutputStream().println(b); verify(this.out).println(b); @@ -398,6 +450,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnChar() throws Exception { + givenGetOutputStreamThenReturn(); char c = 1; this.response.getOutputStream().println(c); verify(this.out).println(c); @@ -405,6 +458,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnInt() throws Exception { + givenGetOutputStreamThenReturn(); int i = 1; this.response.getOutputStream().println(i); verify(this.out).println(i); @@ -412,6 +466,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnLong() throws Exception { + givenGetOutputStreamThenReturn(); long l = 1; this.response.getOutputStream().println(l); verify(this.out).println(l); @@ -419,6 +474,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnFloat() throws Exception { + givenGetOutputStreamThenReturn(); float f = 1; this.response.getOutputStream().println(f); verify(this.out).println(f); @@ -426,6 +482,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnDouble() throws Exception { + givenGetOutputStreamThenReturn(); double x = 1; this.response.getOutputStream().println(x); verify(this.out).println(x); @@ -433,6 +490,7 @@ public class OnCommittedResponseWrapperTests { @Test public void outputStreamPrintlnString() throws Exception { + givenGetOutputStreamThenReturn(); String x = "1"; this.response.getOutputStream().println(x); verify(this.out).println(x); @@ -443,6 +501,7 @@ public class OnCommittedResponseWrapperTests { // gh-3823 @Test public void contentLengthPrintWriterWriteNullCommits() throws Exception { + givenGetWriterThenReturn(); String expected = null; this.response.setContentLength(String.valueOf(expected).length() + 1); this.response.getWriter().write(expected); @@ -453,6 +512,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteIntCommits() throws Exception { + givenGetWriterThenReturn(); int expected = 1; this.response.setContentLength(String.valueOf(expected).length()); this.response.getWriter().write(expected); @@ -461,6 +521,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteIntMultiDigitCommits() throws Exception { + givenGetWriterThenReturn(); int expected = 10000; this.response.setContentLength(String.valueOf(expected).length()); this.response.getWriter().write(expected); @@ -469,6 +530,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPlus1PrintWriterWriteIntMultiDigitCommits() throws Exception { + givenGetWriterThenReturn(); int expected = 10000; this.response.setContentLength(String.valueOf(expected).length() + 1); this.response.getWriter().write(expected); @@ -479,6 +541,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteCharIntIntCommits() throws Exception { + givenGetWriterThenReturn(); char[] buff = new char[0]; int off = 2; int len = 3; @@ -489,6 +552,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteCharCommits() throws Exception { + givenGetWriterThenReturn(); char[] buff = new char[4]; this.response.setContentLength(buff.length); this.response.getWriter().write(buff); @@ -497,6 +561,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteStringIntIntCommits() throws Exception { + givenGetWriterThenReturn(); String s = ""; int off = 2; int len = 3; @@ -507,6 +572,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterWriteStringCommits() throws IOException { + givenGetWriterThenReturn(); String body = "something"; this.response.setContentLength(body.length()); this.response.getWriter().write(body); @@ -515,6 +581,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteStringContentLengthCommits() throws IOException { + givenGetWriterThenReturn(); String body = "something"; this.response.getWriter().write(body); this.response.setContentLength(body.length()); @@ -523,6 +590,7 @@ public class OnCommittedResponseWrapperTests { @Test public void printWriterWriteStringDoesNotCommit() throws IOException { + givenGetWriterThenReturn(); String body = "something"; this.response.getWriter().write(body); assertThat(this.committed).isFalse(); @@ -530,6 +598,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintBooleanCommits() throws Exception { + givenGetWriterThenReturn(); boolean b = true; this.response.setContentLength(1); this.response.getWriter().print(b); @@ -538,6 +607,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintCharCommits() throws Exception { + givenGetWriterThenReturn(); char c = 1; this.response.setContentLength(1); this.response.getWriter().print(c); @@ -546,6 +616,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintIntCommits() throws Exception { + givenGetWriterThenReturn(); int i = 1234; this.response.setContentLength(String.valueOf(i).length()); this.response.getWriter().print(i); @@ -554,6 +625,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintLongCommits() throws Exception { + givenGetWriterThenReturn(); long l = 12345; this.response.setContentLength(String.valueOf(l).length()); this.response.getWriter().print(l); @@ -562,6 +634,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintFloatCommits() throws Exception { + givenGetWriterThenReturn(); float f = 12345; this.response.setContentLength(String.valueOf(f).length()); this.response.getWriter().print(f); @@ -570,6 +643,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintDoubleCommits() throws Exception { + givenGetWriterThenReturn(); double x = 1.2345; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().print(x); @@ -578,6 +652,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintCharArrayCommits() throws Exception { + givenGetWriterThenReturn(); char[] x = new char[10]; this.response.setContentLength(x.length); this.response.getWriter().print(x); @@ -586,6 +661,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintStringCommits() throws Exception { + givenGetWriterThenReturn(); String x = "12345"; this.response.setContentLength(x.length()); this.response.getWriter().print(x); @@ -594,6 +670,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintObjectCommits() throws Exception { + givenGetWriterThenReturn(); Object x = "12345"; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().print(x); @@ -602,6 +679,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnCommits() throws Exception { + givenGetWriterThenReturn(); this.response.setContentLength(NL.length()); this.response.getWriter().println(); assertThat(this.committed).isTrue(); @@ -609,6 +687,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnBooleanCommits() throws Exception { + givenGetWriterThenReturn(); boolean b = true; this.response.setContentLength(1); this.response.getWriter().println(b); @@ -617,6 +696,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnCharCommits() throws Exception { + givenGetWriterThenReturn(); char c = 1; this.response.setContentLength(1); this.response.getWriter().println(c); @@ -625,6 +705,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnIntCommits() throws Exception { + givenGetWriterThenReturn(); int i = 12345; this.response.setContentLength(String.valueOf(i).length()); this.response.getWriter().println(i); @@ -633,6 +714,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnLongCommits() throws Exception { + givenGetWriterThenReturn(); long l = 12345678; this.response.setContentLength(String.valueOf(l).length()); this.response.getWriter().println(l); @@ -641,6 +723,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnFloatCommits() throws Exception { + givenGetWriterThenReturn(); float f = 1234; this.response.setContentLength(String.valueOf(f).length()); this.response.getWriter().println(f); @@ -649,6 +732,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnDoubleCommits() throws Exception { + givenGetWriterThenReturn(); double x = 1; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().println(x); @@ -657,6 +741,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnCharArrayCommits() throws Exception { + givenGetWriterThenReturn(); char[] x = new char[20]; this.response.setContentLength(x.length); this.response.getWriter().println(x); @@ -665,6 +750,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnStringCommits() throws Exception { + givenGetWriterThenReturn(); String x = "1"; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().println(x); @@ -673,6 +759,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterPrintlnObjectCommits() throws Exception { + givenGetWriterThenReturn(); Object x = "1"; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().println(x); @@ -681,6 +768,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterAppendCharSequenceCommits() throws Exception { + givenGetWriterThenReturn(); String x = "a"; this.response.setContentLength(String.valueOf(x).length()); this.response.getWriter().append(x); @@ -689,6 +777,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterAppendCharSequenceIntIntCommits() throws Exception { + givenGetWriterThenReturn(); String x = "abcdef"; int start = 1; int end = 3; @@ -699,6 +788,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPrintWriterAppendCharCommits() throws Exception { + givenGetWriterThenReturn(); char x = 1; this.response.setContentLength(1); this.response.getWriter().append(x); @@ -707,6 +797,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamWriteIntCommits() throws Exception { + givenGetOutputStreamThenReturn(); int expected = 1; this.response.setContentLength(String.valueOf(expected).length()); this.response.getOutputStream().write(expected); @@ -715,6 +806,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamWriteIntMultiDigitCommits() throws Exception { + givenGetOutputStreamThenReturn(); int expected = 10000; this.response.setContentLength(String.valueOf(expected).length()); this.response.getOutputStream().write(expected); @@ -723,6 +815,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthPlus1OutputStreamWriteIntMultiDigitCommits() throws Exception { + givenGetOutputStreamThenReturn(); int expected = 10000; this.response.setContentLength(String.valueOf(expected).length() + 1); this.response.getOutputStream().write(expected); @@ -734,6 +827,7 @@ public class OnCommittedResponseWrapperTests { // gh-171 @Test public void contentLengthPlus1OutputStreamWriteByteArrayMultiDigitCommits() throws Exception { + givenGetOutputStreamThenReturn(); String expected = "{\n" + " \"parameterName\" : \"_csrf\",\n" + " \"token\" : \"06300b65-c4aa-4c8f-8cda-39ee17f545a0\",\n" + " \"headerName\" : \"X-CSRF-TOKEN\"\n" + "}"; @@ -746,6 +840,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintBooleanCommits() throws Exception { + givenGetOutputStreamThenReturn(); boolean b = true; this.response.setContentLength(1); this.response.getOutputStream().print(b); @@ -754,6 +849,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintCharCommits() throws Exception { + givenGetOutputStreamThenReturn(); char c = 1; this.response.setContentLength(1); this.response.getOutputStream().print(c); @@ -762,6 +858,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintIntCommits() throws Exception { + givenGetOutputStreamThenReturn(); int i = 1234; this.response.setContentLength(String.valueOf(i).length()); this.response.getOutputStream().print(i); @@ -770,6 +867,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintLongCommits() throws Exception { + givenGetOutputStreamThenReturn(); long l = 12345; this.response.setContentLength(String.valueOf(l).length()); this.response.getOutputStream().print(l); @@ -778,6 +876,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintFloatCommits() throws Exception { + givenGetOutputStreamThenReturn(); float f = 12345; this.response.setContentLength(String.valueOf(f).length()); this.response.getOutputStream().print(f); @@ -786,6 +885,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintDoubleCommits() throws Exception { + givenGetOutputStreamThenReturn(); double x = 1.2345; this.response.setContentLength(String.valueOf(x).length()); this.response.getOutputStream().print(x); @@ -794,6 +894,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintStringCommits() throws Exception { + givenGetOutputStreamThenReturn(); String x = "12345"; this.response.setContentLength(x.length()); this.response.getOutputStream().print(x); @@ -802,6 +903,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnCommits() throws Exception { + givenGetOutputStreamThenReturn(); this.response.setContentLength(NL.length()); this.response.getOutputStream().println(); assertThat(this.committed).isTrue(); @@ -809,6 +911,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnBooleanCommits() throws Exception { + givenGetOutputStreamThenReturn(); boolean b = true; this.response.setContentLength(1); this.response.getOutputStream().println(b); @@ -817,6 +920,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnCharCommits() throws Exception { + givenGetOutputStreamThenReturn(); char c = 1; this.response.setContentLength(1); this.response.getOutputStream().println(c); @@ -825,6 +929,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnIntCommits() throws Exception { + givenGetOutputStreamThenReturn(); int i = 12345; this.response.setContentLength(String.valueOf(i).length()); this.response.getOutputStream().println(i); @@ -833,6 +938,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnLongCommits() throws Exception { + givenGetOutputStreamThenReturn(); long l = 12345678; this.response.setContentLength(String.valueOf(l).length()); this.response.getOutputStream().println(l); @@ -841,6 +947,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnFloatCommits() throws Exception { + givenGetOutputStreamThenReturn(); float f = 1234; this.response.setContentLength(String.valueOf(f).length()); this.response.getOutputStream().println(f); @@ -849,6 +956,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnDoubleCommits() throws Exception { + givenGetOutputStreamThenReturn(); double x = 1; this.response.setContentLength(String.valueOf(x).length()); this.response.getOutputStream().println(x); @@ -857,6 +965,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamPrintlnStringCommits() throws Exception { + givenGetOutputStreamThenReturn(); String x = "1"; this.response.setContentLength(String.valueOf(x).length()); this.response.getOutputStream().println(x); @@ -872,6 +981,7 @@ public class OnCommittedResponseWrapperTests { @Test public void contentLengthOutputStreamWriteStringCommits() throws IOException { + givenGetOutputStreamThenReturn(); String body = "something"; this.response.setContentLength(body.length()); this.response.getOutputStream().print(body); @@ -881,6 +991,7 @@ public class OnCommittedResponseWrapperTests { // gh-7261 @Test public void contentLengthLongOutputStreamWriteStringCommits() throws IOException { + givenGetOutputStreamThenReturn(); String body = "something"; this.response.setContentLengthLong(body.length()); this.response.getOutputStream().print(body); @@ -889,6 +1000,7 @@ public class OnCommittedResponseWrapperTests { @Test public void addHeaderContentLengthPrintWriterWriteStringCommits() throws Exception { + givenGetWriterThenReturn(); int expected = 1234; this.response.addHeader("Content-Length", String.valueOf(String.valueOf(expected).length())); this.response.getWriter().write(expected); @@ -897,6 +1009,7 @@ public class OnCommittedResponseWrapperTests { @Test public void bufferSizePrintWriterWriteCommits() throws Exception { + givenGetWriterThenReturn(); String expected = "1234567890"; given(this.response.getBufferSize()).willReturn(expected.length()); this.response.getWriter().write(expected); @@ -905,6 +1018,7 @@ public class OnCommittedResponseWrapperTests { @Test public void bufferSizeCommitsOnce() throws Exception { + givenGetWriterThenReturn(); String expected = "1234567890"; given(this.response.getBufferSize()).willReturn(expected.length()); this.response.getWriter().write(expected); diff --git a/web/src/test/java/org/springframework/security/web/util/TextEscapeUtilsTests.java b/web/src/test/java/org/springframework/security/web/util/TextEscapeUtilsTests.java index 6ad525eb7b..71482298c2 100644 --- a/web/src/test/java/org/springframework/security/web/util/TextEscapeUtilsTests.java +++ b/web/src/test/java/org/springframework/security/web/util/TextEscapeUtilsTests.java @@ -57,12 +57,12 @@ public class TextEscapeUtilsTests { */ @Test public void validSurrogatePairIsAccepted() { - assertThat(TextEscapeUtils.escapeEntities("abc\uD801a")).isEqualTo("abc𐐀a"); + assertThat(TextEscapeUtils.escapeEntities("abc\uD801\uDC00a")).isEqualTo("abc𐐀a"); } @Test public void undefinedSurrogatePairIsIgnored() { - assertThat(TextEscapeUtils.escapeEntities("abc\uD888a")).isEqualTo("abca"); + assertThat(TextEscapeUtils.escapeEntities("abc\uD888\uDC00a")).isEqualTo("abca"); } }