diff --git a/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java index 2b489dbbdb..1a40d210e0 100644 --- a/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java @@ -27,6 +27,7 @@ import reactor.core.publisher.Mono; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.util.Assert; /** @@ -57,11 +58,24 @@ public class DelegatingReactiveAuthenticationManager implements ReactiveAuthenti @Override public Mono authenticate(Authentication authentication) { + return ReactiveSecurityContextHolder.getContext().flatMap((context) -> { + Mono result = doAuthenticate(authentication); + Authentication current = context.getAuthentication(); + if (current == null) { + return result; + } + if (!current.isAuthenticated()) { + return result; + } + return doAuthenticate(current).map((r) -> r.toBuilder().apply(current).build()); + }).switchIfEmpty(doAuthenticate(authentication)); + } + + private Mono doAuthenticate(Authentication authentication) { Flux result = Flux.fromIterable(this.delegates); Function> logging = (m) -> m.authenticate(authentication) .doOnError(AuthenticationException.class, (ex) -> ex.setAuthenticationRequest(authentication)) .doOnError(this.logger::debug); - return ((this.continueOnError) ? result.concatMapDelayError(logging) : result.concatMap(logging)).next(); } diff --git a/core/src/main/java/org/springframework/security/authentication/ProviderManager.java b/core/src/main/java/org/springframework/security/authentication/ProviderManager.java index d90bfe5bad..417d144849 100644 --- a/core/src/main/java/org/springframework/security/authentication/ProviderManager.java +++ b/core/src/main/java/org/springframework/security/authentication/ProviderManager.java @@ -33,6 +33,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.CredentialsContainer; import org.springframework.security.core.SpringSecurityMessageSource; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -92,6 +94,9 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar private static final Log logger = LogFactory.getLog(ProviderManager.class); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private AuthenticationEventPublisher eventPublisher = new NullEventPublisher(); private List providers = Collections.emptyList(); @@ -209,6 +214,7 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar lastException = ex; } } + result = applyPreviousAuthentication(result); if (result == null && this.parent != null) { // Allow the parent to try. try { @@ -265,6 +271,20 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar throw lastException; } + private @Nullable Authentication applyPreviousAuthentication(@Nullable Authentication result) { + if (result == null) { + return null; + } + Authentication current = this.securityContextHolderStrategy.getContext().getAuthentication(); + if (current == null) { + return result; + } + if (!current.isAuthenticated()) { + return result; + } + return result.toBuilder().apply(current).build(); + } + @SuppressWarnings("deprecation") private void prepareException(AuthenticationException ex, Authentication auth) { ex.setAuthenticationRequest(auth); @@ -287,6 +307,11 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar return this.providers; } + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); diff --git a/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java index d44278a60f..6c430bdf03 100644 --- a/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java @@ -27,10 +27,13 @@ import reactor.test.StepVerifier; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; /** * @author Rob Winch @@ -118,6 +121,24 @@ public class DelegatingReactiveAuthenticationManagerTests { assertThat(expected.getAuthenticationRequest()).isEqualTo(this.authentication); } + @Test + void authenticateWhenPreviousAuthenticationThenApplies() { + Authentication factorOne = new TestingAuthenticationToken("user", "pass", "FACTOR_ONE"); + Authentication factorTwo = new TestingAuthenticationToken("user", "pass", "FACTOR_TWO"); + ReactiveAuthenticationManager provider = mock(ReactiveAuthenticationManager.class); + given(provider.authenticate(any())).willReturn(Mono.just(factorTwo)); + ReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(provider); + Authentication request = new TestingAuthenticationToken("user", "password"); + StepVerifier + .create(manager.authenticate(request) + .flatMapIterable(Authentication::getAuthorities) + .map(GrantedAuthority::getAuthority) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(factorOne))) + .expectNext("FACTOR_TWO") + .expectNext("FACTOR_ONE") + .verifyComplete(); + } + private DelegatingReactiveAuthenticationManager managerWithContinueOnError() { DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, this.delegate2); diff --git a/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java b/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java index 7bb0c136bc..e1c69c5fe7 100644 --- a/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java @@ -19,12 +19,16 @@ package org.springframework.security.authentication; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Set; import org.junit.jupiter.api.Test; import org.springframework.context.MessageSource; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -310,6 +314,22 @@ public class ProviderManagerTests { verifyNoMoreInteractions(publisher); // Child should not publish (duplicate event) } + @Test + void authenticateWhenPreviousAuthenticationThenApplies() { + Authentication factorOne = new TestingAuthenticationToken("user", "pass", "FACTOR_ONE"); + Authentication factorTwo = new TestingAuthenticationToken("user", "pass", "FACTOR_TWO"); + SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class); + given(securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(factorOne)); + AuthenticationProvider provider = mock(AuthenticationProvider.class); + given(provider.authenticate(any())).willReturn(factorTwo); + given(provider.supports(any())).willReturn(true); + ProviderManager manager = new ProviderManager(provider); + manager.setSecurityContextHolderStrategy(securityContextHolderStrategy); + Authentication request = new TestingAuthenticationToken("user", "password"); + Set authorities = AuthorityUtils.authorityListToSet(manager.authenticate(request).getAuthorities()); + assertThat(authorities).containsExactlyInAnyOrder("FACTOR_ONE", "FACTOR_TWO"); + } + private AuthenticationProvider createProviderWhichThrows(final AuthenticationException ex) { AuthenticationProvider provider = mock(AuthenticationProvider.class); given(provider.supports(any(Class.class))).willReturn(true);