1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Add Saml2AuthenticationTokenConverter

Closes gh-8768
This commit is contained in:
Josh Cummings
2020-08-04 17:28:42 -06:00
parent a10c2c6cf8
commit 5061ae9e79
9 changed files with 386 additions and 49 deletions
@@ -19,23 +19,19 @@ package org.springframework.security.saml2.provider.service.servlet.filter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.util.StringUtils.hasText;
/**
@@ -44,8 +40,7 @@ import static org.springframework.util.StringUtils.hasText;
public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}";
private final RequestMatcher matcher;
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final AuthenticationConverter authenticationConverter;
/**
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured
@@ -64,16 +59,30 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
public Saml2WebSsoAuthenticationFilter(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
String filterProcessesUrl) {
super(filterProcessesUrl);
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
Assert.hasText(filterProcessesUrl, "filterProcessesUrl must contain a URL pattern");
this(new Saml2AuthenticationTokenConverter
(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
filterProcessesUrl);
}
/**
* Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters
*
* @param authenticationConverter the strategy for converting an {@link HttpServletRequest}
* into an {@link Authentication}
* @param filterProcessingUrl the processing URL, must contain a {registrationId} variable
* @since 5.4
*/
public Saml2WebSsoAuthenticationFilter(
AuthenticationConverter authenticationConverter,
String filterProcessingUrl) {
super(filterProcessingUrl);
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern");
Assert.isTrue(
filterProcessesUrl.contains("{registrationId}"),
filterProcessingUrl.contains("{registrationId}"),
"filterProcessesUrl must contain a {registrationId} match variable"
);
this.matcher = new AntPathRequestMatcher(filterProcessesUrl);
setRequiresAuthenticationRequestMatcher(this.matcher);
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
this.authenticationConverter = authenticationConverter;
setAllowSessionCreation(true);
setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy());
}
@@ -86,37 +95,12 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
@Override
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
throws AuthenticationException {
String saml2Response = request.getParameter("SAMLResponse");
byte[] b = Saml2Utils.samlDecode(saml2Response);
String responseXml = inflateIfRequired(request, b);
String registrationId = this.matcher.matcher(request).getVariables().get("registrationId");
RelyingPartyRegistration rp =
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (rp == null) {
Authentication authentication = this.authenticationConverter.convert(request);
if (authentication == null) {
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
"Relying Party Registration not found with ID: " + registrationId);
"No relying party registration found");
throw new Saml2AuthenticationException(saml2Error);
}
String applicationUri = Saml2ServletUtils.getApplicationUri(request);
String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
.entityId(relyingPartyEntityId)
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
.build();
Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
relyingPartyRegistration, responseXml);
return getAuthenticationManager().authenticate(authentication);
}
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
if (HttpMethod.GET.matches(request.getMethod())) {
return Saml2Utils.samlInflate(b);
}
else {
return new String(b, UTF_8);
}
}
}
@@ -0,0 +1,105 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.codec.binary.Base64;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} appropriate
* for authenticated a SAML 2.0 Assertion against an
* {@link org.springframework.security.authentication.AuthenticationManager}.
*
* @author Josh Cummings
* @since 5.4
*/
public class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
private static Base64 BASE64 = new Base64(0, new byte[]{'\n'});
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
/**
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for resolving
* {@link RelyingPartyRegistration}s
*
* @param relyingPartyRegistrationResolver the strategy for resolving {@link RelyingPartyRegistration}s
*/
public Saml2AuthenticationTokenConverter
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
}
/**
* {@inheritDoc}
*/
@Override
public Saml2AuthenticationToken convert(HttpServletRequest request) {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
if (relyingPartyRegistration == null) {
return null;
}
String saml2Response = request.getParameter("SAMLResponse");
if (saml2Response == null) {
return null;
}
byte[] b = samlDecode(saml2Response);
saml2Response = inflateIfRequired(request, b);
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response);
}
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
if (HttpMethod.GET.matches(request.getMethod())) {
return samlInflate(b);
}
else {
return new String(b, UTF_8);
}
}
private byte[] samlDecode(String s) {
return BASE64.decode(s);
}
private String samlInflate(byte[] b) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
}
}
}
@@ -0,0 +1,70 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.core;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
import org.apache.commons.codec.binary.Base64;
import org.springframework.security.saml2.Saml2Exception;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.zip.Deflater.DEFLATED;
public final class Saml2Utils {
private static Base64 BASE64 = new Base64(0, new byte[]{'\n'});
public static String samlEncode(byte[] b) {
return BASE64.encodeAsString(b);
}
public static byte[] samlDecode(String s) {
return BASE64.decode(s);
}
public static byte[] samlDeflate(String s) {
try {
ByteArrayOutputStream b = new ByteArrayOutputStream();
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true));
deflater.write(s.getBytes(UTF_8));
deflater.finish();
return b.toByteArray();
}
catch (IOException e) {
throw new Saml2Exception("Unable to deflate string", e);
}
}
public static String samlInflate(byte[] b) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
}
}
}
@@ -48,5 +48,13 @@ public class TestRelyingPartyRegistrations {
.credentials(c -> c.add(verificationCertificate));
}
public static RelyingPartyRegistration.Builder noCredentials() {
return RelyingPartyRegistration.withRegistrationId("registration-id")
.entityId("rp-entity-id")
.assertionConsumerServiceLocation("https://rp.example.org/acs")
.assertingPartyDetails(party -> party
.entityId("ap-entity-id")
.singleSignOnServiceLocation("https://ap.example.org/sso")
);
}
}
@@ -89,7 +89,7 @@ public class Saml2WebSsoAuthenticationFilterTests {
failBecauseExceptionWasNotThrown(Saml2AuthenticationException.class);
} catch (Exception e) {
assertThat(e).isInstanceOf(Saml2AuthenticationException.class);
assertThat(e.getMessage()).isEqualTo("Relying Party Registration not found with ID: non-existent-id");
assertThat(e.getMessage()).isEqualTo("No relying party registration found");
}
}
}
@@ -0,0 +1,102 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.core.convert.converter.Converter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@RunWith(MockitoJUnitRunner.class)
public class Saml2AuthenticationTokenConverterTests {
@Mock
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistration().build();
@Test
public void convertWhenSamlResponseThenToken() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
(this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
.isEqualTo(relyingPartyRegistration.getRegistrationId());
}
@Test
public void convertWhenNoSamlResponseThenNull() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
(this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
assertThat(converter.convert(request)).isNull();
}
@Test
public void convertWhenNoRelyingPartyRegistrationThenNull() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
(this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(null);
MockHttpServletRequest request = new MockHttpServletRequest();
assertThat(converter.convert(request)).isNull();
}
@Test
public void convertWhenGetRequestThenInflates() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter
(this.relyingPartyRegistrationResolver);
when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.thenReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET");
byte[] deflated = Saml2Utils.samlDeflate("response");
String encoded = Saml2Utils.samlEncode(deflated);
request.setParameter("SAMLResponse", encoded);
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
.isEqualTo(relyingPartyRegistration.getRegistrationId());
}
@Test
public void constructorWhenResolverIsNullThenIllegalArgument() {
assertThatCode(() -> new Saml2AuthenticationTokenConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
}