diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 2755a288dd..035a2dec4c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.authentication; +import org.opensaml.saml.common.xml.SAMLConstants; import org.springframework.util.Assert; import org.joda.time.DateTime; @@ -32,6 +33,7 @@ import java.util.UUID; public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory { private Clock clock = Clock.systemUTC(); private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); + private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI; /** * {@inheritDoc} @@ -43,7 +45,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication auth.setIssueInstant(new DateTime(this.clock.millis())); auth.setForceAuthn(Boolean.FALSE); auth.setIsPassive(Boolean.FALSE); - auth.setProtocolBinding("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"); + auth.setProtocolBinding(protocolBinding); Issuer issuer = this.saml.buildSAMLObject(Issuer.class); issuer.setValue(request.getIssuer()); auth.setIssuer(issuer); @@ -67,4 +69,21 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + + /** + * Sets the {@code protocolBinding} to use when generating authentication requests + * Acceptable values are {@link SAMLConstants#SAML2_POST_BINDING_URI} and + * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} + * + * @param protocolBinding + * @throws IllegalArgumentException if the protocolBinding is not valid + */ + public void setProtocolBinding(String protocolBinding) { + boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) || + SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding); + if (!isAllowedBinding) { + throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding); + } + this.protocolBinding = protocolBinding; + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java new file mode 100644 index 0000000000..b7823e07c5 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.AuthnRequest; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.relyingPartyCredentials; + +public class OpenSamlAuthenticationRequestFactoryTests { + + private OpenSamlAuthenticationRequestFactory factory; + private Saml2AuthenticationRequest request; + + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Before + public void setUp() { + request = Saml2AuthenticationRequest.builder() + .issuer("https://issuer") + .destination("https://destination/sso") + .assertionConsumerServiceUrl("https://issuer/sso") + .credentials(c -> c.addAll(relyingPartyCredentials())) + .build(); + factory = new OpenSamlAuthenticationRequestFactory(); + } + + @Test + public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { + AuthnRequest authn = getAuthNRequest(); + Assert.assertEquals(SAMLConstants.SAML2_POST_BINDING_URI, authn.getProtocolBinding()); + } + + @Test + public void createAuthenticationRequestWhenSetUriThenReturnsCorrectBinding() { + factory.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); + AuthnRequest authn = getAuthNRequest(); + Assert.assertEquals(SAMLConstants.SAML2_REDIRECT_BINDING_URI, authn.getProtocolBinding()); + } + + @Test + public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(containsString("my-invalid-binding")); + factory.setProtocolBinding("my-invalid-binding"); + } + + private AuthnRequest getAuthNRequest() { + String xml = factory.createAuthenticationRequest(request); + return (AuthnRequest) OpenSamlImplementation.getInstance().resolve(xml); + } +}