Polish DefaultSaml2AuthenticationRequestContextResolver
Issue gh-8360 Issue gh-8887
This commit is contained in:
+17
-3
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2019 the original author or authors.
|
* Copyright 2002-2020 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
|
|||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
||||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
|
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
|
||||||
|
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
||||||
|
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
|
||||||
|
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||||
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
|
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
|
||||||
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
|
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
|
||||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
@@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
|
|||||||
|
|
||||||
private final class AuthenticationRequestEndpointConfig {
|
private final class AuthenticationRequestEndpointConfig {
|
||||||
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
|
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
|
||||||
|
|
||||||
private AuthenticationRequestEndpointConfig() {
|
private AuthenticationRequestEndpointConfig() {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Filter build(B http) {
|
private Filter build(B http) {
|
||||||
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
|
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
|
||||||
|
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
|
||||||
|
|
||||||
return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
|
return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
|
||||||
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
|
contextResolver, authenticationRequestResolver));
|
||||||
authenticationRequestResolver));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Saml2AuthenticationRequestFactory getResolver(B http) {
|
private Saml2AuthenticationRequestFactory getResolver(B http) {
|
||||||
@@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
|
|||||||
}
|
}
|
||||||
return resolver;
|
return resolver;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
|
||||||
|
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
|
||||||
|
if (resolver == null) {
|
||||||
|
return new DefaultSaml2AuthenticationRequestContextResolver(
|
||||||
|
new DefaultRelyingPartyRegistrationResolver(
|
||||||
|
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
|
||||||
|
}
|
||||||
|
return resolver;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-16
@@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
|
|||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
|
||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
|
||||||
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
|
|
||||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||||
import org.springframework.security.web.FilterChainProxy;
|
import org.springframework.security.web.FilterChainProxy;
|
||||||
import org.springframework.security.web.context.HttpRequestResponseHolder;
|
import org.springframework.security.web.context.HttpRequestResponseHolder;
|
||||||
@@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString;
|
|||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
import static org.springframework.security.config.Customizer.withDefaults;
|
||||||
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
||||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||||
@@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests {
|
|||||||
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
|
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
|
||||||
Saml2AuthenticationRequestContextResolver resolver =
|
Saml2AuthenticationRequestContextResolver resolver =
|
||||||
CustomAuthenticationRequestContextResolver.resolver;
|
CustomAuthenticationRequestContextResolver.resolver;
|
||||||
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
|
when(resolver.resolve(any(HttpServletRequest.class)))
|
||||||
.thenReturn(context);
|
.thenReturn(context);
|
||||||
this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
||||||
.andExpect(status().isFound());
|
.andExpect(status().isFound());
|
||||||
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
|
verify(resolver).resolve(any(HttpServletRequest.class));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void configure(HttpSecurity http) throws Exception {
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
|
|
||||||
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
|
|
||||||
@Override
|
|
||||||
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
|
|
||||||
filter.setAuthenticationRequestContextResolver(resolver);
|
|
||||||
return filter;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
http
|
http
|
||||||
.authorizeRequests(authz -> authz
|
.authorizeRequests(authz -> authz
|
||||||
.anyRequest().authenticated()
|
.anyRequest().authenticated()
|
||||||
)
|
)
|
||||||
.saml2Login(saml2 -> saml2
|
.saml2Login(withDefaults());
|
||||||
.addObjectPostProcessor(processor)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
|||||||
+13
-23
@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
|
|||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||||
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
||||||
|
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
||||||
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
|
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
|
||||||
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
@@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
|
|||||||
*/
|
*/
|
||||||
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
|
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
|
||||||
|
|
||||||
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
|
||||||
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
|
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
|
||||||
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
|
|
||||||
|
|
||||||
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
|
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
|
||||||
|
|
||||||
@@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
|
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
|
||||||
this(relyingPartyRegistrationRepository,
|
this(new DefaultSaml2AuthenticationRequestContextResolver(
|
||||||
|
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
|
||||||
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
|
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
|
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
|
||||||
*
|
*
|
||||||
* @param relyingPartyRegistrationRepository a repository for relying party configurations
|
* @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
|
||||||
* @since 5.4
|
* @since 5.4
|
||||||
*/
|
*/
|
||||||
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
|
public Saml2WebSsoAuthenticationRequestFilter(
|
||||||
|
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
|
||||||
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
|
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
|
||||||
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
|
|
||||||
|
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
|
||||||
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
|
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
|
||||||
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
|
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
|
||||||
this.authenticationRequestFactory = authenticationRequestFactory;
|
this.authenticationRequestFactory = authenticationRequestFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||||||
this.redirectMatcher = redirectMatcher;
|
this.redirectMatcher = redirectMatcher;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
|
|
||||||
*
|
|
||||||
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
|
|
||||||
* @since 5.4
|
|
||||||
*/
|
|
||||||
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
|
|
||||||
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
|
|
||||||
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
String registrationId = matcher.getVariables().get("registrationId");
|
Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
|
||||||
RelyingPartyRegistration relyingParty =
|
if (context == null) {
|
||||||
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
|
|
||||||
if (relyingParty == null) {
|
|
||||||
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
|
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
|
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
|
||||||
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
|
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
|
||||||
sendRedirect(response, context);
|
sendRedirect(response, context);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+15
-63
@@ -16,45 +16,45 @@
|
|||||||
|
|
||||||
package org.springframework.security.saml2.provider.service.web;
|
package org.springframework.security.saml2.provider.service.web;
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.util.StringUtils;
|
|
||||||
import org.springframework.web.util.UriComponents;
|
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
|
||||||
|
|
||||||
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
|
|
||||||
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
|
* The default implementation for {@link Saml2AuthenticationRequestContextResolver}
|
||||||
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
|
* which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
|
||||||
*
|
*
|
||||||
* @author Shazin Sadakath
|
* @author Shazin Sadakath
|
||||||
|
* @author Josh Cummings
|
||||||
* @since 5.4
|
* @since 5.4
|
||||||
*/
|
*/
|
||||||
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
|
public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
|
||||||
|
|
||||||
private final Log logger = LogFactory.getLog(getClass());
|
private final Log logger = LogFactory.getLog(getClass());
|
||||||
|
|
||||||
private static final char PATH_DELIMITER = '/';
|
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
|
||||||
|
|
||||||
|
public DefaultSaml2AuthenticationRequestContextResolver
|
||||||
|
(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
|
||||||
|
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
|
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
|
||||||
RelyingPartyRegistration relyingParty) {
|
|
||||||
Assert.notNull(request, "request cannot be null");
|
Assert.notNull(request, "request cannot be null");
|
||||||
Assert.notNull(relyingParty, "relyingParty cannot be null");
|
RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
|
||||||
|
if (relyingParty == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
if (this.logger.isDebugEnabled()) {
|
if (this.logger.isDebugEnabled()) {
|
||||||
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
|
this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
|
||||||
relyingParty.getRegistrationId() + "]");
|
relyingParty.getRegistrationId() + "]");
|
||||||
@@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
|
|||||||
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
|
private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
|
||||||
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
|
HttpServletRequest request, RelyingPartyRegistration relyingParty) {
|
||||||
|
|
||||||
String applicationUri = getApplicationUri(request);
|
|
||||||
Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
|
|
||||||
String localSpEntityId = resolver.apply(relyingParty.getEntityId());
|
|
||||||
String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
|
|
||||||
return Saml2AuthenticationRequestContext.builder()
|
return Saml2AuthenticationRequestContext.builder()
|
||||||
.issuer(localSpEntityId)
|
.issuer(relyingParty.getEntityId())
|
||||||
.relyingPartyRegistration(relyingParty)
|
.relyingPartyRegistration(relyingParty)
|
||||||
.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
|
.assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
|
||||||
.relayState(request.getParameter("RelayState"))
|
.relayState(request.getParameter("RelayState"))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
|
|
||||||
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
|
|
||||||
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
|
|
||||||
String registrationId = relyingParty.getRegistrationId();
|
|
||||||
Map<String, String> uriVariables = new HashMap<>();
|
|
||||||
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
|
|
||||||
.replaceQuery(null)
|
|
||||||
.fragment(null)
|
|
||||||
.build();
|
|
||||||
String scheme = uriComponents.getScheme();
|
|
||||||
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
|
|
||||||
String host = uriComponents.getHost();
|
|
||||||
uriVariables.put("baseHost", host == null ? "" : host);
|
|
||||||
// following logic is based on HierarchicalUriComponents#toUriString()
|
|
||||||
int port = uriComponents.getPort();
|
|
||||||
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
|
|
||||||
String path = uriComponents.getPath();
|
|
||||||
if (StringUtils.hasLength(path)) {
|
|
||||||
if (path.charAt(0) != PATH_DELIMITER) {
|
|
||||||
path = PATH_DELIMITER + path;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uriVariables.put("basePath", path == null ? "" : path);
|
|
||||||
uriVariables.put("baseUrl", uriComponents.toUriString());
|
|
||||||
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
|
|
||||||
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
|
|
||||||
|
|
||||||
return UriComponentsBuilder.fromUriString(template)
|
|
||||||
.buildAndExpand(uriVariables)
|
|
||||||
.toUriString();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getApplicationUri(HttpServletRequest request) {
|
|
||||||
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
|
|
||||||
.replacePath(request.getContextPath())
|
|
||||||
.replaceQuery(null)
|
|
||||||
.fragment(null)
|
|
||||||
.build();
|
|
||||||
return uriComponents.toUriString();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-7
@@ -16,16 +16,16 @@
|
|||||||
|
|
||||||
package org.springframework.security.saml2.provider.service.web;
|
package org.springframework.security.saml2.provider.service.web;
|
||||||
|
|
||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
|
||||||
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
|
* This {@code Saml2AuthenticationRequestContextResolver} formulates a
|
||||||
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
|
* <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
|
||||||
*
|
*
|
||||||
* @author Shazin Sadakath
|
* @author Shazin Sadakath
|
||||||
|
* @author Josh Cummings
|
||||||
* @since 5.4
|
* @since 5.4
|
||||||
*/
|
*/
|
||||||
public interface Saml2AuthenticationRequestContextResolver {
|
public interface Saml2AuthenticationRequestContextResolver {
|
||||||
@@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
|
|||||||
*
|
*
|
||||||
*
|
*
|
||||||
* @param request the current request
|
* @param request the current request
|
||||||
* @param relyingParty the relying party responsible for saml2 sso authentication
|
* @return the created {@link Saml2AuthenticationRequestContext} for the request
|
||||||
* @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
|
|
||||||
*/
|
*/
|
||||||
Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
|
Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
|
||||||
RelyingPartyRegistration relyingParty);
|
|
||||||
}
|
}
|
||||||
|
|||||||
+8
-2
@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
|
|||||||
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
|
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||||
|
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
|
||||||
import org.springframework.web.util.HtmlUtils;
|
import org.springframework.web.util.HtmlUtils;
|
||||||
import org.springframework.web.util.UriUtils;
|
import org.springframework.web.util.UriUtils;
|
||||||
|
|
||||||
@@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify;
|
|||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
|
||||||
|
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||||
|
|
||||||
public class Saml2WebSsoAuthenticationRequestFilterTests {
|
public class Saml2WebSsoAuthenticationRequestFilterTests {
|
||||||
@@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|||||||
private Saml2WebSsoAuthenticationRequestFilter filter;
|
private Saml2WebSsoAuthenticationRequestFilter filter;
|
||||||
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
|
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
|
||||||
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
|
private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
|
||||||
|
private Saml2AuthenticationRequestContextResolver resolver =
|
||||||
|
mock(Saml2AuthenticationRequestContextResolver.class);
|
||||||
private MockHttpServletRequest request;
|
private MockHttpServletRequest request;
|
||||||
private MockHttpServletResponse response;
|
private MockHttpServletResponse response;
|
||||||
private MockFilterChain filterChain;
|
private MockFilterChain filterChain;
|
||||||
@@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|||||||
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
|
when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
|
||||||
when(authenticationRequest.getRelayState()).thenReturn("relay");
|
when(authenticationRequest.getRelayState()).thenReturn("relay");
|
||||||
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
|
when(authenticationRequest.getSamlRequest()).thenReturn("saml");
|
||||||
when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
|
when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
|
||||||
|
.relyingPartyRegistration(relyingParty)
|
||||||
|
.build());
|
||||||
when(this.factory.createPostAuthenticationRequest(any()))
|
when(this.factory.createPostAuthenticationRequest(any()))
|
||||||
.thenReturn(authenticationRequest);
|
.thenReturn(authenticationRequest);
|
||||||
|
|
||||||
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
||||||
(this.repository, this.factory);
|
(this.resolver, this.factory);
|
||||||
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
||||||
assertThat(this.response.getContentAsString())
|
assertThat(this.response.getContentAsString())
|
||||||
.contains("<form action=\"uri\" method=\"post\">")
|
.contains("<form action=\"uri\" method=\"post\">")
|
||||||
|
|||||||
+13
-20
@@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
|
|||||||
private MockHttpServletRequest request;
|
private MockHttpServletRequest request;
|
||||||
private RelyingPartyRegistration.Builder relyingPartyBuilder;
|
private RelyingPartyRegistration.Builder relyingPartyBuilder;
|
||||||
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
|
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
|
||||||
= new DefaultSaml2AuthenticationRequestContextResolver();
|
= new DefaultSaml2AuthenticationRequestContextResolver(
|
||||||
|
new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build()));
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setup() {
|
public void setup() {
|
||||||
this.request = new MockHttpServletRequest();
|
this.request = new MockHttpServletRequest();
|
||||||
|
this.request.setPathInfo("/saml2/authenticate/registration-id");
|
||||||
this.relyingPartyBuilder = RelyingPartyRegistration
|
this.relyingPartyBuilder = RelyingPartyRegistration
|
||||||
.withRegistrationId(REGISTRATION_ID)
|
.withRegistrationId(REGISTRATION_ID)
|
||||||
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
|
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
|
||||||
@@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
|
|||||||
@Test
|
@Test
|
||||||
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
|
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
|
||||||
this.request.addParameter("RelayState", "relay-state");
|
this.request.addParameter("RelayState", "relay-state");
|
||||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
|
|
||||||
Saml2AuthenticationRequestContext context =
|
Saml2AuthenticationRequestContext context =
|
||||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
this.authenticationRequestContextResolver.resolve(this.request);
|
||||||
|
|
||||||
assertThat(context).isNotNull();
|
assertThat(context).isNotNull();
|
||||||
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
|
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
|
||||||
assertThat(context.getRelayState()).isEqualTo("relay-state");
|
assertThat(context.getRelayState()).isEqualTo("relay-state");
|
||||||
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
|
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
|
||||||
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
|
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
|
||||||
assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
|
assertThat(context.getRelyingPartyRegistration().getRegistrationId())
|
||||||
|
.isSameAs(this.relyingPartyBuilder.build().getRegistrationId());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
|
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
|
||||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
|
this.relyingPartyBuilder
|
||||||
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
|
.assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}");
|
||||||
.build();
|
|
||||||
Saml2AuthenticationRequestContext context =
|
Saml2AuthenticationRequestContext context =
|
||||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
this.authenticationRequestContextResolver.resolve(this.request);
|
||||||
|
|
||||||
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
|
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
|
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
|
||||||
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
|
this.relyingPartyBuilder
|
||||||
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
|
.assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}");
|
||||||
.build();
|
|
||||||
Saml2AuthenticationRequestContext context =
|
Saml2AuthenticationRequestContext context =
|
||||||
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
|
this.authenticationRequestContextResolver.resolve(this.request);
|
||||||
|
|
||||||
assertThat(context.getAssertionConsumerServiceUrl())
|
assertThat(context.getAssertionConsumerServiceUrl())
|
||||||
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
|
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void resolveWhenRequestNullThenException() {
|
|
||||||
assertThatCode(() ->
|
|
||||||
this.authenticationRequestContextResolver.resolve(this.request, null))
|
|
||||||
.isInstanceOf(IllegalArgumentException.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveWhenRelyingPartyNullThenException() {
|
public void resolveWhenRelyingPartyNullThenException() {
|
||||||
assertThatCode(() ->
|
assertThatCode(() ->
|
||||||
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
|
this.authenticationRequestContextResolver.resolve(null))
|
||||||
.isInstanceOf(IllegalArgumentException.class);
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user