diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 19719dc7c4..5e804c4b0e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -15,6 +15,13 @@ */ package org.springframework.security.oauth2.client.web; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; @@ -31,13 +38,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - /** * The default implementation of an {@link OAuth2AuthorizedClientManager}. * @@ -84,13 +84,13 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); } else { - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( clientRegistrationId, principal, servletRequest); if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); } else { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index cb1d831d9e..d60031b043 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -15,6 +15,18 @@ */ package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.util.context.Context; + import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; @@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.net.URI; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; /** * @author Rob Winch @@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); - when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration); - // Default request attributes set final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com")) .attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();