1
0
mirror of synced 2026-05-22 13:23:17 +00:00

Allow customization of redirect strategy

The default redirect strategy will provide authorization redirect
URI within HTTP 302 response Location header.
Allowing the configuration of custom redirect strategy will provide
an option for the clients to obtain the authorization URI from e.g.
HTTP response body as JSON payload, without a need to handle
automatic redirection initiated by the HTTP Location header.

Closes gh-11373
This commit is contained in:
Igor Bolic
2022-06-17 09:42:50 +02:00
committed by Rob Winch
parent c9f8d2b111
commit efaee4e56b
27 changed files with 712 additions and 2 deletions
@@ -95,7 +95,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
@@ -139,6 +139,15 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
this.authorizationRequestResolver = authorizationRequestResolver;
}
/**
* Sets the redirect strategy for Authorization Endpoint redirect URI.
* @param authorizationRedirectStrategy the redirect strategy
*/
public void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null");
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
}
/**
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
* @param authorizationRequestRepository the repository used for storing
@@ -75,7 +75,7 @@ import org.springframework.web.util.UriComponentsBuilder;
*/
public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy();
private ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy();
private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver;
@@ -105,6 +105,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
this.authorizationRequestResolver = authorizationRequestResolver;
}
/**
* Sets the redirect strategy for Authorization Endpoint redirect URI.
* @param authorizationRedirectStrategy the redirect strategy
*/
public void setAuthorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) {
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null");
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
}
/**
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
* @param authorizationRequestRepository the repository used for storing
@@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client.web;
import java.lang.reflect.Constructor;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -30,7 +31,9 @@ import javax.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
@@ -40,6 +43,7 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.ClassUtils;
import org.springframework.web.util.UriComponentsBuilder;
@@ -116,6 +120,11 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
}
@Test
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null));
}
@Test
public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
@@ -333,4 +342,31 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
+ "login_hint=user@provider\\.com");
}
@Test
public void doFilterWhenCustomAuthorizationRedirectStrategySetThenCustomAuthorizationRedirectStrategyUsed()
throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> {
String redirectUrl = httpResponse.encodeRedirectURL(url);
httpResponse.setStatus(HttpStatus.OK.value());
httpResponse.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE);
httpResponse.getWriter().write(redirectUrl);
httpResponse.getWriter().flush();
};
this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_PLAIN_VALUE);
assertThat(response.getContentAsString(StandardCharsets.UTF_8))
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&"
+ "scope=read:user&state=.{15,}&"
+ "redirect_uri=http://localhost/login/oauth2/code/registration-id");
}
}
@@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client.web.server;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.junit.jupiter.api.BeforeEach;
@@ -24,13 +25,20 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.server.ServerRedirectStrategy;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
@@ -81,6 +89,11 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
.isThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository));
}
@Test
public void setterWhenAuthorizationRedirectStrategyNullThenIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null));
}
@Test
public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() {
// @formatter:off
@@ -195,4 +208,46 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
verifyNoInteractions(this.requestCache);
}
@Test
public void filterWhenCustomRedirectStrategySetThenRedirectUriInResponseBody() {
given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId()))
.willReturn(Mono.just(this.registration));
given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty());
ServerRedirectStrategy customRedirectStrategy = (exchange, location) -> {
ServerHttpResponse response = exchange.getResponse();
response.setStatusCode(HttpStatus.OK);
response.getHeaders().setContentType(MediaType.TEXT_PLAIN);
DataBuffer buffer = exchange.getResponse().bufferFactory()
.wrap(location.toASCIIString().getBytes(StandardCharsets.UTF_8));
return exchange.getResponse().writeWith(Flux.just(buffer));
};
this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy);
this.filter.setRequestCache(this.requestCache);
FluxExchangeResult<String> result = this.client.get()
.uri("https://example.com/oauth2/authorization/registration-id").exchange().expectHeader()
.contentType(MediaType.TEXT_PLAIN).expectStatus().isOk().returnResult(String.class);
// @formatter:off
StepVerifier.create(result.getResponseBody())
.assertNext((uri) -> {
URI location = URI.create(uri);
assertThat(location)
.hasScheme("https")
.hasHost("example.com")
.hasPath("/login/oauth/authorize")
.hasParameter("response_type", "code")
.hasParameter("client_id", "client-id")
.hasParameter("scope", "read:user")
.hasParameter("state")
.hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id");
})
.verifyComplete();
// @formatter:on
verifyNoInteractions(this.requestCache);
}
}