1
0
mirror of synced 2026-05-22 13:23:17 +00:00

Migrate to BDD Mockito

Migrate Mockito imports to use the BDD variant. This aligns better with
the "given" / "when" / "then" style used in most tests since the "given"
block now uses Mockito `given(...)` calls.

The commit also updates a few tests that were accidentally using
Power Mockito when regular Mockito could be used.

Issue gh-8945
This commit is contained in:
Phillip Webb
2020-07-27 12:53:19 -07:00
committed by Rob Winch
parent c12ced6aaa
commit db55ef4b3b
259 changed files with 2126 additions and 2125 deletions
@@ -28,6 +28,8 @@ import java.util.function.Consumer;
import java.util.function.Function;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.ProviderDetails;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.util.Assert;
@@ -69,10 +69,10 @@ import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
@@ -260,7 +260,7 @@ public class OpenSamlAuthenticationProviderTests {
Element attributeElement = element("<element>value</element>");
Marshaller marshaller = mock(Marshaller.class);
when(marshaller.marshall(any(XMLObject.class))).thenReturn(attributeElement);
given(marshaller.marshall(any(XMLObject.class))).willReturn(attributeElement);
try {
XMLObjectProviderRegistrySupport.getMarshallerFactory()
@@ -375,9 +375,9 @@ public class OpenSamlAuthenticationProviderTests {
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
.thenReturn(ValidationResult.VALID);
given(validator.getServicedCondition()).willReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
given(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
.willReturn(ValidationResult.VALID);
provider.authenticate(token);
verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
}
@@ -388,7 +388,7 @@ public class OpenSamlAuthenticationProviderTests {
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
parameters.put(SIGNATURE_REQUIRED, false);
ValidationContext context = mock(ValidationContext.class);
when(context.getStaticParameters()).thenReturn(parameters);
given(context.getStaticParameters()).willReturn(parameters);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setValidationContextConverter(tuple -> context);
Response response = response();
@@ -39,9 +39,9 @@ import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getParserPool;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getUnmarshallerFactory;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
@@ -170,7 +170,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver = mock(
Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {
given(authnRequestConsumerResolver.apply(this.context)).willReturn(authnRequest -> {
});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
@@ -182,7 +182,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver = mock(
Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {
given(authnRequestConsumerResolver.apply(this.context)).willReturn(authnRequest -> {
});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
@@ -31,8 +31,8 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class Saml2WebSsoAuthenticationFilterTests {
@@ -81,7 +81,7 @@ public class Saml2WebSsoAuthenticationFilterTests {
@Test
public void attemptAuthenticationWhenRegistrationIdDoesNotExistThenThrowsException() {
when(this.repository.findByRegistrationId("non-existent-id")).thenReturn(null);
given(this.repository.findByRegistrationId("non-existent-id")).willReturn(null);
this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}");
@@ -38,10 +38,10 @@ import org.springframework.web.util.UriUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
@@ -83,14 +83,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenNoRelayStateThenRedirectDoesNotContainParameter() throws ServletException, IOException {
when(this.repository.findByRegistrationId("registration-id")).thenReturn(this.rpBuilder.build());
given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build());
this.filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.response.getHeader("Location")).doesNotContain("RelayState=").startsWith(IDP_SSO_URL);
}
@Test
public void doFilterWhenRelayStateThenRedirectDoesContainParameter() throws ServletException, IOException {
when(this.repository.findByRegistrationId("registration-id")).thenReturn(this.rpBuilder.build());
given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build());
this.request.setParameter("RelayState", "my-relay-state");
this.filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.response.getHeader("Location")).contains("RelayState=my-relay-state").startsWith(IDP_SSO_URL);
@@ -98,7 +98,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenRelayStateThatRequiresEncodingThenRedirectDoesContainsEncodedParameter() throws Exception {
when(this.repository.findByRegistrationId("registration-id")).thenReturn(this.rpBuilder.build());
given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build());
final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param";
final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1);
this.request.setParameter("RelayState", relayStateValue);
@@ -109,7 +109,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenSimpleSignatureSpecifiedThenSignatureParametersAreInTheRedirectURL() throws Exception {
when(this.repository.findByRegistrationId("registration-id")).thenReturn(this.rpBuilder.build());
given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build());
final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param";
final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1);
this.request.setParameter("RelayState", relayStateValue);
@@ -120,8 +120,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenSignatureIsDisabledThenSignatureParametersAreNotInTheRedirectURL() throws Exception {
when(this.repository.findByRegistrationId("registration-id"))
.thenReturn(this.rpBuilder.providerDetails(c -> c.signAuthNRequest(false)).build());
given(this.repository.findByRegistrationId("registration-id"))
.willReturn(this.rpBuilder.providerDetails(c -> c.signAuthNRequest(false)).build());
final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param";
final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1);
this.request.setParameter("RelayState", relayStateValue);
@@ -132,8 +132,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenPostFormDataIsPresent() throws Exception {
when(this.repository.findByRegistrationId("registration-id"))
.thenReturn(this.rpBuilder.providerDetails(c -> c.binding(POST)).build());
given(this.repository.findByRegistrationId("registration-id"))
.willReturn(this.rpBuilder.providerDetails(c -> c.binding(POST)).build());
final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param&javascript{alert('1');}";
final String relayStateEncoded = HtmlUtils.htmlEscape(relayStateValue);
this.request.setParameter("RelayState", relayStateValue);
@@ -149,11 +149,11 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
public void doFilterWhenSetAuthenticationRequestFactoryThenUses() throws Exception {
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(POST)).build();
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
when(authenticationRequest.getRelayState()).thenReturn("relay");
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
when(this.factory.createPostAuthenticationRequest(any())).thenReturn(authenticationRequest);
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
given(authenticationRequest.getRelayState()).willReturn("relay");
given(authenticationRequest.getSamlRequest()).willReturn("saml");
given(this.repository.findByRegistrationId("registration-id")).willReturn(relyingParty);
given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository);
filter.setAuthenticationRequestFactory(this.factory);
@@ -168,12 +168,12 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
public void doFilterWhenCustomAuthenticationRequestFactoryThenUses() throws Exception {
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(POST)).build();
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
when(authenticationRequest.getRelayState()).thenReturn("relay");
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
when(this.resolver.resolve(this.request))
.thenReturn(authenticationRequestContext().relyingPartyRegistration(relyingParty).build());
when(this.factory.createPostAuthenticationRequest(any())).thenReturn(authenticationRequest);
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
given(authenticationRequest.getRelayState()).willReturn("relay");
given(authenticationRequest.getSamlRequest()).willReturn("saml");
given(this.resolver.resolve(this.request))
.willReturn(authenticationRequestContext().relyingPartyRegistration(relyingParty).build());
given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver,
this.factory);
@@ -39,7 +39,7 @@ import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.mockito.BDDMockito.given;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@RunWith(MockitoJUnitRunner.class)
@@ -54,8 +54,8 @@ public class Saml2AuthenticationTokenConverterTests {
public void convertWhenSamlResponseThenToken() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
@@ -68,8 +68,8 @@ public class Saml2AuthenticationTokenConverterTests {
public void convertWhenNoSamlResponseThenNull() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
assertThat(converter.convert(request)).isNull();
}
@@ -78,7 +78,7 @@ public class Saml2AuthenticationTokenConverterTests {
public void convertWhenNoRelyingPartyRegistrationThenNull() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))).thenReturn(null);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))).willReturn(null);
MockHttpServletRequest request = new MockHttpServletRequest();
assertThat(converter.convert(request)).isNull();
}
@@ -87,8 +87,8 @@ public class Saml2AuthenticationTokenConverterTests {
public void convertWhenGetRequestThenInflates() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET");
byte[] deflated = Saml2Utils.samlDeflate("response");
@@ -109,8 +109,8 @@ public class Saml2AuthenticationTokenConverterTests {
public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", getSsoCircleEncodedXml());
Saml2AuthenticationToken token = converter.convert(request);
@@ -30,10 +30,10 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
@@ -94,7 +94,7 @@ public class Saml2MetadataFilterTests {
public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception {
// given
this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
when(this.repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
given(this.repository.findByRegistrationId("invalidRegistration")).willReturn(null);
// when
this.filter.doFilter(this.request, this.response, this.chain);
@@ -114,7 +114,7 @@ public class Saml2MetadataFilterTests {
.build();
String generatedMetadata = "<xml>test</xml>";
when(this.resolver.resolve(validRegistration)).thenReturn(generatedMetadata);
given(this.resolver.resolve(validRegistration)).willReturn(generatedMetadata);
this.filter = new Saml2MetadataFilter(request -> validRegistration, this.resolver);