Add Saml2AuthenticationTokenConverter
Closes gh-8768
This commit is contained in:
+29
-45
@@ -19,23 +19,19 @@ package org.springframework.security.saml2.provider.service.servlet.filter;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||
import org.springframework.security.saml2.core.Saml2Error;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
|
||||
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
|
||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
|
||||
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
||||
import static org.springframework.util.StringUtils.hasText;
|
||||
|
||||
/**
|
||||
@@ -44,8 +40,7 @@ import static org.springframework.util.StringUtils.hasText;
|
||||
public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
|
||||
|
||||
public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}";
|
||||
private final RequestMatcher matcher;
|
||||
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
||||
private final AuthenticationConverter authenticationConverter;
|
||||
|
||||
/**
|
||||
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured
|
||||
@@ -64,16 +59,30 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
|
||||
public Saml2WebSsoAuthenticationFilter(
|
||||
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
|
||||
String filterProcessesUrl) {
|
||||
super(filterProcessesUrl);
|
||||
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
|
||||
Assert.hasText(filterProcessesUrl, "filterProcessesUrl must contain a URL pattern");
|
||||
this(new Saml2AuthenticationTokenConverter
|
||||
(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
|
||||
filterProcessesUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters
|
||||
*
|
||||
* @param authenticationConverter the strategy for converting an {@link HttpServletRequest}
|
||||
* into an {@link Authentication}
|
||||
* @param filterProcessingUrl the processing URL, must contain a {registrationId} variable
|
||||
* @since 5.4
|
||||
*/
|
||||
public Saml2WebSsoAuthenticationFilter(
|
||||
AuthenticationConverter authenticationConverter,
|
||||
String filterProcessingUrl) {
|
||||
super(filterProcessingUrl);
|
||||
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
|
||||
Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern");
|
||||
Assert.isTrue(
|
||||
filterProcessesUrl.contains("{registrationId}"),
|
||||
filterProcessingUrl.contains("{registrationId}"),
|
||||
"filterProcessesUrl must contain a {registrationId} match variable"
|
||||
);
|
||||
this.matcher = new AntPathRequestMatcher(filterProcessesUrl);
|
||||
setRequiresAuthenticationRequestMatcher(this.matcher);
|
||||
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
|
||||
this.authenticationConverter = authenticationConverter;
|
||||
setAllowSessionCreation(true);
|
||||
setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy());
|
||||
}
|
||||
@@ -86,37 +95,12 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
|
||||
@Override
|
||||
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
|
||||
throws AuthenticationException {
|
||||
String saml2Response = request.getParameter("SAMLResponse");
|
||||
byte[] b = Saml2Utils.samlDecode(saml2Response);
|
||||
|
||||
String responseXml = inflateIfRequired(request, b);
|
||||
String registrationId = this.matcher.matcher(request).getVariables().get("registrationId");
|
||||
RelyingPartyRegistration rp =
|
||||
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
|
||||
if (rp == null) {
|
||||
Authentication authentication = this.authenticationConverter.convert(request);
|
||||
if (authentication == null) {
|
||||
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
|
||||
"Relying Party Registration not found with ID: " + registrationId);
|
||||
"No relying party registration found");
|
||||
throw new Saml2AuthenticationException(saml2Error);
|
||||
}
|
||||
String applicationUri = Saml2ServletUtils.getApplicationUri(request);
|
||||
String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
|
||||
String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
|
||||
rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
|
||||
RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
|
||||
.entityId(relyingPartyEntityId)
|
||||
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
|
||||
.build();
|
||||
Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
|
||||
relyingPartyRegistration, responseXml);
|
||||
return getAuthenticationManager().authenticate(authentication);
|
||||
}
|
||||
|
||||
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
|
||||
if (HttpMethod.GET.matches(request.getMethod())) {
|
||||
return Saml2Utils.samlInflate(b);
|
||||
}
|
||||
else {
|
||||
return new String(b, UTF_8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.saml2.provider.service.web;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.zip.Inflater;
|
||||
import java.util.zip.InflaterOutputStream;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.security.saml2.Saml2Exception;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
/**
|
||||
* An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} appropriate
|
||||
* for authenticated a SAML 2.0 Assertion against an
|
||||
* {@link org.springframework.security.authentication.AuthenticationManager}.
|
||||
*
|
||||
* @author Josh Cummings
|
||||
* @since 5.4
|
||||
*/
|
||||
public class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
|
||||
private static Base64 BASE64 = new Base64(0, new byte[]{'\n'});
|
||||
|
||||
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
|
||||
|
||||
/**
|
||||
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for resolving
|
||||
* {@link RelyingPartyRegistration}s
|
||||
*
|
||||
* @param relyingPartyRegistrationResolver the strategy for resolving {@link RelyingPartyRegistration}s
|
||||
*/
|
||||
public Saml2AuthenticationTokenConverter
|
||||
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
|
||||
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
|
||||
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public Saml2AuthenticationToken convert(HttpServletRequest request) {
|
||||
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
|
||||
if (relyingPartyRegistration == null) {
|
||||
return null;
|
||||
}
|
||||
String saml2Response = request.getParameter("SAMLResponse");
|
||||
if (saml2Response == null) {
|
||||
return null;
|
||||
}
|
||||
byte[] b = samlDecode(saml2Response);
|
||||
saml2Response = inflateIfRequired(request, b);
|
||||
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response);
|
||||
}
|
||||
|
||||
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
|
||||
if (HttpMethod.GET.matches(request.getMethod())) {
|
||||
return samlInflate(b);
|
||||
}
|
||||
else {
|
||||
return new String(b, UTF_8);
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] samlDecode(String s) {
|
||||
return BASE64.decode(s);
|
||||
}
|
||||
|
||||
private String samlInflate(byte[] b) {
|
||||
try {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
|
||||
iout.write(b);
|
||||
iout.finish();
|
||||
return new String(out.toByteArray(), UTF_8);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new Saml2Exception("Unable to inflate string", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.saml2.core;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.zip.Deflater;
|
||||
import java.util.zip.DeflaterOutputStream;
|
||||
import java.util.zip.Inflater;
|
||||
import java.util.zip.InflaterOutputStream;
|
||||
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
|
||||
import org.springframework.security.saml2.Saml2Exception;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static java.util.zip.Deflater.DEFLATED;
|
||||
|
||||
public final class Saml2Utils {
|
||||
|
||||
private static Base64 BASE64 = new Base64(0, new byte[]{'\n'});
|
||||
|
||||
public static String samlEncode(byte[] b) {
|
||||
return BASE64.encodeAsString(b);
|
||||
}
|
||||
|
||||
public static byte[] samlDecode(String s) {
|
||||
return BASE64.decode(s);
|
||||
}
|
||||
|
||||
public static byte[] samlDeflate(String s) {
|
||||
try {
|
||||
ByteArrayOutputStream b = new ByteArrayOutputStream();
|
||||
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true));
|
||||
deflater.write(s.getBytes(UTF_8));
|
||||
deflater.finish();
|
||||
return b.toByteArray();
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new Saml2Exception("Unable to deflate string", e);
|
||||
}
|
||||
}
|
||||
|
||||
public static String samlInflate(byte[] b) {
|
||||
try {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
|
||||
iout.write(b);
|
||||
iout.finish();
|
||||
return new String(out.toByteArray(), UTF_8);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new Saml2Exception("Unable to inflate string", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
+9
-1
@@ -48,5 +48,13 @@ public class TestRelyingPartyRegistrations {
|
||||
.credentials(c -> c.add(verificationCertificate));
|
||||
}
|
||||
|
||||
|
||||
public static RelyingPartyRegistration.Builder noCredentials() {
|
||||
return RelyingPartyRegistration.withRegistrationId("registration-id")
|
||||
.entityId("rp-entity-id")
|
||||
.assertionConsumerServiceLocation("https://rp.example.org/acs")
|
||||
.assertingPartyDetails(party -> party
|
||||
.entityId("ap-entity-id")
|
||||
.singleSignOnServiceLocation("https://ap.example.org/sso")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -89,7 +89,7 @@ public class Saml2WebSsoAuthenticationFilterTests {
|
||||
failBecauseExceptionWasNotThrown(Saml2AuthenticationException.class);
|
||||
} catch (Exception e) {
|
||||
assertThat(e).isInstanceOf(Saml2AuthenticationException.class);
|
||||
assertThat(e.getMessage()).isEqualTo("Relying Party Registration not found with ID: non-existent-id");
|
||||
assertThat(e.getMessage()).isEqualTo("No relying party registration found");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+102
@@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.saml2.provider.service.web;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.security.saml2.core.Saml2Utils;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class Saml2AuthenticationTokenConverterTests {
|
||||
@Mock
|
||||
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
|
||||
|
||||
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistration().build();
|
||||
|
||||
@Test
|
||||
public void convertWhenSamlResponseThenToken() {
|
||||
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
|
||||
(this.relyingPartyRegistrationResolver);
|
||||
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
|
||||
.thenReturn(this.relyingPartyRegistration);
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8)));
|
||||
Saml2AuthenticationToken token = converter.convert(request);
|
||||
assertThat(token.getSaml2Response()).isEqualTo("response");
|
||||
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
||||
.isEqualTo(relyingPartyRegistration.getRegistrationId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void convertWhenNoSamlResponseThenNull() {
|
||||
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
|
||||
(this.relyingPartyRegistrationResolver);
|
||||
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
|
||||
.thenReturn(this.relyingPartyRegistration);
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
assertThat(converter.convert(request)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void convertWhenNoRelyingPartyRegistrationThenNull() {
|
||||
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
|
||||
(this.relyingPartyRegistrationResolver);
|
||||
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
|
||||
.thenReturn(null);
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
assertThat(converter.convert(request)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void convertWhenGetRequestThenInflates() {
|
||||
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
|
||||
(this.relyingPartyRegistrationResolver);
|
||||
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
|
||||
.thenReturn(this.relyingPartyRegistration);
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setMethod("GET");
|
||||
byte[] deflated = Saml2Utils.samlDeflate("response");
|
||||
String encoded = Saml2Utils.samlEncode(deflated);
|
||||
request.setParameter("SAMLResponse", encoded);
|
||||
Saml2AuthenticationToken token = converter.convert(request);
|
||||
assertThat(token.getSaml2Response()).isEqualTo("response");
|
||||
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
||||
.isEqualTo(relyingPartyRegistration.getRegistrationId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenResolverIsNullThenIllegalArgument() {
|
||||
assertThatCode(() -> new Saml2AuthenticationTokenConverter(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user