1
0
mirror of synced 2026-05-22 13:23:17 +00:00

Remove restricted static imports

Replace static imports with class referenced methods. With the exception
of a few well known static imports, checkstyle restricts the static
imports that a class can use. For example, `asList(...)` would be
replaced with `Arrays.asList(...)`.

Issue gh-8945
This commit is contained in:
Phillip Webb
2020-07-27 21:34:26 -07:00
committed by Rob Winch
parent 9a3fa6e812
commit e9130489a6
252 changed files with 2216 additions and 2222 deletions
@@ -29,13 +29,10 @@ import org.apache.commons.logging.LogFactory;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.config.InitializationService;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.springframework.security.saml2.Saml2Exception;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.setParserPool;
/**
* An initialization service for initializing OpenSAML. Each Spring Security
* OpenSAML-based component invokes the {@link #initialize()} method at static
@@ -130,12 +127,13 @@ public class OpenSamlInitializationService {
parserPool.setMaxPoolSize(50);
Map<String, Boolean> parserBuilderFeatures = new HashMap<>();
parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", TRUE);
parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, TRUE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value", FALSE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE);
parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value",
Boolean.FALSE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", Boolean.FALSE);
parserPool.setBuilderFeatures(parserBuilderFeatures);
try {
@@ -144,7 +142,7 @@ public class OpenSamlInitializationService {
catch (Exception e) {
throw new Saml2Exception(e);
}
setParserPool(parserPool);
XMLObjectProviderRegistrySupport.setParserPool(parserPool);
registryConsumer.accept(ConfigurationService.get(XMLObjectProviderRegistry.class));
@@ -17,17 +17,13 @@ package org.springframework.security.saml2.core;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import org.springframework.util.Assert;
import static java.util.Arrays.asList;
import static org.springframework.util.Assert.notEmpty;
import static org.springframework.util.Assert.notNull;
import static org.springframework.util.Assert.state;
/**
* An object for holding a public certificate, any associated private key, and its
* intended <a href=
@@ -127,14 +123,14 @@ public final class Saml2X509Credential {
private Saml2X509Credential(PrivateKey privateKey, boolean keyRequired, X509Certificate certificate,
Saml2X509CredentialType... types) {
notNull(certificate, "certificate cannot be null");
notEmpty(types, "credentials types cannot be empty");
Assert.notNull(certificate, "certificate cannot be null");
Assert.notEmpty(types, "credentials types cannot be empty");
if (keyRequired) {
notNull(privateKey, "privateKey cannot be null");
Assert.notNull(privateKey, "privateKey cannot be null");
}
this.privateKey = privateKey;
this.certificate = certificate;
this.credentialTypes = new LinkedHashSet<>(asList(types));
this.credentialTypes = new LinkedHashSet<>(Arrays.asList(types));
}
/**
@@ -224,7 +220,7 @@ public final class Saml2X509Credential {
break;
}
}
state(valid, () -> usage + " is not a valid usage for this credential");
Assert.state(valid, () -> usage + " is not a valid usage for this credential");
}
}
@@ -17,17 +17,13 @@ package org.springframework.security.saml2.credentials;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import org.springframework.util.Assert;
import static java.util.Arrays.asList;
import static org.springframework.util.Assert.notEmpty;
import static org.springframework.util.Assert.notNull;
import static org.springframework.util.Assert.state;
/**
* Saml2X509Credential is meant to hold an X509 certificate, or an X509 certificate and a
* private key. Per:
@@ -98,14 +94,14 @@ public class Saml2X509Credential {
private Saml2X509Credential(PrivateKey privateKey, boolean keyRequired, X509Certificate certificate,
Saml2X509CredentialType... types) {
notNull(certificate, "certificate cannot be null");
notEmpty(types, "credentials types cannot be empty");
Assert.notNull(certificate, "certificate cannot be null");
Assert.notEmpty(types, "credentials types cannot be empty");
if (keyRequired) {
notNull(privateKey, "privateKey cannot be null");
Assert.notNull(privateKey, "privateKey cannot be null");
}
this.privateKey = privateKey;
this.certificate = certificate;
this.credentialTypes = new LinkedHashSet<>(asList(types));
this.credentialTypes = new LinkedHashSet<>(Arrays.asList(types));
}
/**
@@ -198,7 +194,7 @@ public class Saml2X509Credential {
break;
}
}
state(valid, () -> usage + " is not a valid usage for this credential");
Assert.state(valid, () -> usage + " is not a valid usage for this credential");
}
}
@@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@@ -58,6 +59,7 @@ import org.opensaml.saml.criterion.ProtocolCriterion;
import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
import org.opensaml.saml.saml2.assertion.ConditionValidator;
import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.assertion.StatementValidator;
import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
@@ -107,28 +109,12 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import static java.util.Arrays.asList;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.CLOCK_SKEW;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.COND_VALID_AUDIENCES;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.DECRYPTION_ERROR;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_DESTINATION;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ISSUER;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
import static org.springframework.util.Assert.notNull;
/**
* Implementation of {@link AuthenticationProvider} for SAML authentications when
* receiving a {@code Response} object containing an {@code Assertion}. This
@@ -188,8 +174,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private final ParserPool parserPool;
private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor = (a -> singletonList(
new SimpleGrantedAuthority("ROLE_USER")));
private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor = (a -> Collections
.singletonList(new SimpleGrantedAuthority("ROLE_USER")));
private GrantedAuthoritiesMapper authoritiesMapper = (a -> a);
@@ -268,7 +254,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* user's authorities
*/
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
notNull(authoritiesMapper, "authoritiesMapper cannot be null");
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}
@@ -300,7 +286,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw e;
}
catch (Exception e) {
throw authException(INTERNAL_VALIDATION_ERROR, e.getMessage(), e);
throw authException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, e.getMessage(), e);
}
}
@@ -324,7 +310,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return (Response) this.responseUnmarshaller.unmarshall(element);
}
catch (Exception e) {
throw authException(MALFORMED_RESPONSE_DATA, e.getMessage(), e);
throw authException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, e.getMessage(), e);
}
}
@@ -340,15 +326,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
Decrypter decrypter = this.decrypterConverter.convert(token);
List<Assertion> assertions = decryptAssertions(decrypter, response);
if (!isSigned(responseSigned, assertions)) {
throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.");
throw authException(Saml2ErrorCodes.INVALID_SIGNATURE,
"Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions.");
}
validationExceptions.putAll(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
if (nameId == null || nameId.getValue() == null) {
validationExceptions.put(SUBJECT_NOT_FOUND, authException(SUBJECT_NOT_FOUND,
validationExceptions.put(Saml2ErrorCodes.SUBJECT_NOT_FOUND, authException(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + firstAssertion.getID() + "] is missing a subject"));
}
@@ -385,8 +372,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
profileValidator.validate(response.getSignature());
}
catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE,
authException(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
}
try {
@@ -396,13 +384,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]"));
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE,
authException(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]"));
}
}
catch (Exception e) {
validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE,
authException(Saml2ErrorCodes.INVALID_SIGNATURE,
"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
}
}
@@ -410,13 +400,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
validationExceptions.put(Saml2ErrorCodes.INVALID_DESTINATION,
authException(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
validationExceptions.put(Saml2ErrorCodes.INVALID_ISSUER,
authException(Saml2ErrorCodes.INVALID_ISSUER, message));
}
return validationExceptions;
@@ -430,7 +422,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
assertions.add(assertion);
}
catch (DecryptionException e) {
throw authException(DECRYPTION_ERROR, e.getMessage(), e);
throw authException(Saml2ErrorCodes.DECRYPTION_ERROR, e.getMessage(), e);
}
}
response.getAssertions().addAll(assertions);
@@ -441,7 +433,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
Response response) {
List<Assertion> assertions = response.getAssertions();
if (assertions.isEmpty()) {
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
throw authException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.");
}
Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
@@ -461,13 +453,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(),
context.getValidationFailureMessage());
validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message));
validationExceptions.put(Saml2ErrorCodes.INVALID_ASSERTION,
authException(Saml2ErrorCodes.INVALID_ASSERTION, message));
}
}
catch (Exception e) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), e.getMessage());
validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
validationExceptions.put(Saml2ErrorCodes.INVALID_ASSERTION,
authException(Saml2ErrorCodes.INVALID_ASSERTION, message, e));
}
}
@@ -501,7 +495,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return nameId;
}
catch (DecryptionException e) {
throw authException(DECRYPTION_ERROR, e.getMessage(), e);
throw authException(Saml2ErrorCodes.DECRYPTION_ERROR, e.getMessage(), e);
}
}
@@ -606,11 +600,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
params.put(SIGNATURE_REQUIRED, false); // this verification is performed
// earlier
params.put(SAML2AssertionValidationParameters.CLOCK_SKEW,
OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
params.put(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false); // this
// verification
// is
// performed
// earlier
return new ValidationContext(params);
}
@@ -649,7 +647,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private static class DecrypterConverter implements Converter<Saml2AuthenticationToken, Decrypter> {
private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
new SimpleRetrievalMethodEncryptedKeyResolver()));
@Override
@@ -59,13 +59,9 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriUtils;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
import static org.springframework.util.StringUtils.hasText;
/**
* @since 5.2
*/
@@ -130,7 +126,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest);
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
.samlRequest(samlEncode(xml.getBytes(UTF_8))).build();
.samlRequest(Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))).build();
}
/**
@@ -142,7 +138,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
AuthnRequest authnRequest = createAuthnRequest(context);
String xml = serialize(authnRequest);
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
@@ -264,7 +260,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
StringBuilder queryString = new StringBuilder();
queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
.append("&");
if (hasText(relayState)) {
if (StringUtils.hasText(relayState)) {
queryString.append("RelayState").append("=")
.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
}
@@ -277,7 +273,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
Map<String, String> result = new LinkedHashMap<>();
result.put("SAMLRequest", samlRequest);
if (hasText(relayState)) {
if (StringUtils.hasText(relayState)) {
result.put("RelayState", relayState);
}
result.put("SigAlg", algorithmUri);
@@ -22,10 +22,6 @@ import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import static org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequest.withAuthenticationRequestContext;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
/**
* Component that generates AuthenticationRequest, <code>samlp:AuthnRequestType</code>
* XML, and accompanying signature data. as defined by
@@ -81,9 +77,10 @@ public interface Saml2AuthenticationRequestFactory {
default Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(
Saml2AuthenticationRequestContext context) {
// backwards compatible with 5.2.x settings
Saml2AuthenticationRequest.Builder resultBuilder = withAuthenticationRequestContext(context);
Saml2AuthenticationRequest.Builder resultBuilder = Saml2AuthenticationRequest
.withAuthenticationRequestContext(context);
String samlRequest = createAuthenticationRequest(resultBuilder.build());
samlRequest = samlEncode(samlDeflate(samlRequest));
samlRequest = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(samlRequest));
return Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest(samlRequest)
.build();
}
@@ -108,9 +105,10 @@ public interface Saml2AuthenticationRequestFactory {
*/
default Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
// backwards compatible with 5.2.x settings
Saml2AuthenticationRequest.Builder resultBuilder = withAuthenticationRequestContext(context);
Saml2AuthenticationRequest.Builder resultBuilder = Saml2AuthenticationRequest
.withAuthenticationRequestContext(context);
String samlRequest = createAuthenticationRequest(resultBuilder.build());
samlRequest = samlEncode(samlRequest.getBytes(StandardCharsets.UTF_8));
samlRequest = Saml2Utils.samlEncode(samlRequest.getBytes(StandardCharsets.UTF_8));
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest(samlRequest)
.build();
}
@@ -24,8 +24,6 @@ import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
/**
* Represents an incoming SAML 2.0 response containing an assertion that has not been
* validated. {@link Saml2AuthenticationToken#isAuthenticated()} will always return false.
@@ -78,8 +76,9 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
public Saml2AuthenticationToken(String saml2Response, String recipientUri, String idpEntityId,
String localSpEntityId, List<Saml2X509Credential> credentials) {
super(null);
this.relyingPartyRegistration = withRegistrationId(idpEntityId).entityId(localSpEntityId)
.assertionConsumerServiceLocation(recipientUri).credentials(c -> c.addAll(credentials))
this.relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId(idpEntityId)
.entityId(localSpEntityId).assertionConsumerServiceLocation(recipientUri)
.credentials(c -> c.addAll(credentials))
.assertingPartyDetails(
assertingParty -> assertingParty.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
.build();
@@ -18,8 +18,6 @@ package org.springframework.security.saml2.provider.service.authentication;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
/**
* Data holder for information required to send an {@code AuthNRequest} over a POST
* binding from the service provider to the identity provider
@@ -40,7 +38,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
*/
@Override
public Saml2MessageBinding getBinding() {
return POST;
return Saml2MessageBinding.POST;
}
/**
@@ -18,8 +18,6 @@ package org.springframework.security.saml2.provider.service.authentication;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
/**
* Data holder for information required to send an {@code AuthNRequest} over a REDIRECT
* binding from the service provider to the identity provider
@@ -63,7 +61,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
*/
@Override
public Saml2MessageBinding getBinding() {
return REDIRECT;
return Saml2MessageBinding.REDIRECT;
}
/**
@@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;
@@ -27,9 +28,6 @@ import org.apache.commons.codec.binary.Base64;
import org.springframework.security.saml2.Saml2Exception;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.zip.Deflater.DEFLATED;
/**
* @since 5.3
*/
@@ -48,8 +46,8 @@ final class Saml2Utils {
static byte[] samlDeflate(String s) {
try {
ByteArrayOutputStream b = new ByteArrayOutputStream();
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true));
deflater.write(s.getBytes(UTF_8));
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true));
deflater.write(s.getBytes(StandardCharsets.UTF_8));
deflater.finish();
return b.toByteArray();
}
@@ -64,7 +62,7 @@ final class Saml2Utils {
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
return new String(out.toByteArray(), StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
@@ -26,6 +26,7 @@ import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
@@ -44,9 +45,6 @@ import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
/**
* Resolves the SAML 2.0 Relying Party Metadata for a given
* {@link RelyingPartyRegistration} using the OpenSAML API.
@@ -64,8 +62,8 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
private final EntityDescriptorMarshaller entityDescriptorMarshaller;
public OpenSamlMetadataResolver() {
this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) getMarshallerFactory()
.getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) XMLObjectProviderRegistrySupport
.getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null");
}
@@ -135,7 +133,7 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
@SuppressWarnings("unchecked")
private <T> T build(QName elementName) {
XMLObjectBuilder<?> builder = getBuilderFactory().getBuilder(elementName);
XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName);
if (builder == null) {
throw new Saml2Exception("Unable to resolve Builder for " + elementName);
}
@@ -16,6 +16,7 @@
package org.springframework.security.saml2.provider.service.registration;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
@@ -24,10 +25,6 @@ import java.util.Map;
import org.springframework.util.Assert;
import static java.util.Arrays.asList;
import static org.springframework.util.Assert.notEmpty;
import static org.springframework.util.Assert.notNull;
/**
* @since 5.2
*/
@@ -37,11 +34,11 @@ public class InMemoryRelyingPartyRegistrationRepository
private final Map<String, RelyingPartyRegistration> byRegistrationId;
public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) {
this(asList(registrations));
this(Arrays.asList(registrations));
}
public InMemoryRelyingPartyRegistrationRepository(Collection<RelyingPartyRegistration> registrations) {
notEmpty(registrations, "registrations cannot be empty");
Assert.notEmpty(registrations, "registrations cannot be empty");
this.byRegistrationId = createMappingToIdentityProvider(registrations);
}
@@ -49,9 +46,9 @@ public class InMemoryRelyingPartyRegistrationRepository
Collection<RelyingPartyRegistration> rps) {
LinkedHashMap<String, RelyingPartyRegistration> result = new LinkedHashMap<>();
for (RelyingPartyRegistration rp : rps) {
notNull(rp, "relying party collection cannot contain null values");
Assert.notNull(rp, "relying party collection cannot contain null values");
String key = rp.getRegistrationId();
notNull(rp, "relying party identifier cannot be null");
Assert.notNull(rp, "relying party identifier cannot be null");
Assert.isNull(result.get(key), () -> "relying party duplicate identifier '" + key + "' detected.");
result.put(key, rp);
}
@@ -27,6 +27,7 @@ import java.util.List;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor;
@@ -47,12 +48,6 @@ import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static java.lang.Boolean.TRUE;
import static org.opensaml.saml.common.xml.SAMLConstants.SAML20P_NS;
import static org.springframework.security.saml2.core.Saml2X509Credential.encryption;
import static org.springframework.security.saml2.core.Saml2X509Credential.verification;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
/**
* An {@link HttpMessageConverter} that takes an {@code IDPSSODescriptor} in an HTTP
* response and converts it into a {@link RelyingPartyRegistration.Builder}.
@@ -133,7 +128,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException {
EntityDescriptor descriptor = entityDescriptor(inputMessage.getBody());
IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAML20P_NS);
IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAMLConstants.SAML20P_NS);
if (idpssoDescriptor == null) {
throw new Saml2Exception("Metadata response is missing the necessary IDPSSODescriptor element");
}
@@ -143,20 +138,20 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
if (keyDescriptor.getUse().equals(UsageType.SIGNING)) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
verification.add(verification(certificate));
verification.add(Saml2X509Credential.verification(certificate));
}
}
if (keyDescriptor.getUse().equals(UsageType.ENCRYPTION)) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
encryption.add(encryption(certificate));
encryption.add(Saml2X509Credential.encryption(certificate));
}
}
if (keyDescriptor.getUse().equals(UsageType.UNSPECIFIED)) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
verification.add(verification(certificate));
encryption.add(encryption(certificate));
verification.add(Saml2X509Credential.verification(certificate));
encryption.add(Saml2X509Credential.encryption(certificate));
}
}
}
@@ -164,9 +159,9 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
throw new Saml2Exception(
"Metadata response is missing verification certificates, necessary for verifying SAML assertions");
}
RelyingPartyRegistration.Builder builder = withRegistrationId(descriptor.getEntityID())
RelyingPartyRegistration.Builder builder = RelyingPartyRegistration.withRegistrationId(descriptor.getEntityID())
.assertingPartyDetails(party -> party.entityId(descriptor.getEntityID())
.wantAuthnRequestsSigned(TRUE.equals(idpssoDescriptor.getWantAuthnRequestsSigned()))
.wantAuthnRequestsSigned(Boolean.TRUE.equals(idpssoDescriptor.getWantAuthnRequestsSigned()))
.verificationX509Credentials(c -> c.addAll(verification))
.encryptionX509Credentials(c -> c.addAll(encryption)));
for (SingleSignOnService singleSignOnService : idpssoDescriptor.getSingleSignOnServices()) {
@@ -0,0 +1,74 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.servlet.filter;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
/**
* @since 5.3
*/
final class Saml2ServletUtils {
private static final char PATH_DELIMITER = '/';
static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
if (!StringUtils.hasText(template)) {
return baseUrl;
}
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>();
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
.build();
String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
String host = uriComponents.getHost();
uriVariables.put("baseHost", host == null ? "" : host);
// following logic is based on HierarchicalUriComponents#toUriString()
int port = uriComponents.getPort();
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
String path = uriComponents.getPath();
if (StringUtils.hasLength(path)) {
if (path.charAt(0) != PATH_DELIMITER) {
path = PATH_DELIMITER + path;
}
}
uriVariables.put("basePath", path == null ? "" : path);
uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
}
static String getApplicationUri(HttpServletRequest request) {
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
return uriComponents.toUriString();
}
}
@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
@@ -30,9 +31,7 @@ import org.springframework.security.web.authentication.AbstractAuthenticationPro
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
import org.springframework.util.Assert;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
import static org.springframework.util.StringUtils.hasText;
import org.springframework.util.StringUtils;
/**
* @since 5.2
@@ -88,7 +87,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
@Override
protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) {
return (super.requiresAuthentication(request, response) && hasText(request.getParameter("SAMLResponse")));
return (super.requiresAuthentication(request, response)
&& StringUtils.hasText(request.getParameter("SAMLResponse")));
}
@Override
@@ -96,7 +96,7 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
throws AuthenticationException {
Authentication authentication = this.authenticationConverter.convert(request);
if (authentication == null) {
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
Saml2Error saml2Error = new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND,
"No relying party registration found");
throw new Saml2AuthenticationException(saml2Error);
}
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.servlet.filter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
@@ -44,8 +45,6 @@ import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
/**
* This {@code Filter} formulates a
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0
@@ -176,7 +175,8 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
private void addParameter(String name, String value, UriComponentsBuilder builder) {
Assert.hasText(name, "name cannot be empty or null");
if (StringUtils.hasText(value)) {
builder.queryParam(UriUtils.encode(name, ISO_8859_1), UriUtils.encode(value, ISO_8859_1));
builder.queryParam(UriUtils.encode(name, StandardCharsets.ISO_8859_1),
UriUtils.encode(value, StandardCharsets.ISO_8859_1));
}
}
@@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@@ -32,10 +33,6 @@ import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
/**
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
* registration id from the request, querying a
@@ -77,8 +74,9 @@ public final class DefaultRelyingPartyRegistrationResolver
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
String assertionConsumerServiceLocation = templateResolver
.apply(relyingPartyRegistration.getAssertionConsumerServiceLocation());
return withRelyingPartyRegistration(relyingPartyRegistration).entityId(relyingPartyEntityId)
.assertionConsumerServiceLocation(assertionConsumerServiceLocation).build();
return RelyingPartyRegistration.withRelyingPartyRegistration(relyingPartyRegistration)
.entityId(relyingPartyEntityId).assertionConsumerServiceLocation(assertionConsumerServiceLocation)
.build();
}
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
@@ -111,8 +109,8 @@ public final class DefaultRelyingPartyRegistrationResolver
}
private static String getApplicationUri(HttpServletRequest request) {
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)).replacePath(request.getContextPath())
.replaceQuery(null).fragment(null).build();
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
return uriComponents.toUriString();
}
@@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
@@ -33,8 +34,6 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken}
* appropriate for authenticated a SAML 2.0 Assertion against an
@@ -84,7 +83,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
return samlInflate(b);
}
else {
return new String(b, UTF_8);
return new String(b, StandardCharsets.UTF_8);
}
}
@@ -98,7 +97,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
return new String(out.toByteArray(), StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
@@ -18,6 +18,7 @@ package org.springframework.security.saml2.core;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;
@@ -27,9 +28,6 @@ import org.apache.commons.codec.binary.Base64;
import org.springframework.security.saml2.Saml2Exception;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.zip.Deflater.DEFLATED;
public final class Saml2Utils {
private static Base64 BASE64 = new Base64(0, new byte[] { '\n' });
@@ -45,8 +43,8 @@ public final class Saml2Utils {
public static byte[] samlDeflate(String s) {
try {
ByteArrayOutputStream b = new ByteArrayOutputStream();
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true));
deflater.write(s.getBytes(UTF_8));
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true));
deflater.write(s.getBytes(StandardCharsets.UTF_8));
deflater.finish();
return b.toByteArray();
}
@@ -61,7 +59,7 @@ public final class Saml2Utils {
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
return new String(out.toByteArray(), StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.core;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
@@ -27,12 +28,7 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.security.converter.RsaKeyConverters;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
public class Saml2X509CredentialTests {
@@ -60,7 +56,7 @@ public class Saml2X509CredentialTests {
+ "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n"
+ "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n"
+ "-----END PRIVATE KEY-----";
this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(UTF_8)));
this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(StandardCharsets.UTF_8)));
final CertificateFactory factory = CertificateFactory.getInstance("X.509");
String certificateData = "-----BEGIN CERTIFICATE-----\n"
+ "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n"
@@ -78,23 +74,25 @@ public class Saml2X509CredentialTests {
+ "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n"
+ "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----";
this.certificate = (X509Certificate) factory
.generateCertificate(new ByteArrayInputStream(certificateData.getBytes(UTF_8)));
.generateCertificate(new ByteArrayInputStream(certificateData.getBytes(StandardCharsets.UTF_8)));
}
@Test
public void constructorWhenRelyingPartyWithCredentialsThenItSucceeds() {
new Saml2X509Credential(this.key, this.certificate, SIGNING);
new Saml2X509Credential(this.key, this.certificate, SIGNING, DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING,
Saml2X509CredentialType.DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.DECRYPTION);
Saml2X509Credential.signing(this.key, this.certificate);
Saml2X509Credential.decryption(this.key, this.certificate);
}
@Test
public void constructorWhenAssertingPartyWithCredentialsThenItSucceeds() {
new Saml2X509Credential(this.certificate, VERIFICATION);
new Saml2X509Credential(this.certificate, VERIFICATION, ENCRYPTION);
new Saml2X509Credential(this.certificate, ENCRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION,
Saml2X509CredentialType.ENCRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.ENCRYPTION);
Saml2X509Credential.verification(this.certificate);
Saml2X509Credential.encryption(this.certificate);
}
@@ -102,49 +100,49 @@ public class Saml2X509CredentialTests {
@Test
public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, (X509Certificate) null, SIGNING);
new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, this.certificate, SIGNING);
new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(this.key, null, SIGNING);
new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, SIGNING);
new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.key, this.certificate, ENCRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION);
}
@Test
public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.key, this.certificate, VERIFICATION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION);
}
@Test
public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.certificate, SIGNING);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.certificate, DECRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION);
}
@Test
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.core;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyException;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
@@ -26,37 +27,33 @@ import java.security.cert.X509Certificate;
import org.opensaml.security.crypto.KeySupport;
import org.springframework.security.saml2.Saml2Exception;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
public final class TestSaml2X509Credentials {
public static Saml2X509Credential assertingPartySigningCredential() {
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING);
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING);
}
public static Saml2X509Credential assertingPartyEncryptingCredential() {
return new Saml2X509Credential(spCertificate(), ENCRYPTION);
return new Saml2X509Credential(spCertificate(), Saml2X509CredentialType.ENCRYPTION);
}
public static Saml2X509Credential assertingPartyPrivateCredential() {
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION);
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING,
Saml2X509CredentialType.DECRYPTION);
}
public static Saml2X509Credential relyingPartyVerifyingCredential() {
return new Saml2X509Credential(idpCertificate(), VERIFICATION);
return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.VERIFICATION);
}
public static Saml2X509Credential relyingPartySigningCredential() {
return new Saml2X509Credential(spPrivateKey(), spCertificate(), SIGNING);
return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.SIGNING);
}
public static Saml2X509Credential relyingPartyDecryptingCredential() {
return new Saml2X509Credential(spPrivateKey(), spCertificate(), DECRYPTION);
return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.DECRYPTION);
}
private static X509Certificate certificate(String cert) {
@@ -71,7 +68,7 @@ public final class TestSaml2X509Credentials {
private static PrivateKey privateKey(String key) {
try {
return KeySupport.decodePrivateKey(key.getBytes(UTF_8), new char[0]);
return KeySupport.decodePrivateKey(key.getBytes(StandardCharsets.UTF_8), new char[0]);
}
catch (KeyException e) {
throw new Saml2Exception(e);
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.credentials;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
@@ -27,12 +28,7 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.security.converter.RsaKeyConverters;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION;
import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType;
public class Saml2X509CredentialTests {
@@ -62,7 +58,7 @@ public class Saml2X509CredentialTests {
+ "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n"
+ "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n"
+ "-----END PRIVATE KEY-----";
this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(UTF_8)));
this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(StandardCharsets.UTF_8)));
final CertificateFactory factory = CertificateFactory.getInstance("X.509");
String certificateData = "-----BEGIN CERTIFICATE-----\n"
+ "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n"
@@ -80,69 +76,71 @@ public class Saml2X509CredentialTests {
+ "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n"
+ "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----";
this.certificate = (X509Certificate) factory
.generateCertificate(new ByteArrayInputStream(certificateData.getBytes(UTF_8)));
.generateCertificate(new ByteArrayInputStream(certificateData.getBytes(StandardCharsets.UTF_8)));
}
@Test
public void constructorWhenRelyingPartyWithCredentialsThenItSucceeds() {
new Saml2X509Credential(this.key, this.certificate, SIGNING);
new Saml2X509Credential(this.key, this.certificate, SIGNING, DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING,
Saml2X509CredentialType.DECRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.DECRYPTION);
}
@Test
public void constructorWhenAssertingPartyWithCredentialsThenItSucceeds() {
new Saml2X509Credential(this.certificate, VERIFICATION);
new Saml2X509Credential(this.certificate, VERIFICATION, ENCRYPTION);
new Saml2X509Credential(this.certificate, ENCRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION,
Saml2X509CredentialType.ENCRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.ENCRYPTION);
}
@Test
public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, (X509Certificate) null, SIGNING);
new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, this.certificate, SIGNING);
new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(this.key, null, SIGNING);
new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class);
new Saml2X509Credential(null, SIGNING);
new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.key, this.certificate, ENCRYPTION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION);
}
@Test
public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.key, this.certificate, VERIFICATION);
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION);
}
@Test
public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.certificate, SIGNING);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING);
}
@Test
public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class);
new Saml2X509Credential(this.certificate, DECRYPTION);
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION);
}
}
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.credentials;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyException;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
@@ -26,37 +27,33 @@ import java.security.cert.X509Certificate;
import org.opensaml.security.crypto.KeySupport;
import org.springframework.security.saml2.Saml2Exception;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION;
import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType;
public final class TestSaml2X509Credentials {
public static Saml2X509Credential assertingPartySigningCredential() {
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING);
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING);
}
public static Saml2X509Credential assertingPartyEncryptingCredential() {
return new Saml2X509Credential(spCertificate(), ENCRYPTION);
return new Saml2X509Credential(spCertificate(), Saml2X509CredentialType.ENCRYPTION);
}
public static Saml2X509Credential assertingPartyPrivateCredential() {
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION);
return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING,
Saml2X509CredentialType.DECRYPTION);
}
public static Saml2X509Credential relyingPartyVerifyingCredential() {
return new Saml2X509Credential(idpCertificate(), VERIFICATION);
return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.VERIFICATION);
}
public static Saml2X509Credential relyingPartySigningCredential() {
return new Saml2X509Credential(spPrivateKey(), spCertificate(), SIGNING);
return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.SIGNING);
}
public static Saml2X509Credential relyingPartyDecryptingCredential() {
return new Saml2X509Credential(spPrivateKey(), spCertificate(), DECRYPTION);
return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.DECRYPTION);
}
private static X509Certificate certificate(String cert) {
@@ -71,7 +68,7 @@ public final class TestSaml2X509Credentials {
private static PrivateKey privateKey(String key) {
try {
return KeySupport.decodePrivateKey(key.getBytes(UTF_8), new char[0]);
return KeySupport.decodePrivateKey(key.getBytes(StandardCharsets.UTF_8), new char[0]);
}
catch (KeyException e) {
throw new Saml2Exception(e);
@@ -47,6 +47,7 @@ import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AttributeStatement;
@@ -64,8 +65,9 @@ import org.xml.sax.InputSource;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.util.StringUtils;
import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
@@ -73,21 +75,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
import static org.springframework.util.StringUtils.hasText;
/**
* Tests for {@link OpenSamlAuthenticationProvider}
@@ -128,16 +115,18 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Assertion assertion = (Assertion) getBuilderFactory().getBuilder(Assertion.DEFAULT_ELEMENT_NAME)
.buildObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider.authenticate(token(serialize(assertion), relyingPartyVerifyingCredential()));
Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider
.authenticate(token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential()));
}
@Test
public void authenticateWhenXmlErrorThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Saml2AuthenticationToken token = token("invalid xml", relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token("invalid xml",
TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -145,10 +134,11 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION));
Response response = response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion());
signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID);
response.getAssertions().add(TestOpenSamlObjects.assertion());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -157,7 +147,8 @@ public class OpenSamlAuthenticationProviderTests {
this.exception.expect(
authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
Saml2AuthenticationToken token = token(response(), assertingPartySigningCredential());
Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(),
TestSaml2X509Credentials.assertingPartySigningCredential());
this.provider.authenticate(token);
}
@@ -165,9 +156,9 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
Response response = response();
response.getAssertions().add(assertion());
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Response response = TestOpenSamlObjects.response();
response.getAssertions().add(TestOpenSamlObjects.assertion());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -175,13 +166,14 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_ASSERTION));
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData()
.setNotOnOrAfter(DateTime.now().minus(Duration.standardDays(3)));
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -189,12 +181,13 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenMissingSubjectThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.setSubject(null);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -202,36 +195,39 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getNameID().setValue(null);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@Test
public void authenticateWhenAssertionContainsValidationAddressThenItSucceeds() throws Exception {
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations()
.forEach(sc -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@Test
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
Response response = response();
Assertion assertion = assertion();
List<AttributeStatement> attributes = attributeStatements();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
List<AttributeStatement> attributes = TestOpenSamlObjects.attributeStatements();
assertion.getAttributeStatements().addAll(attributes);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
Authentication authentication = this.provider.authenticate(token);
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
@@ -250,13 +246,14 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenAttributeValueMarshallerConfiguredThenUses() throws Exception {
Response response = response();
Assertion assertion = assertion();
List<AttributeStatement> attributes = attributeStatements();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
List<AttributeStatement> attributes = TestOpenSamlObjects.attributeStatements();
assertion.getAttributeStatements().addAll(attributes);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
Element attributeElement = element("<element>value</element>");
Marshaller marshaller = mock(Marshaller.class);
@@ -278,47 +275,54 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(response, relyingPartyDecryptingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token);
}
@Test
public void authenticateWhenEncryptedAssertionWithSignatureThenItSucceeds() throws Exception {
Response response = response();
Assertion assertion = signed(assertion(), assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
EncryptedAssertion encryptedAssertion = encrypted(assertion, assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(),
relyingPartyDecryptingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token);
}
@Test
public void authenticateWhenEncryptedAssertionWithResponseSignatureThenItSucceeds() throws Exception {
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(),
relyingPartyDecryptingCredential());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token);
}
@Test
public void authenticateWhenEncryptedNameIdWithSignatureThenItSucceeds() throws Exception {
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
NameID nameId = assertion.getSubject().getNameID();
EncryptedID encryptedID = encrypted(nameId, assertingPartyEncryptingCredential());
EncryptedID encryptedID = TestOpenSamlObjects.encrypted(nameId,
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
assertion.getSubject().setNameID(null);
assertion.getSubject().setEncryptedID(encryptedID);
response.getAssertions().add(assertion);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(),
relyingPartyDecryptingCredential());
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token);
}
@@ -327,10 +331,12 @@ public class OpenSamlAuthenticationProviderTests {
this.exception
.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(serialize(response), relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@@ -339,21 +345,25 @@ public class OpenSamlAuthenticationProviderTests {
this.exception
.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(serialize(response), assertingPartyPrivateCredential());
Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.assertingPartyPrivateCredential());
this.provider.authenticate(token);
}
@Test
public void writeObjectWhenTypeIsSaml2AuthenticationThenNoException() throws IOException {
Response response = response();
Assertion assertion = signed(assertion(), assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
EncryptedAssertion encryptedAssertion = encrypted(assertion, assertingPartyEncryptingCredential());
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(),
relyingPartyDecryptingCredential());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
TestSaml2X509Credentials.relyingPartyDecryptingCredential());
Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token);
// the following code will throw an exception if authentication isn't serializable
@@ -368,13 +378,14 @@ public class OpenSamlAuthenticationProviderTests {
OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setConditionValidators(Collections.singleton(validator));
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
assertion.getConditions().getConditions().add(oneTimeUse);
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
given(validator.getServicedCondition()).willReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
given(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
.willReturn(ValidationResult.VALID);
@@ -385,17 +396,18 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenValidationContextCustomizedThenUsers() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
parameters.put(SIGNATURE_REQUIRED, false);
parameters.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION));
parameters.put(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false);
ValidationContext context = mock(ValidationContext.class);
given(context.getStaticParameters()).willReturn(parameters);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setValidationContextConverter(tuple -> context);
Response response = response();
Assertion assertion = assertion();
Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion();
response.getAssertions().add(assertion);
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
provider.authenticate(token);
verify(context, atLeastOnce()).getStaticParameters();
}
@@ -415,12 +427,12 @@ public class OpenSamlAuthenticationProviderTests {
}
private <T extends XMLObject> T build(QName qName) {
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
}
private String serialize(XMLObject object) {
try {
Marshaller marshaller = getMarshallerFactory().getMarshaller(object);
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
Element element = marshaller.marshall(object);
return SerializeSupport.nodeToString(element);
}
@@ -444,7 +456,7 @@ public class OpenSamlAuthenticationProviderTests {
if (!code.equals(ex.getError().getErrorCode())) {
return false;
}
if (hasText(description)) {
if (StringUtils.hasText(description)) {
if (!description.equals(ex.getError().getDescription())) {
return false;
}
@@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -25,6 +26,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
@@ -32,24 +34,16 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getParserPool;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getUnmarshallerFactory;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
/**
* Tests for {@link OpenSamlAuthenticationRequestFactory}
@@ -66,8 +60,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private RelyingPartyRegistration relyingPartyRegistration;
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport
.getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
@Rule
public ExpectedException exception = ExpectedException.none();
@@ -78,7 +72,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.assertionConsumerServiceLocation("template")
.providerDetails(c -> c.webSsoUrl("https://destination/sso"))
.providerDetails(c -> c.entityId("remote-entity-id")).localEntityIdTemplate("local-entity-id")
.credentials(c -> c.add(relyingPartySigningCredential()));
.credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential()));
this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build();
this.contextBuilder = Saml2AuthenticationRequestContext.builder().issuer("https://issuer")
.relyingPartyRegistration(this.relyingPartyRegistration)
@@ -104,58 +98,64 @@ public class OpenSamlAuthenticationRequestFactoryTests {
assertThat(result.getRelayState()).isEqualTo("Relay State Value");
assertThat(result.getSigAlg()).isNotEmpty();
assertThat(result.getSignature()).isNotEmpty();
assertThat(result.getBinding()).isEqualTo(REDIRECT);
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
}
@Test
public void createRedirectAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
.relyingPartyRegistration(withRelyingPartyRegistration(this.relyingPartyRegistration)
.providerDetails(c -> c.signAuthNRequest(false)).build())
.relyingPartyRegistration(
RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration)
.providerDetails(c -> c.signAuthNRequest(false)).build())
.build();
Saml2RedirectAuthenticationRequest result = this.factory.createRedirectAuthenticationRequest(this.context);
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isEqualTo("Relay State Value");
assertThat(result.getSigAlg()).isNull();
assertThat(result.getSignature()).isNull();
assertThat(result.getBinding()).isEqualTo(REDIRECT);
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
}
@Test
public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
.relyingPartyRegistration(withRelyingPartyRegistration(this.relyingPartyRegistration)
.providerDetails(c -> c.signAuthNRequest(false)).build())
.relyingPartyRegistration(
RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration)
.providerDetails(c -> c.signAuthNRequest(false)).build())
.build();
Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context);
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isEqualTo("Relay State Value");
assertThat(result.getBinding()).isEqualTo(POST);
assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)).doesNotContain("ds:Signature");
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8))
.doesNotContain("ds:Signature");
}
@Test
public void createPostAuthenticationRequestWhenSignRequestThenSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
.relyingPartyRegistration(withRelyingPartyRegistration(this.relyingPartyRegistration).build()).build();
.relyingPartyRegistration(
RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration).build())
.build();
Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context);
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isEqualTo("Relay State Value");
assertThat(result.getBinding()).isEqualTo(POST);
assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)).contains("ds:Signature");
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8))
.contains("ds:Signature");
}
@Test
public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
AuthnRequest authn = getAuthNRequest(POST);
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
Assert.assertEquals(SAMLConstants.SAML2_POST_BINDING_URI, authn.getProtocolBinding());
}
@Test
public void createAuthenticationRequestWhenSetUriThenReturnsCorrectBinding() {
this.factory.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI);
AuthnRequest authn = getAuthNRequest(POST);
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
Assert.assertEquals(SAMLConstants.SAML2_REDIRECT_BINDING_URI, authn.getProtocolBinding());
}
@@ -199,29 +199,30 @@ public class OpenSamlAuthenticationRequestFactoryTests {
@Test
public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder
.assertionConsumerServiceBinding(REDIRECT).build();
.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
Saml2AuthenticationRequestContext context = this.contextBuilder
.relyingPartyRegistration(relyingPartyRegistration).build();
Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context);
String samlRequest = request.getSamlRequest();
String inflated = new String(samlDecode(samlRequest));
String inflated = new String(Saml2Utils.samlDecode(samlRequest));
assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\"");
}
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT)
AbstractSaml2AuthenticationRequest result = (binding == Saml2MessageBinding.REDIRECT)
? this.factory.createRedirectAuthenticationRequest(this.context)
: this.factory.createPostAuthenticationRequest(this.context);
String samlRequest = result.getSamlRequest();
assertThat(samlRequest).isNotEmpty();
if (result.getBinding() == REDIRECT) {
samlRequest = samlInflate(samlDecode(samlRequest));
if (result.getBinding() == Saml2MessageBinding.REDIRECT) {
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
}
else {
samlRequest = new String(samlDecode(samlRequest), UTF_8);
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
}
try {
Document document = getParserPool().parse(new ByteArrayInputStream(samlRequest.getBytes(UTF_8)));
Document document = XMLObjectProviderRegistrySupport.getParserPool()
.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (AuthnRequest) this.unmarshaller.unmarshall(element);
}
@@ -20,12 +20,10 @@ import java.util.UUID;
import org.junit.Test;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
/**
* Tests for {@link Saml2AuthenticationRequestFactory} default interface methods
@@ -36,7 +34,7 @@ public class Saml2AuthenticationRequestFactoryTests {
.assertionConsumerServiceUrlTemplate("template")
.providerDetails(c -> c.webSsoUrl("https://example.com/destination"))
.providerDetails(c -> c.entityId("remote-entity-id")).localEntityIdTemplate("local-entity-id")
.credentials(c -> c.add(relyingPartySigningCredential())).build();
.credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential())).build();
@Test
public void createAuthenticationRequestParametersWhenRedirectDefaultIsUsedMessageIsDeflatedAndEncoded() {
@@ -47,8 +45,8 @@ public class Saml2AuthenticationRequestFactoryTests {
.assertionConsumerServiceUrl("https://example.com/acs-url").build();
Saml2RedirectAuthenticationRequest response = factory.createRedirectAuthenticationRequest(request);
String resultValue = response.getSamlRequest();
byte[] decoded = samlDecode(resultValue);
String inflated = samlInflate(decoded);
byte[] decoded = Saml2Utils.samlDecode(resultValue);
String inflated = Saml2Utils.samlInflate(decoded);
assertThat(inflated).isEqualTo(value);
}
@@ -61,7 +59,7 @@ public class Saml2AuthenticationRequestFactoryTests {
.assertionConsumerServiceUrl("https://example.com/acs-url").build();
Saml2PostAuthenticationRequest response = factory.createPostAuthenticationRequest(request);
String resultValue = response.getSamlRequest();
byte[] decoded = samlDecode(resultValue);
byte[] decoded = Saml2Utils.samlDecode(resultValue);
assertThat(new String(decoded)).isEqualTo(value);
}
@@ -30,6 +30,7 @@ import org.apache.xml.security.encryption.XMLCipherParameters;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
@@ -79,8 +80,6 @@ import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
final class TestOpenSamlObjects {
static {
@@ -368,7 +367,7 @@ final class TestOpenSamlObjects {
}
static <T extends XMLObject> T build(QName qName) {
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
}
}
@@ -16,7 +16,7 @@
package org.springframework.security.saml2.provider.service.authentication;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
/**
* Test {@link Saml2AuthenticationRequestContext}s
@@ -25,7 +25,7 @@ public class TestSaml2AuthenticationRequestContexts {
public static Saml2AuthenticationRequestContext.Builder authenticationRequestContext() {
return Saml2AuthenticationRequestContext.builder().relayState("relayState").issuer("issuer")
.relyingPartyRegistration(relyingPartyRegistration().build())
.relyingPartyRegistration(TestRelyingPartyRegistrations.relyingPartyRegistration().build())
.assertionConsumerServiceUrl("assertionConsumerServiceUrl");
}
@@ -18,13 +18,12 @@ package org.springframework.security.saml2.provider.service.metadata;
import org.junit.Test;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.full;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
/**
* Tests for {@link OpenSamlMetadataResolver}
@@ -34,7 +33,8 @@ public class OpenSamlMetadataResolverTests {
@Test
public void resolveWhenRelyingPartyThenMetadataMatches() {
// given
RelyingPartyRegistration relyingPartyRegistration = full().assertionConsumerServiceBinding(REDIRECT).build();
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full()
.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
// when
@@ -52,9 +52,9 @@ public class OpenSamlMetadataResolverTests {
@Test
public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() {
// given
RelyingPartyRegistration relyingPartyRegistration = noCredentials()
.assertingPartyDetails(
party -> party.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())))
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails(party -> party.verificationX509Credentials(
c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
.build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
@@ -25,12 +25,12 @@ import java.util.Base64;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.security.saml2.Saml2Exception;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.springframework.http.HttpStatus.OK;
public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
@@ -62,7 +62,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
@Test
public void readWhenMissingIDPSSODescriptorThenException() {
MockClientHttpResponse response = new MockClientHttpResponse(
(String.format(ENTITY_DESCRIPTOR_TEMPLATE, "")).getBytes(), OK);
(String.format(ENTITY_DESCRIPTOR_TEMPLATE, "")).getBytes(), HttpStatus.OK);
assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response))
.isInstanceOf(Saml2Exception.class)
.hasMessageContaining("Metadata response is missing the necessary IDPSSODescriptor element");
@@ -71,7 +71,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
@Test
public void readWhenMissingVerificationKeyThenException() {
String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, ""));
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK);
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK);
assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response))
.isInstanceOf(Saml2Exception.class).hasMessageContaining(
"Metadata response is missing verification certificates, necessary for verifying SAML assertions");
@@ -81,7 +81,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
public void readWhenMissingSingleSignOnServiceThenException() {
String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE,
String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\"")));
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK);
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK);
assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response))
.isInstanceOf(Saml2Exception.class).hasMessageContaining(
"Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests");
@@ -94,7 +94,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\"")
+ String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"encryption\"")
+ String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE)));
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK);
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK);
RelyingPartyRegistration registration = this.converter.read(RelyingPartyRegistration.Builder.class, response)
.registrationId("one").build();
RelyingPartyRegistration.AssertingPartyDetails details = registration.getAssertingPartyDetails();
@@ -114,7 +114,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests {
public void readWhenKeyDescriptorHasNoUseThenConfiguresBothKeyTypes() throws Exception {
String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, String.format(IDP_SSO_DESCRIPTOR_TEMPLATE,
String.format(KEY_DESCRIPTOR_TEMPLATE, "") + String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE)));
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK);
MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK);
RelyingPartyRegistration registration = this.converter.read(RelyingPartyRegistration.Builder.class, response)
.registrationId("one").build();
RelyingPartyRegistration.AssertingPartyDetails details = registration.getAssertingPartyDetails();
@@ -18,19 +18,17 @@ package org.springframework.security.saml2.provider.service.registration;
import org.junit.Test;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
public class RelyingPartyRegistrationTests {
@Test
public void withRelyingPartyRegistrationWorks() {
RelyingPartyRegistration registration = relyingPartyRegistration().providerDetails(p -> p.binding(POST))
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.providerDetails(p -> p.binding(Saml2MessageBinding.POST))
.providerDetails(p -> p.signAuthNRequest(false))
.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build();
@@ -59,7 +57,8 @@ public class RelyingPartyRegistrationTests {
.isEqualTo("https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php");
assertThat(copy.getProviderDetails().getBinding()).isEqualTo(registration.getProviderDetails().getBinding())
.isEqualTo(copy.getAssertingPartyDetails().getSingleSignOnServiceBinding())
.isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding()).isEqualTo(POST);
.isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding())
.isEqualTo(Saml2MessageBinding.POST);
assertThat(copy.getProviderDetails().isSignAuthNRequest())
.isEqualTo(registration.getProviderDetails().isSignAuthNRequest())
.isEqualTo(copy.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@@ -76,13 +75,13 @@ public class RelyingPartyRegistrationTests {
@Test
public void buildWhenUsingDefaultsThenAssertionConsumerServiceBindingDefaultsToPost() {
RelyingPartyRegistration relyingPartyRegistration = withRegistrationId("id").entityId("entity-id")
.assertionConsumerServiceLocation("location")
RelyingPartyRegistration relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id")
.entityId("entity-id").assertionConsumerServiceLocation("location")
.assertingPartyDetails(
assertingParty -> assertingParty.entityId("entity-id").singleSignOnServiceLocation("location"))
.credentials(c -> c.add(relyingPartyVerifyingCredential())).build();
.credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())).build();
assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()).isEqualTo(POST);
assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
}
}
@@ -23,7 +23,7 @@ import org.junit.Test;
import org.springframework.security.saml2.Saml2Exception;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatCode;
/**
* Tests for {@link RelyingPartyRegistration}
@@ -16,13 +16,10 @@
package org.springframework.security.saml2.provider.service.registration;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
/**
* Preconfigured test data for {@link RelyingPartyRegistration} objects
*/
@@ -32,12 +29,12 @@ public class TestRelyingPartyRegistrations {
String registrationId = "simplesamlphp";
String rpEntityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}";
Saml2X509Credential signingCredential = relyingPartySigningCredential();
Saml2X509Credential signingCredential = TestSaml2X509Credentials.relyingPartySigningCredential();
String assertionConsumerServiceLocation = "{baseUrl}"
+ Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
String apEntityId = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php";
Saml2X509Credential verificationCertificate = relyingPartyVerifyingCredential();
Saml2X509Credential verificationCertificate = TestSaml2X509Credentials.relyingPartyVerifyingCredential();
String singleSignOnServiceLocation = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php";
return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(rpEntityId)
@@ -55,10 +52,13 @@ public class TestRelyingPartyRegistrations {
public static RelyingPartyRegistration.Builder full() {
return noCredentials()
.signingX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential()))
.decryptionX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential()))
.signingX509Credentials(c -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartySigningCredential()))
.decryptionX509Credentials(c -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyDecryptingCredential()))
.assertingPartyDetails(party -> party.verificationX509Credentials(
c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())));
c -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyVerifyingCredential())));
}
}
@@ -27,10 +27,13 @@ import org.junit.Test;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriUtils;
@@ -42,9 +45,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
public class Saml2WebSsoAuthenticationRequestFilterTests {
@@ -78,7 +78,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
this.rpBuilder = RelyingPartyRegistration.withRegistrationId("registration-id")
.providerDetails(c -> c.entityId("idp-entity-id")).providerDetails(c -> c.webSsoUrl(IDP_SSO_URL))
.assertionConsumerServiceUrlTemplate("template")
.credentials(c -> c.add(assertingPartyPrivateCredential()));
.credentials(c -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()));
}
@Test
@@ -133,7 +133,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenPostFormDataIsPresent() throws Exception {
given(this.repository.findByRegistrationId("registration-id"))
.willReturn(this.rpBuilder.providerDetails(c -> c.binding(POST)).build());
.willReturn(this.rpBuilder.providerDetails(c -> c.binding(Saml2MessageBinding.POST)).build());
final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param&javascript{alert('1');}";
final String relayStateEncoded = HtmlUtils.htmlEscape(relayStateValue);
this.request.setParameter("RelayState", relayStateValue);
@@ -147,7 +147,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenSetAuthenticationRequestFactoryThenUses() throws Exception {
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(POST)).build();
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(Saml2MessageBinding.POST))
.build();
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
given(authenticationRequest.getRelayState()).willReturn("relay");
@@ -166,13 +167,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
@Test
public void doFilterWhenCustomAuthenticationRequestFactoryThenUses() throws Exception {
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(POST)).build();
RelyingPartyRegistration relyingParty = this.rpBuilder.providerDetails(c -> c.binding(Saml2MessageBinding.POST))
.build();
Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
given(authenticationRequest.getRelayState()).willReturn("relay");
given(authenticationRequest.getSamlRequest()).willReturn("saml");
given(this.resolver.resolve(this.request))
.willReturn(authenticationRequestContext().relyingPartyRegistration(relyingParty).build());
given(this.resolver.resolve(this.request)).willReturn(TestSaml2AuthenticationRequestContexts
.authenticationRequestContext().relyingPartyRegistration(relyingParty).build());
given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver,
@@ -22,17 +22,18 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
/**
* Tests for {@link DefaultRelyingPartyRegistrationResolver}
*/
public class DefaultRelyingPartyRegistrationResolverTests {
private final RelyingPartyRegistration registration = relyingPartyRegistration().build();
private final RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.build();
private final RelyingPartyRegistrationRepository repository = new InMemoryRelyingPartyRegistrationRepository(
this.registration);
@@ -20,12 +20,12 @@ import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
/**
* Tests for {@link DefaultSaml2AuthenticationRequestContextResolver}
@@ -61,7 +61,7 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
.providerDetails(c -> c.entityId(ASSERTING_PARTY_ENTITY_ID))
.providerDetails(c -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL))
.assertionConsumerServiceUrlTemplate(RELYING_PARTY_SSO_URL)
.credentials(c -> c.add(relyingPartyVerifyingCredential()));
.credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()));
}
@Test
@@ -32,15 +32,14 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@RunWith(MockitoJUnitRunner.class)
public class Saml2AuthenticationTokenConverterTests {
@@ -48,7 +47,8 @@ public class Saml2AuthenticationTokenConverterTests {
@Mock
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistration().build();
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.build();
@Test
public void convertWhenSamlResponseThenToken() {
@@ -57,7 +57,7 @@ public class Saml2AuthenticationTokenConverterTests {
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8)));
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
@@ -126,7 +126,7 @@ public class Saml2AuthenticationTokenConverterTests {
private String getSsoCircleEncodedXml() throws IOException {
ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded");
String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8);
return UriUtils.decode(response, UTF_8);
return UriUtils.decode(response, StandardCharsets.UTF_8);
}
}
@@ -23,9 +23,11 @@ import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
@@ -34,8 +36,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
/**
* Tests for {@link Saml2MetadataFilter}
@@ -108,9 +108,9 @@ public class Saml2MetadataFilterTests {
public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception {
// given
this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
RelyingPartyRegistration validRegistration = noCredentials()
.assertingPartyDetails(
party -> party.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())))
RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails(party -> party.verificationX509Credentials(
c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
.build();
String generatedMetadata = "<xml>test</xml>";