1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Polish DefaultSaml2AuthenticationRequestContextResolver

Issue gh-8360
Issue gh-8887
This commit is contained in:
Josh Cummings
2020-07-28 17:19:48 -06:00
parent 015281ff53
commit a10c2c6cf8
7 changed files with 75 additions and 134 deletions
@@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
private final class AuthenticationRequestEndpointConfig { private final class AuthenticationRequestEndpointConfig {
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}"; private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
private AuthenticationRequestEndpointConfig() { private AuthenticationRequestEndpointConfig() {
} }
private Filter build(B http) { private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
return postProcess(new Saml2WebSsoAuthenticationRequestFilter( return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, contextResolver, authenticationRequestResolver));
authenticationRequestResolver));
} }
private Saml2AuthenticationRequestFactory getResolver(B http) { private Saml2AuthenticationRequestFactory getResolver(B http) {
@@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
} }
return resolver; return resolver;
} }
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
return new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
}
return resolver;
}
} }
} }
@@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests {
Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver = Saml2AuthenticationRequestContextResolver resolver =
CustomAuthenticationRequestContextResolver.resolver; CustomAuthenticationRequestContextResolver.resolver;
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class))) when(resolver.resolve(any(HttpServletRequest.class)))
.thenReturn(context); .thenReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id")) this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andExpect(status().isFound()); .andExpect(status().isFound());
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); verify(resolver).resolve(any(HttpServletRequest.class));
} }
@Test @Test
@@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
@Override
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
filter.setAuthenticationRequestContextResolver(resolver);
return filter;
}
};
http http
.authorizeRequests(authz -> authz .authorizeRequests(authz -> authz
.anyRequest().authenticated() .anyRequest().authenticated()
) )
.saml2Login(saml2 -> saml2 .saml2Login(withDefaults());
.addObjectPostProcessor(processor)
);
} }
@Bean @Bean
@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
*/ */
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
private Saml2AuthenticationRequestFactory authenticationRequestFactory; private Saml2AuthenticationRequestFactory authenticationRequestFactory;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
@@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
*/ */
@Deprecated @Deprecated
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this(relyingPartyRegistrationRepository, this(new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory()); new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
} }
/** /**
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
* *
* @param relyingPartyRegistrationRepository a repository for relying party configurations * @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
* @since 5.4 * @since 5.4
*/ */
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, public Saml2WebSsoAuthenticationRequestFilter(
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
Saml2AuthenticationRequestFactory authenticationRequestFactory) { Saml2AuthenticationRequestFactory authenticationRequestFactory) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; this.authenticationRequestContextResolver = authenticationRequestContextResolver;
this.authenticationRequestFactory = authenticationRequestFactory; this.authenticationRequestFactory = authenticationRequestFactory;
} }
@@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
this.redirectMatcher = redirectMatcher; this.redirectMatcher = redirectMatcher;
} }
/**
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
*
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
* @since 5.4
*/
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
return; return;
} }
String registrationId = matcher.getVariables().get("registrationId"); Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
RelyingPartyRegistration relyingParty = if (context == null) {
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (relyingParty == null) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED); response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return; return;
} }
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty); RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
sendRedirect(response, context); sendRedirect(response, context);
} else { } else {
@@ -16,45 +16,45 @@
package org.springframework.security.saml2.provider.service.web; package org.springframework.security.saml2.provider.service.web;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
/** /**
* The default implementation for {@link Saml2AuthenticationRequestContextResolver} * The default implementation for {@link Saml2AuthenticationRequestContextResolver}
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext} * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
* *
* @author Shazin Sadakath * @author Shazin Sadakath
* @author Josh Cummings
* @since 5.4 * @since 5.4
*/ */
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver { public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
private final Log logger = LogFactory.getLog(getClass()); private final Log logger = LogFactory.getLog(getClass());
private static final char PATH_DELIMITER = '/'; private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
public DefaultSaml2AuthenticationRequestContextResolver
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request, public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
RelyingPartyRegistration relyingParty) {
Assert.notNull(request, "request cannot be null"); Assert.notNull(request, "request cannot be null");
Assert.notNull(relyingParty, "relyingParty cannot be null"); RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
if (relyingParty == null) {
return null;
}
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
relyingParty.getRegistrationId() + "]"); relyingParty.getRegistrationId() + "]");
@@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
HttpServletRequest request, RelyingPartyRegistration relyingParty) { HttpServletRequest request, RelyingPartyRegistration relyingParty) {
String applicationUri = getApplicationUri(request);
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
String localSpEntityId = resolver.apply(relyingParty.getEntityId());
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
return Saml2AuthenticationRequestContext.builder() return Saml2AuthenticationRequestContext.builder()
.issuer(localSpEntityId) .issuer(relyingParty.getEntityId())
.relyingPartyRegistration(relyingParty) .relyingPartyRegistration(relyingParty)
.assertionConsumerServiceUrl(assertionConsumerServiceUrl) .assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
.relayState(request.getParameter("RelayState")) .relayState(request.getParameter("RelayState"))
.build(); .build();
} }
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
}
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>();
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
.replaceQuery(null)
.fragment(null)
.build();
String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
String host = uriComponents.getHost();
uriVariables.put("baseHost", host == null ? "" : host);
// following logic is based on HierarchicalUriComponents#toUriString()
int port = uriComponents.getPort();
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
String path = uriComponents.getPath();
if (StringUtils.hasLength(path)) {
if (path.charAt(0) != PATH_DELIMITER) {
path = PATH_DELIMITER + path;
}
}
uriVariables.put("basePath", path == null ? "" : path);
uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
return UriComponentsBuilder.fromUriString(template)
.buildAndExpand(uriVariables)
.toUriString();
}
private static String getApplicationUri(HttpServletRequest request) {
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.replaceQuery(null)
.fragment(null)
.build();
return uriComponents.toUriString();
}
} }
@@ -16,16 +16,16 @@
package org.springframework.security.saml2.provider.service.web; package org.springframework.security.saml2.provider.service.web;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
/** /**
* This {@code Saml2AuthenticationRequestContextResolver} formulates a * This {@code Saml2AuthenticationRequestContextResolver} formulates a
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968) * <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
* *
* @author Shazin Sadakath * @author Shazin Sadakath
* @author Josh Cummings
* @since 5.4 * @since 5.4
*/ */
public interface Saml2AuthenticationRequestContextResolver { public interface Saml2AuthenticationRequestContextResolver {
@@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
* *
* *
* @param request the current request * @param request the current request
* @param relyingParty the relying party responsible for saml2 sso authentication * @return the created {@link Saml2AuthenticationRequestContext} for the request
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
*/ */
Saml2AuthenticationRequestContext resolve(HttpServletRequest request, Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
RelyingPartyRegistration relyingParty);
} }
@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
public class Saml2WebSsoAuthenticationRequestFilterTests { public class Saml2WebSsoAuthenticationRequestFilterTests {
@@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
private Saml2WebSsoAuthenticationRequestFilter filter; private Saml2WebSsoAuthenticationRequestFilter filter;
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class); private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
private Saml2AuthenticationRequestContextResolver resolver =
mock(Saml2AuthenticationRequestContextResolver.class);
private MockHttpServletRequest request; private MockHttpServletRequest request;
private MockHttpServletResponse response; private MockHttpServletResponse response;
private MockFilterChain filterChain; private MockFilterChain filterChain;
@@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
when(authenticationRequest.getRelayState()).thenReturn("relay"); when(authenticationRequest.getRelayState()).thenReturn("relay");
when(authenticationRequest.getSamlRequest()).thenReturn("saml"); when(authenticationRequest.getSamlRequest()).thenReturn("saml");
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
.relyingPartyRegistration(relyingParty)
.build());
when(this.factory.createPostAuthenticationRequest(any())) when(this.factory.createPostAuthenticationRequest(any()))
.thenReturn(authenticationRequest); .thenReturn(authenticationRequest);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
(this.repository, this.factory); (this.resolver, this.factory);
filter.doFilterInternal(this.request, this.response, this.filterChain); filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.response.getContentAsString()) assertThat(this.response.getContentAsString())
.contains("<form action=\"uri\" method=\"post\">") .contains("<form action=\"uri\" method=\"post\">")
@@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
private MockHttpServletRequest request; private MockHttpServletRequest request;
private RelyingPartyRegistration.Builder relyingPartyBuilder; private RelyingPartyRegistration.Builder relyingPartyBuilder;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
= new DefaultSaml2AuthenticationRequestContextResolver(); = new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build()));
@Before @Before
public void setup() { public void setup() {
this.request = new MockHttpServletRequest(); this.request = new MockHttpServletRequest();
this.request.setPathInfo("/saml2/authenticate/registration-id");
this.relyingPartyBuilder = RelyingPartyRegistration this.relyingPartyBuilder = RelyingPartyRegistration
.withRegistrationId(REGISTRATION_ID) .withRegistrationId(REGISTRATION_ID)
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID) .localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
@@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
@Test @Test
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
this.request.addParameter("RelayState", "relay-state"); this.request.addParameter("RelayState", "relay-state");
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
Saml2AuthenticationRequestContext context = Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty); this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context).isNotNull(); assertThat(context).isNotNull();
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
assertThat(context.getRelayState()).isEqualTo("relay-state"); assertThat(context.getRelayState()).isEqualTo("relay-state");
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL); assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID); assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty); assertThat(context.getRelyingPartyRegistration().getRegistrationId())
.isSameAs(this.relyingPartyBuilder.build().getRegistrationId());
} }
@Test @Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() { public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}") .assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}");
.build();
Saml2AuthenticationRequestContext context = Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty); this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id"); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
} }
@Test @Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() { public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}") .assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}");
.build();
Saml2AuthenticationRequestContext context = Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty); this.authenticationRequestContextResolver.resolve(this.request);
assertThat(context.getAssertionConsumerServiceUrl()) assertThat(context.getAssertionConsumerServiceUrl())
.isEqualTo("http://localhost/saml2/authenticate/registration-id"); .isEqualTo("http://localhost/saml2/authenticate/registration-id");
} }
@Test
public void resolveWhenRequestNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(this.request, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test @Test
public void resolveWhenRelyingPartyNullThenException() { public void resolveWhenRelyingPartyNullThenException() {
assertThatCode(() -> assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build())) this.authenticationRequestContextResolver.resolve(null))
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }
} }