diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java similarity index 60% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java index 38bf53487c..16400dc922 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -15,20 +15,13 @@ */ package org.springframework.security.oauth2.client; -import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.util.Collections; -import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.function.Function; /** @@ -36,25 +29,31 @@ import java.util.function.Function; * that is capable of operating outside of a {@code ServerHttpRequest} context, * e.g. in a scheduled/background thread and/or in the service-tier. * + *

This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}

+ * * @author Ankur Pathak + * @author Phil Clay * @see ReactiveOAuth2AuthorizedClientManager * @see ReactiveOAuth2AuthorizedClientProvider * @see ReactiveOAuth2AuthorizedClientService - * @since 5.3 + * @since 5.2.2 */ -public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { +public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager + implements ReactiveOAuth2AuthorizedClientManager { + private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final ReactiveOAuth2AuthorizedClientService authorizedClientService; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); /** - * Constructs an {@code OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. + * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientService the authorized client service */ - public OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository, + public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, ReactiveOAuth2AuthorizedClientService authorizedClientService) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); @@ -62,35 +61,42 @@ public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientMa this.authorizedClientService = authorizedClientService; } - @Nullable @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + + return createAuthorizationContext(authorizeRequest) + .flatMap(this::authorizeAndSave); + } + + private Mono createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) { String clientRegistrationId = authorizeRequest.getClientRegistrationId(); - OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - // @formatter:off - return Mono.justOrEmpty(authorizedClient) + return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) .map(OAuth2AuthorizationContext::withAuthorizedClient) .switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()) - .map(OAuth2AuthorizationContext::withAuthorizedClient) - .switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration))) - ) - .switchIfEmpty(Mono.error(new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) - ) - ) + .flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()) + .map(OAuth2AuthorizationContext::withAuthorizedClient) + .switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))))) .flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest) - .filter(contextAttributes-> !CollectionUtils.isEmpty(contextAttributes)) - .map(contextAttributes -> contextBuilder.principal(principal) - .attributes(attributes -> { - attributes.putAll(contextAttributes); - }).build()) - ).flatMap(authorizationContext -> this.authorizedClientProvider.authorize(authorizationContext) - .doOnNext(_authorizedClient -> authorizedClientService.saveAuthorizedClient(_authorizedClient, principal)) - .switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(Optional.ofNullable(authorizationContext.getAuthorizedClient())))) - ); - // @formatter:on + .defaultIfEmpty(Collections.emptyMap()) + .map(contextAttributes -> { + OAuth2AuthorizationContext.Builder builder = contextBuilder.principal(principal); + if (!contextAttributes.isEmpty()) { + builder = builder.attributes(attributes -> attributes.putAll(contextAttributes)); + } + return builder.build(); + })); + } + + private Mono authorizeAndSave(OAuth2AuthorizationContext authorizationContext) { + return this.authorizedClientProvider.authorize(authorizationContext) + .flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient( + authorizedClient, + authorizationContext.getPrincipal()) + .thenReturn(authorizedClient)) + .switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient()))); } /** @@ -115,33 +121,17 @@ public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientMa this.contextAttributesMapper = contextAttributesMapper; } - private static Mono currentServerWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); - } - /** * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. */ public static class DefaultContextAttributesMapper implements Function>> { + private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper = + new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper(); + @Override public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { - ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - return Mono.justOrEmpty(serverWebExchange) - .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) - .flatMap(exchange -> { - Map contextAttributes = Collections.emptyMap(); - String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); - if (StringUtils.hasText(scope)) { - contextAttributes = new HashMap<>(); - contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, - StringUtils.delimitedListToStringArray(scope, " ")); - } - return Mono.just(contextAttributes); - }) - .defaultIfEmpty(Collections.emptyMap()); + return Mono.fromCallable(() -> mapper.apply(authorizeRequest)); } } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java similarity index 85% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java index 2be9e4139d..ab9c7ff382 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java @@ -28,39 +28,49 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; +import java.util.Map; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** - * Tests for {@link OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}. + * Tests for {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}. * * @author Ankur Pathak + * @author Phil Clay */ -public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { +public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { private ReactiveClientRegistrationRepository clientRegistrationRepository; private ReactiveOAuth2AuthorizedClientService authorizedClientService; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; - private Function contextAttributesMapper; - private OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private Function>> contextAttributesMapper; + private AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager; private ClientRegistration clientRegistration; private Authentication principal; private OAuth2AuthorizedClient authorizedClient; private ArgumentCaptor authorizationContextCaptor; + private PublisherProbe saveAuthorizedClientProbe; @SuppressWarnings("unchecked") @Before public void setup() { this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); + this.saveAuthorizedClientProbe = PublisherProbe.empty(); + when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); this.contextAttributesMapper = mock(Function.class); - this.authorizedClientManager = new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty()); + this.authorizedClientManager = new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientService); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); @@ -73,23 +83,23 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService)) + assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("reactiveClientRegistrationRepository cannot be null"); + .hasMessage("clientRegistrationRepository cannot be null"); } @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("reactiveAuthorizedClientService cannot be null"); + .hasMessage("authorizedClientService cannot be null"); } @Test public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("reactiveAuthorizedClientProvider cannot be null"); + .hasMessage("authorizedClientProvider cannot be null"); } @Test @@ -132,7 +142,7 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT .build(); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - authorizedClient.subscribe(); + StepVerifier.create(authorizedClient).verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); @@ -142,7 +152,6 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - StepVerifier.create(authorizedClient).expectComplete(); verify(this.authorizedClientService, never()).saveAuthorizedClient( any(OAuth2AuthorizedClient.class), eq(this.principal)); } @@ -163,7 +172,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT .build(); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - authorizedClient.subscribe(); + StepVerifier.create(authorizedClient) + .expectNext(this.authorizedClient) + .verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); @@ -173,9 +184,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); verify(this.authorizedClientService).saveAuthorizedClient( eq(this.authorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); } @SuppressWarnings("unchecked") @@ -197,8 +208,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT .build(); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - authorizedClient.subscribe(); - + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); @@ -207,9 +219,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); } @SuppressWarnings("unchecked") @@ -221,8 +233,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT .build(); Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - authorizedClient.subscribe(); - + StepVerifier.create(authorizedClient) + .expectNext(this.authorizedClient) + .verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -231,7 +244,6 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); verify(this.authorizedClientService, never()).saveAuthorizedClient( any(OAuth2AuthorizedClient.class), eq(this.principal)); } @@ -250,7 +262,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT .build(); Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - authorizedClient.subscribe(); + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -260,9 +274,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); } @SuppressWarnings("unchecked") @@ -274,14 +288,20 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .attribute(OAuth2ParameterNames.SCOPE, "read write") .build(); + + this.authorizedClientManager.setContextAttributesMapper(new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - authorizedClient.subscribe(); + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -293,8 +313,5 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); - StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); } }