From a10c2c6cf881dc9c83296d4890a3c8d038d47243 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 28 Jul 2020 17:19:48 -0600 Subject: [PATCH] Polish DefaultSaml2AuthenticationRequestContextResolver Issue gh-8360 Issue gh-8887 --- .../saml2/Saml2LoginConfigurer.java | 20 ++++- .../saml2/Saml2LoginConfigurerTests.java | 20 +---- ...aml2WebSsoAuthenticationRequestFilter.java | 36 ++++----- ...2AuthenticationRequestContextResolver.java | 78 ++++--------------- ...2AuthenticationRequestContextResolver.java | 12 ++- ...ebSsoAuthenticationRequestFilterTests.java | 10 ++- ...enticationRequestContextResolverTests.java | 33 ++++---- 7 files changed, 75 insertions(+), 134 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 73f33203c1..321a492f0a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer> extend private final class AuthenticationRequestEndpointConfig { private String filterProcessingUrl = "/saml2/authenticate/{registrationId}"; + private AuthenticationRequestEndpointConfig() { } private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); + Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); return postProcess(new Saml2WebSsoAuthenticationRequestFilter( - Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, - authenticationRequestResolver)); + contextResolver, authenticationRequestResolver)); } private Saml2AuthenticationRequestFactory getResolver(B http) { @@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer> extend } return resolver; } + + private Saml2AuthenticationRequestContextResolver getContextResolver(B http) { + Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class); + if (resolver == null) { + return new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver( + Saml2LoginConfigurer.this.relyingPartyRegistrationRepository)); + } + return resolver; + } } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 440a10f933..21845bcec3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; -import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpRequestResponseHolder; @@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests { Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver; - when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class))) + when(resolver.resolve(any(HttpServletRequest.class))) .thenReturn(context); this.mvc.perform(get("/saml2/authenticate/registration-id")) .andExpect(status().isFound()); - verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); + verify(resolver).resolve(any(HttpServletRequest.class)); } @Test @@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { - ObjectPostProcessor processor - = new ObjectPostProcessor() { - @Override - public O postProcess(O filter) { - filter.setAuthenticationRequestContextResolver(resolver); - return filter; - } - }; - http .authorizeRequests(authz -> authz .anyRequest().authenticated() ) - .saml2Login(saml2 -> saml2 - .addObjectPostProcessor(processor) - ); + .saml2Login(withDefaults()); } @Bean diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index f0270de213..3fa3e9522c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1; */ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver; private Saml2AuthenticationRequestFactory authenticationRequestFactory; - private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(); private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); @@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter */ @Deprecated public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { - this(relyingPartyRegistrationRepository, + this(new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory()); } /** * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters * - * @param relyingPartyRegistrationRepository a repository for relying party configurations + * @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext} * @since 5.4 */ - public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + public Saml2WebSsoAuthenticationRequestFilter( + Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver, Saml2AuthenticationRequestFactory authenticationRequestFactory) { - Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); + + Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); - this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + this.authenticationRequestContextResolver = authenticationRequestContextResolver; this.authenticationRequestFactory = authenticationRequestFactory; } @@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter this.redirectMatcher = redirectMatcher; } - /** - * Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext} - * - * @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use - * @since 5.4 - */ - public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) { - Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); - this.authenticationRequestContextResolver = authenticationRequestContextResolver; - } - /** * {@inheritDoc} */ @@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter return; } - String registrationId = matcher.getVariables().get("registrationId"); - RelyingPartyRegistration relyingParty = - this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); - if (relyingParty == null) { + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request); + if (context == null) { response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty); + RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { sendRedirect(response, context); } else { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java index 7910b74bb9..b9d15b7860 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java @@ -16,45 +16,45 @@ package org.springframework.security.saml2.provider.service.web; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponents; -import org.springframework.web.util.UriComponentsBuilder; - -import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; -import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; /** * The default implementation for {@link Saml2AuthenticationRequestContextResolver} * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext} * * @author Shazin Sadakath + * @author Josh Cummings * @since 5.4 */ public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver { private final Log logger = LogFactory.getLog(getClass()); - private static final char PATH_DELIMITER = '/'; + private final Converter relyingPartyRegistrationResolver; + + public DefaultSaml2AuthenticationRequestContextResolver + (Converter relyingPartyRegistrationResolver) { + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + } /** * {@inheritDoc} */ @Override - public Saml2AuthenticationRequestContext resolve(HttpServletRequest request, - RelyingPartyRegistration relyingParty) { + public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - Assert.notNull(relyingParty, "relyingParty cannot be null"); + RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request); + if (relyingParty == null) { + return null; + } if (this.logger.isDebugEnabled()) { this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + relyingParty.getRegistrationId() + "]"); @@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( HttpServletRequest request, RelyingPartyRegistration relyingParty) { - String applicationUri = getApplicationUri(request); - Function resolver = templateResolver(applicationUri, relyingParty); - String localSpEntityId = resolver.apply(relyingParty.getEntityId()); - String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation()); return Saml2AuthenticationRequestContext.builder() - .issuer(localSpEntityId) + .issuer(relyingParty.getEntityId()) .relyingPartyRegistration(relyingParty) - .assertionConsumerServiceUrl(assertionConsumerServiceUrl) + .assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation()) .relayState(request.getParameter("RelayState")) .build(); } - - private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { - return template -> resolveUrlTemplate(template, applicationUri, relyingParty); - } - - private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { - String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); - String registrationId = relyingParty.getRegistrationId(); - Map uriVariables = new HashMap<>(); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) - .replaceQuery(null) - .fragment(null) - .build(); - String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); - String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); - // following logic is based on HierarchicalUriComponents#toUriString() - int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); - String path = uriComponents.getPath(); - if (StringUtils.hasLength(path)) { - if (path.charAt(0) != PATH_DELIMITER) { - path = PATH_DELIMITER + path; - } - } - uriVariables.put("basePath", path == null ? "" : path); - uriVariables.put("baseUrl", uriComponents.toUriString()); - uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); - uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); - - return UriComponentsBuilder.fromUriString(template) - .buildAndExpand(uriVariables) - .toUriString(); - } - - private static String getApplicationUri(HttpServletRequest request) { - UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) - .replacePath(request.getContextPath()) - .replaceQuery(null) - .fragment(null) - .build(); - return uriComponents.toUriString(); - } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java index 1c86ec239e..db24c8ff90 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java @@ -16,16 +16,16 @@ package org.springframework.security.saml2.provider.service.web; -import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - import javax.servlet.http.HttpServletRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; + /** * This {@code Saml2AuthenticationRequestContextResolver} formulates a * SAML 2.0 AuthnRequest (line 1968) * * @author Shazin Sadakath + * @author Josh Cummings * @since 5.4 */ public interface Saml2AuthenticationRequestContextResolver { @@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver { * * * @param request the current request - * @param relyingParty the relying party responsible for saml2 sso authentication - * @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination + * @return the created {@link Saml2AuthenticationRequestContext} for the request */ - Saml2AuthenticationRequestContext resolve(HttpServletRequest request, - RelyingPartyRegistration relyingParty); + Saml2AuthenticationRequestContext resolve(HttpServletRequest request); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index f9613e080d..1ea4d636c9 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; @@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; +import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; public class Saml2WebSsoAuthenticationRequestFilterTests { @@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { private Saml2WebSsoAuthenticationRequestFilter filter; private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class); + private Saml2AuthenticationRequestContextResolver resolver = + mock(Saml2AuthenticationRequestContextResolver.class); private MockHttpServletRequest request; private MockHttpServletResponse response; private MockFilterChain filterChain; @@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); when(authenticationRequest.getRelayState()).thenReturn("relay"); when(authenticationRequest.getSamlRequest()).thenReturn("saml"); - when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); + when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext() + .relyingPartyRegistration(relyingParty) + .build()); when(this.factory.createPostAuthenticationRequest(any())) .thenReturn(authenticationRequest); Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.repository, this.factory); + (this.resolver, this.factory); filter.doFilterInternal(this.request, this.response, this.filterChain); assertThat(this.response.getContentAsString()) .contains("
") diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java index 182b700965..80f2cd6afc 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java @@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests { private MockHttpServletRequest request; private RelyingPartyRegistration.Builder relyingPartyBuilder; private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver - = new DefaultSaml2AuthenticationRequestContextResolver(); + = new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build())); @Before public void setup() { this.request = new MockHttpServletRequest(); + this.request.setPathInfo("/saml2/authenticate/registration-id"); this.relyingPartyBuilder = RelyingPartyRegistration .withRegistrationId(REGISTRATION_ID) .localEntityIdTemplate(RELYING_PARTY_ENTITY_ID) @@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests { @Test public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { this.request.addParameter("RelayState", "relay-state"); - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build(); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context).isNotNull(); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL); assertThat(context.getRelayState()).isEqualTo("relay-state"); assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL); assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID); - assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty); + assertThat(context.getRelyingPartyRegistration().getRegistrationId()) + .isSameAs(this.relyingPartyBuilder.build().getRegistrationId()); } @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() { - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder - .assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}") - .build(); + this.relyingPartyBuilder + .assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}"); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id"); } @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() { - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder - .assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}") - .build(); + this.relyingPartyBuilder + .assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}"); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()) .isEqualTo("http://localhost/saml2/authenticate/registration-id"); } - @Test - public void resolveWhenRequestNullThenException() { - assertThatCode(() -> - this.authenticationRequestContextResolver.resolve(this.request, null)) - .isInstanceOf(IllegalArgumentException.class); - } - @Test public void resolveWhenRelyingPartyNullThenException() { assertThatCode(() -> - this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build())) + this.authenticationRequestContextResolver.resolve(null)) .isInstanceOf(IllegalArgumentException.class); } }