1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Add Saml2AuthenticationRequestRepository

Closes gh-9185
This commit is contained in:
Marcus Da Coregio
2021-06-28 13:58:42 -03:00
committed by Josh Cummings
parent 6b68a6d62b
commit 16e17d242e
15 changed files with 657 additions and 19 deletions
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractAu
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -40,6 +41,8 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSa
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
@@ -206,6 +209,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
}
this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
this.loginProcessingUrl);
setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter);
setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl);
if (StringUtils.hasText(this.loginPage)) {
@@ -252,6 +256,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
}
}
private void setAuthenticationRequestRepository(B http,
Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) {
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}
private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter == null) {
return new Saml2AuthenticationTokenConverter(
@@ -311,6 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
return idps;
}
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> getAuthenticationRequestRepository(
B http) {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getBeanOrNull(http,
Saml2AuthenticationRequestRepository.class);
if (repository == null) {
return new HttpSessionSaml2AuthenticationRequestRepository();
}
return repository;
}
private <C> C getSharedOrBean(B http, Class<C> clazz) {
C shared = http.getSharedObject(clazz);
if (shared != null) {
@@ -348,8 +367,12 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
return postProcess(
new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver));
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
http);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
authenticationRequestResolver);
filter.setAuthenticationRequestRepository(repository);
return postProcess(filter);
}
private Saml2AuthenticationRequestFactory getResolver(B http) {
@@ -63,6 +63,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -76,9 +77,11 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -237,6 +240,29 @@ public class Saml2LoginConfigurerTests {
assertThat(exception.getCause()).isInstanceOf(IOException.class);
}
@Test
public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id");
this.mvc.perform(request).andExpect(status().isFound());
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class),
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse",
SIGNED_RESPONSE);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
this.mvc.perform(request);
verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class));
verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -371,7 +397,7 @@ public class Saml2LoginConfigurerTests {
@Bean
Saml2AuthenticationRequestContextResolver resolver() {
return resolver;
return this.resolver;
}
}
@@ -420,6 +446,27 @@ public class Saml2LoginConfigurerTests {
}
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestRepository {
private static final Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = mock(
Saml2AuthenticationRequestRepository.class);
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.authorizeRequests((authz) -> authz.anyRequest().authenticated());
http.saml2Login(withDefaults());
return http.build();
}
@Bean
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
return this.repository;
}
}
static class Saml2LoginConfigBeans {
@Bean