From 9a8ddebc94faec94c094de9bc18efea2bfe2845d Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 26 Sep 2017 15:24:08 -0400 Subject: [PATCH] Use param matching for Authorization Response Fixes gh-4576 --- ...ionCodeAuthenticationFilterConfigurer.java | 3 +- .../oauth2/client/OAuth2LoginConfigurer.java | 2 +- ...ionCodeAuthenticationProcessingFilter.java | 41 +++++++++++-------- ...uthorizationCodeRequestRedirectFilter.java | 5 +++ ...deAuthenticationProcessingFilterTests.java | 9 ++-- .../AuthorizationRequestAttributes.java | 21 +++++++++- .../oauth2/core/endpoint/OAuth2Parameter.java | 4 +- 7 files changed, 58 insertions(+), 27 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java index fd663f2aa7..19ca2d78bc 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java @@ -36,7 +36,6 @@ import org.springframework.security.oauth2.core.AccessToken; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.oidc.client.user.OidcUserService; import org.springframework.security.web.util.matcher.RequestMatcher; -import org.springframework.security.web.util.matcher.RequestVariablesExtractor; import org.springframework.util.Assert; import java.net.URI; @@ -48,7 +47,7 @@ import java.util.Map; /** * @author Joe Grandja */ -final class AuthorizationCodeAuthenticationFilterConfigurer, R extends RequestMatcher & RequestVariablesExtractor> extends +final class AuthorizationCodeAuthenticationFilterConfigurer, R extends RequestMatcher> extends AbstractAuthenticationFilterConfigurer, AuthorizationCodeAuthenticationProcessingFilter> { private R authorizationResponseMatcher; diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index e021026bd8..f340063000 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -166,7 +166,7 @@ public final class OAuth2LoginConfigurer> exten private RedirectionEndpointConfig() { } - public RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) { + public RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) { Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null"); OAuth2LoginConfigurer.this.authorizationCodeAuthenticationFilterConfigurer.authorizationResponseMatcher(authorizationResponseMatcher); return this; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java index b55bb67cf5..90ca191bed 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java @@ -37,9 +37,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; import org.springframework.security.oauth2.core.endpoint.TokenResponseAttributes; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; -import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; -import org.springframework.security.web.util.matcher.RequestVariablesExtractor; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -111,20 +109,18 @@ import java.io.IOException; */ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAuthenticationProcessingFilter { public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code"; - public static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; - public static final String DEFAULT_AUTHORIZATION_RESPONSE_URI = DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}"; private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter"; private final ErrorResponseAttributesConverter errorResponseConverter = new ErrorResponseAttributesConverter(); private final AuthorizationCodeAuthorizationResponseAttributesConverter authorizationCodeResponseConverter = new AuthorizationCodeAuthorizationResponseAttributesConverter(); - private RequestMatcher authorizationResponseMatcher = new AntPathRequestMatcher(DEFAULT_AUTHORIZATION_RESPONSE_URI); private ClientRegistrationRepository clientRegistrationRepository; + private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher(); private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); public AuthorizationCodeAuthenticationProcessingFilter() { - super(DEFAULT_AUTHORIZATION_RESPONSE_URI); + super(new AuthorizationResponseMatcher()); } @Override @@ -140,17 +136,8 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut } AuthorizationRequestAttributes matchingAuthorizationRequest = this.resolveAuthorizationRequest(request); - - String registrationId = ((RequestVariablesExtractor)this.getAuthorizationResponseMatcher()) - .extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME); - ClientRegistration clientRegistration = null; - if (!StringUtils.isEmpty(registrationId)) { - clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId); - } - if (clientRegistration == null || !matchingAuthorizationRequest.getClientId().equals(clientRegistration.getClientId())) { - OAuth2Error oauth2Error = new OAuth2Error(OAuth2Error.INVALID_REQUEST_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + String registrationId = (String)matchingAuthorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID); + ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId); // The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by // the user or configured by default. In these cases, the redirectUri will be expanded and ultimately changed @@ -180,7 +167,7 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut return this.authorizationResponseMatcher; } - public final void setAuthorizationResponseMatcher(T authorizationResponseMatcher) { + public final void setAuthorizationResponseMatcher(T authorizationResponseMatcher) { Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null"); this.authorizationResponseMatcher = authorizationResponseMatcher; this.setRequiresAuthenticationRequestMatcher(authorizationResponseMatcher); @@ -228,4 +215,22 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } } + + private static class AuthorizationResponseMatcher implements RequestMatcher { + + @Override + public boolean matches(HttpServletRequest request) { + return this.successResponse(request) || this.errorResponse(request); + } + + private boolean successResponse(HttpServletRequest request) { + return StringUtils.hasText(request.getParameter(OAuth2Parameter.CODE)) && + StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE)); + } + + private boolean errorResponse(HttpServletRequest request) { + return StringUtils.hasText(request.getParameter(OAuth2Parameter.ERROR)) && + StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE)); + } + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeRequestRedirectFilter.java index 265ee9362e..e43b04504a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeRequestRedirectFilter.java @@ -19,6 +19,7 @@ import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; +import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -122,6 +123,9 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter String redirectUriStr = this.expandRedirectUri(request, clientRegistration); + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId()); + AuthorizationRequestAttributes authorizationRequestAttributes = AuthorizationRequestAttributes.withAuthorizationCode() .clientId(clientRegistration.getClientId()) @@ -129,6 +133,7 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter .redirectUri(redirectUriStr) .scope(clientRegistration.getScope()) .state(this.stateGenerator.generateKey()) + .additionalParameters(additionalParameters) .build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java index f06bf070b5..42b2309326 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java @@ -38,9 +38,8 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; +import java.util.HashMap; +import java.util.Map; /** * Tests {@link AuthorizationCodeAuthenticationProcessingFilter}. @@ -233,6 +232,9 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { ClientRegistration clientRegistration, String state) { + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId()); + AuthorizationRequestAttributes authorizationRequestAttributes = AuthorizationRequestAttributes.withAuthorizationCode() .clientId(clientRegistration.getClientId()) @@ -240,6 +242,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { .redirectUri(clientRegistration.getRedirectUri()) .scope(clientRegistration.getScope()) .state(state) + .additionalParameters(additionalParameters) .build(); authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response); diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationRequestAttributes.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationRequestAttributes.java index 1a5b197fe4..06df2dbaf9 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationRequestAttributes.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationRequestAttributes.java @@ -21,7 +21,9 @@ import org.springframework.util.CollectionUtils; import java.io.Serializable; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.Map; import java.util.Set; /** @@ -43,6 +45,7 @@ public final class AuthorizationRequestAttributes implements Serializable { private String redirectUri; private Set scope; private String state; + private Map additionalParameters; private AuthorizationRequestAttributes() { } @@ -75,6 +78,10 @@ public final class AuthorizationRequestAttributes implements Serializable { return this.state; } + public Map getAdditionalParameters() { + return this.additionalParameters; + } + public static Builder withAuthorizationCode() { return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE); } @@ -107,8 +114,7 @@ public final class AuthorizationRequestAttributes implements Serializable { } public Builder scope(Set scope) { - this.authorizationRequest.scope = Collections.unmodifiableSet( - CollectionUtils.isEmpty(scope) ? Collections.emptySet() : new LinkedHashSet<>(scope)); + this.authorizationRequest.scope = scope; return this; } @@ -117,9 +123,20 @@ public final class AuthorizationRequestAttributes implements Serializable { return this; } + public Builder additionalParameters(Map additionalParameters) { + this.authorizationRequest.additionalParameters = additionalParameters; + return this; + } + public AuthorizationRequestAttributes build() { Assert.hasText(this.authorizationRequest.clientId, "clientId cannot be empty"); Assert.hasText(this.authorizationRequest.authorizeUri, "authorizeUri cannot be empty"); + this.authorizationRequest.scope = Collections.unmodifiableSet( + CollectionUtils.isEmpty(this.authorizationRequest.scope) ? + Collections.emptySet() : new LinkedHashSet<>(this.authorizationRequest.scope)); + this.authorizationRequest.additionalParameters = Collections.unmodifiableMap( + CollectionUtils.isEmpty(this.authorizationRequest.additionalParameters) ? + Collections.emptyMap() : new LinkedHashMap<>(this.authorizationRequest.additionalParameters)); return this.authorizationRequest; } } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2Parameter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2Parameter.java index fd3becc823..c548b3ec2f 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2Parameter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2Parameter.java @@ -16,7 +16,7 @@ package org.springframework.security.oauth2.core.endpoint; /** - * Standard parameters defined in the OAuth Parameters Registry + * Standard and additional (custom) parameters defined in the OAuth Parameters Registry * and used by the authorization endpoint and token endpoint. * * @author Joe Grandja @@ -43,4 +43,6 @@ public interface OAuth2Parameter { String ERROR_URI = "error_uri"; + String REGISTRATION_ID = "registration_id"; // Non-standard additional parameter + }