diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index bd2fcf3c40..e39ff4a86f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -48,6 +48,8 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.web.NimbusJwkSetEndpointFilter; @@ -58,6 +60,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; /** * An {@link AbstractHttpConfigurer} for OAuth 2.1 Authorization Server support. @@ -78,6 +81,7 @@ import org.springframework.util.Assert; * @see OAuth2TokenRevocationEndpointConfigurer * @see OAuth2DeviceAuthorizationEndpointConfigurer * @see OAuth2DeviceVerificationEndpointConfigurer + * @see OAuth2ClientRegistrationEndpointConfigurer * @see OidcConfigurer * @see RegisteredClientRepository * @see OAuth2AuthorizationService @@ -268,6 +272,25 @@ public final class OAuth2AuthorizationServerConfigurer return this; } + /** + * Configures the OAuth 2.0 Dynamic Client Registration Endpoint. + * @param clientRegistrationEndpointCustomizer the {@link Customizer} providing access + * to the {@link OAuth2ClientRegistrationEndpointConfigurer} + * @return the {@link OAuth2AuthorizationServerConfigurer} for further configuration + */ + public OAuth2AuthorizationServerConfigurer clientRegistrationEndpoint( + Customizer clientRegistrationEndpointCustomizer) { + OAuth2ClientRegistrationEndpointConfigurer clientRegistrationEndpointConfigurer = getConfigurer( + OAuth2ClientRegistrationEndpointConfigurer.class); + if (clientRegistrationEndpointConfigurer == null) { + addConfigurer(OAuth2ClientRegistrationEndpointConfigurer.class, + new OAuth2ClientRegistrationEndpointConfigurer(this::postProcess)); + clientRegistrationEndpointConfigurer = getConfigurer(OAuth2ClientRegistrationEndpointConfigurer.class); + } + clientRegistrationEndpointCustomizer.customize(clientRegistrationEndpointConfigurer); + return this; + } + /** * Configures OpenID Connect 1.0 support (disabled by default). * @param oidcCustomizer the {@link Customizer} providing access to the @@ -377,6 +400,12 @@ public final class OAuth2AuthorizationServerConfigurer httpSecurity.csrf((csrf) -> csrf.ignoringRequestMatchers(this.endpointsMatcher)); + if (getConfigurer(OAuth2ClientRegistrationEndpointConfigurer.class) != null) { + httpSecurity + // Accept access tokens for Client Registration + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer.jwt(Customizer.withDefaults())); + } + OidcConfigurer oidcConfigurer = getConfigurer(OidcConfigurer.class); if (oidcConfigurer != null) { if (oidcConfigurer.getConfigurer(OidcUserInfoEndpointConfigurer.class) != null @@ -392,6 +421,27 @@ public final class OAuth2AuthorizationServerConfigurer @Override public void configure(HttpSecurity httpSecurity) { + OAuth2ClientRegistrationEndpointConfigurer clientRegistrationEndpointConfigurer = getConfigurer( + OAuth2ClientRegistrationEndpointConfigurer.class); + if (clientRegistrationEndpointConfigurer != null) { + OAuth2AuthorizationServerMetadataEndpointConfigurer authorizationServerMetadataEndpointConfigurer = getConfigurer( + OAuth2AuthorizationServerMetadataEndpointConfigurer.class); + + authorizationServerMetadataEndpointConfigurer.addDefaultAuthorizationServerMetadataCustomizer((builder) -> { + AuthorizationServerContext authorizationServerContext = AuthorizationServerContextHolder.getContext(); + String issuer = authorizationServerContext.getIssuer(); + AuthorizationServerSettings authorizationServerSettings = authorizationServerContext + .getAuthorizationServerSettings(); + + String clientRegistrationEndpoint = UriComponentsBuilder.fromUriString(issuer) + .path(authorizationServerSettings.getClientRegistrationEndpoint()) + .build() + .toUriString(); + + builder.clientRegistrationEndpoint(clientRegistrationEndpoint); + }); + } + this.configurers.values().forEach((configurer) -> configurer.configure(httpSecurity)); AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationEndpointConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationEndpointConfigurer.java new file mode 100644 index 0000000000..c6b5931f53 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationEndpointConfigurer.java @@ -0,0 +1,277 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpMethod; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.config.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientRegistrationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientRegistrationAuthenticationConverter; +import org.springframework.security.web.access.intercept.AuthorizationFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.DelegatingAuthenticationConverter; +import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * Configurer for OAuth 2.0 Dynamic Client Registration Endpoint. + * + * @author Joe Grandja + * @since 7.0 + * @see OAuth2AuthorizationServerConfigurer#clientRegistrationEndpoint + * @see OAuth2ClientRegistrationEndpointFilter + */ +public final class OAuth2ClientRegistrationEndpointConfigurer extends AbstractOAuth2Configurer { + + private RequestMatcher requestMatcher; + + private final List clientRegistrationRequestConverters = new ArrayList<>(); + + private Consumer> clientRegistrationRequestConvertersConsumer = ( + clientRegistrationRequestConverters) -> { + }; + + private final List authenticationProviders = new ArrayList<>(); + + private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> { + }; + + private AuthenticationSuccessHandler clientRegistrationResponseHandler; + + private AuthenticationFailureHandler errorResponseHandler; + + private boolean openRegistrationAllowed; + + /** + * Restrict for internal use only. + * @param objectPostProcessor an {@code ObjectPostProcessor} + */ + OAuth2ClientRegistrationEndpointConfigurer(ObjectPostProcessor objectPostProcessor) { + super(objectPostProcessor); + } + + /** + * Adds an {@link AuthenticationConverter} used when attempting to extract a Client + * Registration Request from {@link HttpServletRequest} to an instance of + * {@link OAuth2ClientRegistrationAuthenticationToken} used for authenticating the + * request. + * @param clientRegistrationRequestConverter an {@link AuthenticationConverter} used + * when attempting to extract a Client Registration Request from + * {@link HttpServletRequest} + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer clientRegistrationRequestConverter( + AuthenticationConverter clientRegistrationRequestConverter) { + Assert.notNull(clientRegistrationRequestConverter, "clientRegistrationRequestConverter cannot be null"); + this.clientRegistrationRequestConverters.add(clientRegistrationRequestConverter); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default and + * (optionally) added + * {@link #clientRegistrationRequestConverter(AuthenticationConverter) + * AuthenticationConverter}'s allowing the ability to add, remove, or customize a + * specific {@link AuthenticationConverter}. + * @param clientRegistrationRequestConvertersConsumer the {@code Consumer} providing + * access to the {@code List} of default and (optionally) added + * {@link AuthenticationConverter}'s + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer clientRegistrationRequestConverters( + Consumer> clientRegistrationRequestConvertersConsumer) { + Assert.notNull(clientRegistrationRequestConvertersConsumer, + "clientRegistrationRequestConvertersConsumer cannot be null"); + this.clientRegistrationRequestConvertersConsumer = clientRegistrationRequestConvertersConsumer; + return this; + } + + /** + * Adds an {@link AuthenticationProvider} used for authenticating an + * {@link OAuth2ClientRegistrationAuthenticationToken}. + * @param authenticationProvider an {@link AuthenticationProvider} used for + * authenticating an {@link OAuth2ClientRegistrationAuthenticationToken} + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer authenticationProvider( + AuthenticationProvider authenticationProvider) { + Assert.notNull(authenticationProvider, "authenticationProvider cannot be null"); + this.authenticationProviders.add(authenticationProvider); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default and + * (optionally) added {@link #authenticationProvider(AuthenticationProvider) + * AuthenticationProvider}'s allowing the ability to add, remove, or customize a + * specific {@link AuthenticationProvider}. + * @param authenticationProvidersConsumer the {@code Consumer} providing access to the + * {@code List} of default and (optionally) added {@link AuthenticationProvider}'s + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer authenticationProviders( + Consumer> authenticationProvidersConsumer) { + Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null"); + this.authenticationProvidersConsumer = authenticationProvidersConsumer; + return this; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an + * {@link OAuth2ClientRegistrationAuthenticationToken} and returning the + * {@link OAuth2ClientRegistration Client Registration Response}. + * @param clientRegistrationResponseHandler the {@link AuthenticationSuccessHandler} + * used for handling an {@link OAuth2ClientRegistrationAuthenticationToken} + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer clientRegistrationResponseHandler( + AuthenticationSuccessHandler clientRegistrationResponseHandler) { + this.clientRegistrationResponseHandler = clientRegistrationResponseHandler; + return this; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an + * {@link OAuth2AuthenticationException} and returning the {@link OAuth2Error Error + * Response}. + * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for + * handling an {@link OAuth2AuthenticationException} + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer errorResponseHandler( + AuthenticationFailureHandler errorResponseHandler) { + this.errorResponseHandler = errorResponseHandler; + return this; + } + + /** + * Set to {@code true} if open client registration (with no initial access token) is + * allowed. The default is {@code false}. + * @param openRegistrationAllowed {@code true} if open client registration is allowed, + * {@code false} otherwise + * @return the {@link OAuth2ClientRegistrationEndpointConfigurer} for further + * configuration + */ + public OAuth2ClientRegistrationEndpointConfigurer openRegistrationAllowed(boolean openRegistrationAllowed) { + this.openRegistrationAllowed = openRegistrationAllowed; + return this; + } + + @Override + void init(HttpSecurity httpSecurity) { + AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils + .getAuthorizationServerSettings(httpSecurity); + String clientRegistrationEndpointUri = authorizationServerSettings.isMultipleIssuersAllowed() + ? OAuth2ConfigurerUtils + .withMultipleIssuersPattern(authorizationServerSettings.getClientRegistrationEndpoint()) + : authorizationServerSettings.getClientRegistrationEndpoint(); + this.requestMatcher = PathPatternRequestMatcher.withDefaults() + .matcher(HttpMethod.POST, clientRegistrationEndpointUri); + + List authenticationProviders = createDefaultAuthenticationProviders(httpSecurity, + this.openRegistrationAllowed); + if (!this.authenticationProviders.isEmpty()) { + authenticationProviders.addAll(0, this.authenticationProviders); + } + this.authenticationProvidersConsumer.accept(authenticationProviders); + authenticationProviders.forEach( + (authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider))); + } + + @Override + void configure(HttpSecurity httpSecurity) { + AuthenticationManager authenticationManager = httpSecurity.getSharedObject(AuthenticationManager.class); + AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils + .getAuthorizationServerSettings(httpSecurity); + + String clientRegistrationEndpointUri = authorizationServerSettings.isMultipleIssuersAllowed() + ? OAuth2ConfigurerUtils + .withMultipleIssuersPattern(authorizationServerSettings.getClientRegistrationEndpoint()) + : authorizationServerSettings.getClientRegistrationEndpoint(); + OAuth2ClientRegistrationEndpointFilter clientRegistrationEndpointFilter = new OAuth2ClientRegistrationEndpointFilter( + authenticationManager, clientRegistrationEndpointUri); + List authenticationConverters = createDefaultAuthenticationConverters(); + if (!this.clientRegistrationRequestConverters.isEmpty()) { + authenticationConverters.addAll(0, this.clientRegistrationRequestConverters); + } + this.clientRegistrationRequestConvertersConsumer.accept(authenticationConverters); + clientRegistrationEndpointFilter + .setAuthenticationConverter(new DelegatingAuthenticationConverter(authenticationConverters)); + if (this.clientRegistrationResponseHandler != null) { + clientRegistrationEndpointFilter.setAuthenticationSuccessHandler(this.clientRegistrationResponseHandler); + } + if (this.errorResponseHandler != null) { + clientRegistrationEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); + } + httpSecurity.addFilterAfter(postProcess(clientRegistrationEndpointFilter), AuthorizationFilter.class); + } + + @Override + RequestMatcher getRequestMatcher() { + return this.requestMatcher; + } + + private static List createDefaultAuthenticationConverters() { + List authenticationConverters = new ArrayList<>(); + + authenticationConverters.add(new OAuth2ClientRegistrationAuthenticationConverter()); + + return authenticationConverters; + } + + private static List createDefaultAuthenticationProviders(HttpSecurity httpSecurity, + boolean openRegistrationAllowed) { + List authenticationProviders = new ArrayList<>(); + + OAuth2ClientRegistrationAuthenticationProvider clientRegistrationAuthenticationProvider = new OAuth2ClientRegistrationAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + PasswordEncoder passwordEncoder = OAuth2ConfigurerUtils.getOptionalBean(httpSecurity, PasswordEncoder.class); + if (passwordEncoder != null) { + clientRegistrationAuthenticationProvider.setPasswordEncoder(passwordEncoder); + } + clientRegistrationAuthenticationProvider.setOpenRegistrationAllowed(openRegistrationAllowed); + authenticationProviders.add(clientRegistrationAuthenticationProvider); + + return authenticationProviders; + } + +} diff --git a/config/src/test/java/org/springframework/security/SerializationSamples.java b/config/src/test/java/org/springframework/security/SerializationSamples.java index f931c83bc1..33c61fd28d 100644 --- a/config/src/test/java/org/springframework/security/SerializationSamples.java +++ b/config/src/test/java/org/springframework/security/SerializationSamples.java @@ -162,6 +162,7 @@ import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadata; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; import org.springframework.security.oauth2.server.authorization.OAuth2TokenIntrospection; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; @@ -170,6 +171,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationConsentAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationToken; @@ -478,6 +480,18 @@ final class SerializationSamples { authenticationToken.setDetails(details); return authenticationToken; }); + OAuth2ClientRegistration oauth2ClientRegistration = OAuth2ClientRegistration.builder() + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .scope("scope1") + .redirectUri("https://localhost/oauth2/callback") + .build(); + generatorByClassName.put(OAuth2ClientRegistration.class, (r) -> oauth2ClientRegistration); + generatorByClassName.put(OAuth2ClientRegistrationAuthenticationToken.class, (r) -> { + OAuth2ClientRegistrationAuthenticationToken authenticationToken = new OAuth2ClientRegistrationAuthenticationToken( + principal, oauth2ClientRegistration); + authenticationToken.setDetails(details); + return authenticationToken; + }); OidcClientRegistration oidcClientRegistration = OidcClientRegistration.builder() .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) .scope("scope1") diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerMetadataTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerMetadataTests.java index 5b57cdf4ef..dcd15b69f3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerMetadataTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerMetadataTests.java @@ -36,12 +36,14 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadata; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadataClaimNames; import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository; @@ -75,6 +77,9 @@ public class OAuth2AuthorizationServerMetadataTests { public final SpringTestContext spring = new SpringTestContext(this); + @Autowired + private AuthorizationServerSettings authorizationServerSettings; + @Autowired private MockMvc mvc; @@ -156,6 +161,17 @@ public class OAuth2AuthorizationServerMetadataTests { hasItems("scope1", "scope2"))); } + @Test + public void requestWhenAuthorizationServerMetadataRequestAndClientRegistrationEnabledThenMetadataResponseIncludesRegistrationEndpoint() + throws Exception { + this.spring.register(AuthorizationServerConfigurationWithClientRegistrationEnabled.class).autowire(); + + this.mvc.perform(get(ISSUER.concat(DEFAULT_OAUTH2_AUTHORIZATION_SERVER_METADATA_ENDPOINT_URI))) + .andExpect(status().is2xxSuccessful()) + .andExpect(jsonPath("$.registration_endpoint") + .value(ISSUER.concat(this.authorizationServerSettings.getClientRegistrationEndpoint()))); + } + @EnableWebSecurity @Import(OAuth2AuthorizationServerConfiguration.class) static class AuthorizationServerConfiguration { @@ -179,6 +195,11 @@ public class OAuth2AuthorizationServerMetadataTests { return jwkSource; } + @Bean + JwtDecoder jwtDecoder(JWKSource jwkSource) { + return OAuth2AuthorizationServerConfiguration.jwtDecoder(jwkSource); + } + @Bean AuthorizationServerSettings authorizationServerSettings() { return AuthorizationServerSettings.builder().issuer(ISSUER).build(); @@ -224,4 +245,26 @@ public class OAuth2AuthorizationServerMetadataTests { } + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class AuthorizationServerConfigurationWithClientRegistrationEnabled + extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint(Customizer.withDefaults()) + ) + .authorizeHttpRequests((authorize) -> + authorize.anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationTests.java new file mode 100644 index 0000000000..a0512af77a --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientRegistrationTests.java @@ -0,0 +1,776 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import jakarta.servlet.http.HttpServletResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.assertj.core.data.TemporalUnitWithinOffset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.mock.http.MockHttpOutputMessage; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.config.test.SpringTestContextExtension; +import org.springframework.security.crypto.factory.PasswordEncoderFactories; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.converter.OAuth2ClientRegistrationRegisteredClientConverter; +import org.springframework.security.oauth2.server.authorization.converter.RegisteredClientOAuth2ClientRegistrationConverter; +import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2ClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientRegistrationAuthenticationConverter; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.util.CollectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Integration tests for OAuth 2.0 Dynamic Client Registration. + * + * @author Joe Grandja + */ +@ExtendWith(SpringTestContextExtension.class) +public class OAuth2ClientRegistrationTests { + + private static final String ISSUER = "https://example.com:8443/issuer1"; + + private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token"; + + private static final String DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI = "/oauth2/register"; + + private static final HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); + + private static final HttpMessageConverter clientRegistrationHttpMessageConverter = new OAuth2ClientRegistrationHttpMessageConverter(); + + private static EmbeddedDatabase db; + + private static JWKSource jwkSource; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private MockMvc mvc; + + @Autowired + private JdbcOperations jdbcOperations; + + @Autowired + private RegisteredClientRepository registeredClientRepository; + + private static AuthenticationConverter authenticationConverter; + + private static Consumer> authenticationConvertersConsumer; + + private static AuthenticationProvider authenticationProvider; + + private static Consumer> authenticationProvidersConsumer; + + private static AuthenticationSuccessHandler authenticationSuccessHandler; + + private static AuthenticationFailureHandler authenticationFailureHandler; + + private MockWebServer server; + + @BeforeAll + public static void init() { + JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); + jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); + db = new EmbeddedDatabaseBuilder().generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql") + .addScript( + "org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql") + .build(); + authenticationConverter = mock(AuthenticationConverter.class); + authenticationConvertersConsumer = mock(Consumer.class); + authenticationProvider = mock(AuthenticationProvider.class); + authenticationProvidersConsumer = mock(Consumer.class); + authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + } + + @BeforeEach + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + given(authenticationProvider.supports(OAuth2ClientRegistrationAuthenticationToken.class)).willReturn(true); + } + + @AfterEach + public void tearDown() throws Exception { + this.server.shutdown(); + this.jdbcOperations.update("truncate table oauth2_authorization"); + this.jdbcOperations.update("truncate table oauth2_registered_client"); + reset(authenticationConverter); + reset(authenticationConvertersConsumer); + reset(authenticationProvider); + reset(authenticationProvidersConsumer); + reset(authenticationSuccessHandler); + reset(authenticationFailureHandler); + } + + @AfterAll + public static void destroy() { + db.shutdown(); + } + + @Test + public void requestWhenClientRegistrationRequestAuthorizedThenClientRegistrationResponse() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OAuth2ClientRegistration clientRegistrationResponse = registerClient(clientRegistration); + + assertClientRegistrationResponse(clientRegistration, clientRegistrationResponse); + } + + @Test + public void requestWhenOpenClientRegistrationRequestThenClientRegistrationResponse() throws Exception { + this.spring.register(OpenClientRegistrationConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + MvcResult mvcResult = this.mvc + .perform(post(ISSUER.concat(DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI)) + .contentType(MediaType.APPLICATION_JSON) + .content(getClientRegistrationRequestContent(clientRegistration))) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andReturn(); + + OAuth2ClientRegistration clientRegistrationResponse = readClientRegistrationResponse(mvcResult.getResponse()); + + assertClientRegistrationResponse(clientRegistration, clientRegistrationResponse); + } + + @Test + public void requestWhenClientRegistrationEndpointCustomizedThenUsed() throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + willAnswer((invocation) -> { + HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.CREATED); + new OAuth2ClientRegistrationHttpMessageConverter().write(clientRegistration, null, httpResponse); + return null; + }).given(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + + registerClient(clientRegistration); + + verify(authenticationConverter).convert(any()); + ArgumentCaptor> authenticationConvertersCaptor = ArgumentCaptor + .forClass(List.class); + verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture()); + List authenticationConverters = authenticationConvertersCaptor.getValue(); + assertThat(authenticationConverters).hasSize(2) + .allMatch((converter) -> converter == authenticationConverter + || converter instanceof OAuth2ClientRegistrationAuthenticationConverter); + + verify(authenticationProvider).authenticate(any()); + ArgumentCaptor> authenticationProvidersCaptor = ArgumentCaptor + .forClass(List.class); + verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture()); + List authenticationProviders = authenticationProvidersCaptor.getValue(); + assertThat(authenticationProviders).hasSize(2) + .allMatch((provider) -> provider == authenticationProvider + || provider instanceof OAuth2ClientRegistrationAuthenticationProvider); + + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + verifyNoInteractions(authenticationFailureHandler); + } + + @Test + public void requestWhenClientRegistrationEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() + throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + given(authenticationProvider.authenticate(any())).willThrow(new OAuth2AuthenticationException("error")); + + this.mvc.perform(post(ISSUER.concat(DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI)).with(jwt())); + + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); + verifyNoInteractions(authenticationSuccessHandler); + } + + @Test + public void requestWhenClientRegistersWithSecretThenClientAuthenticationSuccess() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OAuth2ClientRegistration clientRegistrationResponse = registerClient(clientRegistration); + + this.mvc + .perform(post(ISSUER.concat(DEFAULT_TOKEN_ENDPOINT_URI)) + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .param(OAuth2ParameterNames.SCOPE, "scope1") + .with(httpBasic(clientRegistrationResponse.getClientId(), + clientRegistrationResponse.getClientSecret()))) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").value("scope1")) + .andReturn(); + } + + @Test + public void requestWhenClientRegistersWithCustomMetadataThenSavedToRegisteredClient() throws Exception { + this.spring.register(CustomClientMetadataConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .claim("custom-metadata-name-1", "value-1") + .claim("custom-metadata-name-2", "value-2") + .claim("non-registered-custom-metadata", "value-3") + .build(); + // @formatter:on + + OAuth2ClientRegistration clientRegistrationResponse = registerClient(clientRegistration); + + RegisteredClient registeredClient = this.registeredClientRepository + .findByClientId(clientRegistrationResponse.getClientId()); + + assertClientRegistrationResponse(clientRegistration, clientRegistrationResponse); + assertThat(clientRegistrationResponse.getClaim("custom-metadata-name-1")).isEqualTo("value-1"); + assertThat(clientRegistrationResponse.getClaim("custom-metadata-name-2")).isEqualTo("value-2"); + assertThat(clientRegistrationResponse.getClaim("non-registered-custom-metadata")).isNull(); + + assertThat(registeredClient.getClientSettings().getSetting("custom-metadata-name-1")) + .isEqualTo("value-1"); + assertThat(registeredClient.getClientSettings().getSetting("custom-metadata-name-2")) + .isEqualTo("value-2"); + assertThat(registeredClient.getClientSettings().getSetting("non-registered-custom-metadata")).isNull(); + } + + @Test + public void requestWhenClientRegistersWithSecretExpirationThenClientRegistrationResponse() throws Exception { + this.spring.register(ClientSecretExpirationConfiguration.class).autowire(); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OAuth2ClientRegistration clientRegistrationResponse = registerClient(clientRegistration); + + Instant expectedSecretExpiryDate = Instant.now().plus(Duration.ofHours(24)); + TemporalUnitWithinOffset allowedDelta = new TemporalUnitWithinOffset(1, ChronoUnit.MINUTES); + + // Returned response contains expiration date + assertThat(clientRegistrationResponse.getClientSecretExpiresAt()).isNotNull() + .isCloseTo(expectedSecretExpiryDate, allowedDelta); + + RegisteredClient registeredClient = this.registeredClientRepository + .findByClientId(clientRegistrationResponse.getClientId()); + + // Persisted RegisteredClient contains expiration date + assertThat(registeredClient).isNotNull(); + assertThat(registeredClient.getClientSecretExpiresAt()).isNotNull() + .isCloseTo(expectedSecretExpiryDate, allowedDelta); + } + + private OAuth2ClientRegistration registerClient(OAuth2ClientRegistration clientRegistration) throws Exception { + // ***** (1) Obtain the "initial" access token used for registering the client + + String clientRegistrationScope = "client.create"; + // @formatter:off + RegisteredClient clientRegistrar = RegisteredClient.withId("client-registrar-1") + .clientId("client-registrar-1") + .clientSecret("{noop}secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope(clientRegistrationScope) + .build(); + // @formatter:on + this.registeredClientRepository.save(clientRegistrar); + + MvcResult mvcResult = this.mvc + .perform(post(ISSUER.concat(DEFAULT_TOKEN_ENDPOINT_URI)) + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .param(OAuth2ParameterNames.SCOPE, clientRegistrationScope) + .with(httpBasic("client-registrar-1", "secret"))) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").value(clientRegistrationScope)) + .andReturn(); + + OAuth2AccessToken accessToken = readAccessTokenResponse(mvcResult.getResponse()).getAccessToken(); + + // ***** (2) Register the client + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setBearerAuth(accessToken.getTokenValue()); + + // Register the client + mvcResult = this.mvc + .perform(post(ISSUER.concat(DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI)).headers(httpHeaders) + .contentType(MediaType.APPLICATION_JSON) + .content(getClientRegistrationRequestContent(clientRegistration))) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andReturn(); + + return readClientRegistrationResponse(mvcResult.getResponse()); + } + + private static void assertClientRegistrationResponse(OAuth2ClientRegistration clientRegistrationRequest, + OAuth2ClientRegistration clientRegistrationResponse) { + assertThat(clientRegistrationResponse.getClientId()).isNotNull(); + assertThat(clientRegistrationResponse.getClientIdIssuedAt()).isNotNull(); + assertThat(clientRegistrationResponse.getClientSecret()).isNotNull(); + assertThat(clientRegistrationResponse.getClientSecretExpiresAt()).isNull(); + assertThat(clientRegistrationResponse.getClientName()).isEqualTo(clientRegistrationRequest.getClientName()); + assertThat(clientRegistrationResponse.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(clientRegistrationRequest.getRedirectUris()); + assertThat(clientRegistrationResponse.getGrantTypes()) + .containsExactlyInAnyOrderElementsOf(clientRegistrationRequest.getGrantTypes()); + assertThat(clientRegistrationResponse.getResponseTypes()) + .containsExactly(OAuth2AuthorizationResponseType.CODE.getValue()); + assertThat(clientRegistrationResponse.getScopes()) + .containsExactlyInAnyOrderElementsOf(clientRegistrationRequest.getScopes()); + assertThat(clientRegistrationResponse.getTokenEndpointAuthenticationMethod()) + .isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); + } + + private static OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) + throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse(response.getContentAsByteArray(), + HttpStatus.valueOf(response.getStatus())); + return accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); + } + + private static byte[] getClientRegistrationRequestContent(OAuth2ClientRegistration clientRegistration) + throws Exception { + MockHttpOutputMessage httpRequest = new MockHttpOutputMessage(); + clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpRequest); + return httpRequest.getBodyAsBytes(); + } + + private static OAuth2ClientRegistration readClientRegistrationResponse(MockHttpServletResponse response) + throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse(response.getContentAsByteArray(), + HttpStatus.valueOf(response.getStatus())); + return clientRegistrationHttpMessageConverter.read(OAuth2ClientRegistration.class, httpResponse); + } + + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class CustomClientRegistrationConfiguration extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint((clientRegistration) -> + clientRegistration + .clientRegistrationRequestConverter(authenticationConverter) + .clientRegistrationRequestConverters(authenticationConvertersConsumer) + .authenticationProvider(authenticationProvider) + .authenticationProviders(authenticationProvidersConsumer) + .clientRegistrationResponseHandler(authenticationSuccessHandler) + .errorResponseHandler(authenticationFailureHandler) + ) + ) + .authorizeHttpRequests((authorize) -> + authorize.anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + } + + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class CustomClientMetadataConfiguration extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint((clientRegistration) -> + clientRegistration + .authenticationProviders(configureClientRegistrationConverters()) + ) + ) + .authorizeHttpRequests((authorize) -> + authorize.anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + private Consumer> configureClientRegistrationConverters() { + // @formatter:off + return (authenticationProviders) -> + authenticationProviders.forEach((authenticationProvider) -> { + List supportedCustomClientMetadata = List.of("custom-metadata-name-1", "custom-metadata-name-2"); + if (authenticationProvider instanceof OAuth2ClientRegistrationAuthenticationProvider provider) { + provider.setRegisteredClientConverter(new CustomRegisteredClientConverter(supportedCustomClientMetadata)); + provider.setClientRegistrationConverter(new CustomClientRegistrationConverter(supportedCustomClientMetadata)); + } + }); + // @formatter:on + } + + } + + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class ClientSecretExpirationConfiguration extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint((clientRegistration) -> + clientRegistration + .authenticationProviders(configureClientRegistrationConverters()) + ) + ) + .authorizeHttpRequests((authorize) -> + authorize.anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + private Consumer> configureClientRegistrationConverters() { + // @formatter:off + return (authenticationProviders) -> + authenticationProviders.forEach((authenticationProvider) -> { + if (authenticationProvider instanceof OAuth2ClientRegistrationAuthenticationProvider provider) { + provider.setRegisteredClientConverter(new ClientSecretExpirationRegisteredClientConverter()); + } + }); + // @formatter:on + } + + } + + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class OpenClientRegistrationConfiguration extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint((clientRegistration) -> + clientRegistration + .openRegistrationAllowed(true) + ) + ) + .authorizeHttpRequests((authorize) -> + authorize + .requestMatchers("/**/oauth2/register").permitAll() + .anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + } + + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class AuthorizationServerConfiguration { + + // @formatter:off + @Bean + SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + http + .oauth2AuthorizationServer((authorizationServer) -> + authorizationServer + .clientRegistrationEndpoint(Customizer.withDefaults()) + ) + .authorizeHttpRequests((authorize) -> + authorize.anyRequest().authenticated() + ); + return http.build(); + } + // @formatter:on + + @Bean + RegisteredClientRepository registeredClientRepository(JdbcOperations jdbcOperations) { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClientParametersMapper registeredClientParametersMapper = new RegisteredClientParametersMapper(); + JdbcRegisteredClientRepository registeredClientRepository = new JdbcRegisteredClientRepository( + jdbcOperations); + registeredClientRepository.setRegisteredClientParametersMapper(registeredClientParametersMapper); + registeredClientRepository.save(registeredClient); + return registeredClientRepository; + } + + @Bean + OAuth2AuthorizationService authorizationService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository) { + return new JdbcOAuth2AuthorizationService(jdbcOperations, registeredClientRepository); + } + + @Bean + JdbcOperations jdbcOperations() { + return new JdbcTemplate(db); + } + + @Bean + JWKSource jwkSource() { + return jwkSource; + } + + @Bean + JwtDecoder jwtDecoder(JWKSource jwkSource) { + return OAuth2AuthorizationServerConfiguration.jwtDecoder(jwkSource); + } + + @Bean + AuthorizationServerSettings authorizationServerSettings() { + return AuthorizationServerSettings.builder().multipleIssuersAllowed(true).build(); + } + + @Bean + PasswordEncoder passwordEncoder() { + return PasswordEncoderFactories.createDelegatingPasswordEncoder(); + } + + } + + private static final class CustomRegisteredClientConverter + implements Converter { + + private final OAuth2ClientRegistrationRegisteredClientConverter delegate = new OAuth2ClientRegistrationRegisteredClientConverter(); + + private final List supportedCustomClientMetadata; + + private CustomRegisteredClientConverter(List supportedCustomClientMetadata) { + this.supportedCustomClientMetadata = supportedCustomClientMetadata; + } + + @Override + public RegisteredClient convert(OAuth2ClientRegistration clientRegistration) { + RegisteredClient registeredClient = this.delegate.convert(clientRegistration); + + ClientSettings.Builder clientSettingsBuilder = ClientSettings + .withSettings(registeredClient.getClientSettings().getSettings()); + if (!CollectionUtils.isEmpty(this.supportedCustomClientMetadata)) { + clientRegistration.getClaims().forEach((claim, value) -> { + if (this.supportedCustomClientMetadata.contains(claim)) { + clientSettingsBuilder.setting(claim, value); + } + }); + } + + return RegisteredClient.from(registeredClient).clientSettings(clientSettingsBuilder.build()).build(); + } + + } + + private static final class CustomClientRegistrationConverter + implements Converter { + + private final RegisteredClientOAuth2ClientRegistrationConverter delegate = new RegisteredClientOAuth2ClientRegistrationConverter(); + + private final List supportedCustomClientMetadata; + + private CustomClientRegistrationConverter(List supportedCustomClientMetadata) { + this.supportedCustomClientMetadata = supportedCustomClientMetadata; + } + + @Override + public OAuth2ClientRegistration convert(RegisteredClient registeredClient) { + OAuth2ClientRegistration clientRegistration = this.delegate.convert(registeredClient); + + Map clientMetadata = new HashMap<>(clientRegistration.getClaims()); + if (!CollectionUtils.isEmpty(this.supportedCustomClientMetadata)) { + Map clientSettings = registeredClient.getClientSettings().getSettings(); + this.supportedCustomClientMetadata.forEach((customClaim) -> { + if (clientSettings.containsKey(customClaim)) { + clientMetadata.put(customClaim, clientSettings.get(customClaim)); + } + }); + } + + return OAuth2ClientRegistration.withClaims(clientMetadata).build(); + } + + } + + /** + * This customization adds client secret expiration time by setting + * {@code RegisteredClient.clientSecretExpiresAt} during + * {@code OAuth2ClientRegistration} -> {@code RegisteredClient} conversion + */ + private static final class ClientSecretExpirationRegisteredClientConverter + implements Converter { + + private static final OAuth2ClientRegistrationRegisteredClientConverter delegate = new OAuth2ClientRegistrationRegisteredClientConverter(); + + @Override + public RegisteredClient convert(OAuth2ClientRegistration clientRegistration) { + RegisteredClient registeredClient = delegate.convert(clientRegistration); + RegisteredClient.Builder registeredClientBuilder = RegisteredClient.from(registeredClient); + + Instant clientSecretExpiresAt = Instant.now().plus(Duration.ofHours(24)); + registeredClientBuilder.clientSecretExpiresAt(clientSecretExpiresAt); + + return registeredClientBuilder.build(); + } + + } + +} diff --git a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration.serialized b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration.serialized new file mode 100644 index 0000000000..3755c77ea0 Binary files /dev/null and b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration.serialized differ diff --git a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken.serialized b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken.serialized new file mode 100644 index 0000000000..4caafa608b Binary files /dev/null and b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken.serialized differ diff --git a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration.serialized b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration.serialized index 2565b5f56c..dcb4247260 100644 Binary files a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration.serialized and b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration.serialized differ diff --git a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken.serialized b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken.serialized index 280e98ce65..e8f554db5c 100644 Binary files a/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken.serialized and b/config/src/test/resources/serialized/7.0.x/org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken.serialized differ diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AbstractOAuth2ClientRegistration.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AbstractOAuth2ClientRegistration.java new file mode 100644 index 0000000000..f020eef1a7 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/AbstractOAuth2ClientRegistration.java @@ -0,0 +1,367 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization; + +import java.io.Serial; +import java.io.Serializable; +import java.net.URI; +import java.net.URL; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.util.Assert; + +/** + * A base representation of an OAuth 2.0 Client Registration Request and Response, which + * is sent to and returned from the Client Registration Endpoint, and contains a set of + * claims about the Client's Registration information. The claims are defined by the OAuth + * 2.0 Dynamic Client Registration Protocol specification. + * + * @author Joe Grandja + * @since 7.0 + * @see OAuth2ClientMetadataClaimAccessor + * @see 3.1. Client Registration + * Request + * @see 3.2.1. Client + * Registration Response + */ +public abstract class AbstractOAuth2ClientRegistration implements OAuth2ClientMetadataClaimAccessor, Serializable { + + @Serial + private static final long serialVersionUID = 8042785346181558593L; + + private final Map claims; + + protected AbstractOAuth2ClientRegistration(Map claims) { + Assert.notEmpty(claims, "claims cannot be empty"); + this.claims = Collections.unmodifiableMap(new LinkedHashMap<>(claims)); + } + + /** + * Returns the metadata as claims. + * @return a {@code Map} of the metadata as claims + */ + @Override + public Map getClaims() { + return this.claims; + } + + /** + * A builder for subclasses of {@link AbstractOAuth2ClientRegistration}. + * + * @param the type of object + * @param the type of the builder + */ + protected abstract static class AbstractBuilder> { + + private final Map claims = new LinkedHashMap<>(); + + protected AbstractBuilder() { + } + + protected Map getClaims() { + return this.claims; + } + + @SuppressWarnings("unchecked") + protected final B getThis() { + // avoid unchecked casts in subclasses by using "getThis()" instead of "(B) + // this" + return (B) this; + } + + /** + * Sets the Client Identifier, REQUIRED. + * @param clientId the Client Identifier + * @return the {@link AbstractBuilder} for further configuration + */ + public B clientId(String clientId) { + return claim(OAuth2ClientMetadataClaimNames.CLIENT_ID, clientId); + } + + /** + * Sets the time at which the Client Identifier was issued, OPTIONAL. + * @param clientIdIssuedAt the time at which the Client Identifier was issued + * @return the {@link AbstractBuilder} for further configuration + */ + public B clientIdIssuedAt(Instant clientIdIssuedAt) { + return claim(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIdIssuedAt); + } + + /** + * Sets the Client Secret, OPTIONAL. + * @param clientSecret the Client Secret + * @return the {@link AbstractBuilder} for further configuration + */ + public B clientSecret(String clientSecret) { + return claim(OAuth2ClientMetadataClaimNames.CLIENT_SECRET, clientSecret); + } + + /** + * Sets the time at which the {@code client_secret} will expire or {@code null} if + * it will not expire, REQUIRED if {@code client_secret} was issued. + * @param clientSecretExpiresAt the time at which the {@code client_secret} will + * expire or {@code null} if it will not expire + * @return the {@link AbstractBuilder} for further configuration + */ + public B clientSecretExpiresAt(Instant clientSecretExpiresAt) { + return claim(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + } + + /** + * Sets the name of the Client to be presented to the End-User, OPTIONAL. + * @param clientName the name of the Client to be presented to the End-User + * @return the {@link AbstractBuilder} for further configuration + */ + public B clientName(String clientName) { + return claim(OAuth2ClientMetadataClaimNames.CLIENT_NAME, clientName); + } + + /** + * Add the redirection {@code URI} used by the Client, REQUIRED for redirect-based + * flows. + * @param redirectUri the redirection {@code URI} used by the Client + * @return the {@link AbstractBuilder} for further configuration + */ + public B redirectUri(String redirectUri) { + addClaimToClaimList(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, redirectUri); + return getThis(); + } + + /** + * A {@code Consumer} of the redirection {@code URI} values used by the Client, + * allowing the ability to add, replace, or remove, REQUIRED for redirect-based + * flows. + * @param redirectUrisConsumer a {@code Consumer} of the redirection {@code URI} + * values used by the Client + * @return the {@link AbstractBuilder} for further configuration + */ + public B redirectUris(Consumer> redirectUrisConsumer) { + acceptClaimValues(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, redirectUrisConsumer); + return getThis(); + } + + /** + * Sets the authentication method used by the Client for the Token Endpoint, + * OPTIONAL. + * @param tokenEndpointAuthenticationMethod the authentication method used by the + * Client for the Token Endpoint + * @return the {@link AbstractBuilder} for further configuration + */ + public B tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticationMethod) { + return claim(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod); + } + + /** + * Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to + * using, OPTIONAL. + * @param grantType the OAuth 2.0 {@code grant_type} that the Client will restrict + * itself to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B grantType(String grantType) { + addClaimToClaimList(OAuth2ClientMetadataClaimNames.GRANT_TYPES, grantType); + return getThis(); + } + + /** + * A {@code Consumer} of the OAuth 2.0 {@code grant_type} values that the Client + * will restrict itself to using, allowing the ability to add, replace, or remove, + * OPTIONAL. + * @param grantTypesConsumer a {@code Consumer} of the OAuth 2.0 + * {@code grant_type} values that the Client will restrict itself to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B grantTypes(Consumer> grantTypesConsumer) { + acceptClaimValues(OAuth2ClientMetadataClaimNames.GRANT_TYPES, grantTypesConsumer); + return getThis(); + } + + /** + * Add the OAuth 2.0 {@code response_type} that the Client will restrict itself to + * using, OPTIONAL. + * @param responseType the OAuth 2.0 {@code response_type} that the Client will + * restrict itself to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B responseType(String responseType) { + addClaimToClaimList(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, responseType); + return getThis(); + } + + /** + * A {@code Consumer} of the OAuth 2.0 {@code response_type} values that the + * Client will restrict itself to using, allowing the ability to add, replace, or + * remove, OPTIONAL. + * @param responseTypesConsumer a {@code Consumer} of the OAuth 2.0 + * {@code response_type} values that the Client will restrict itself to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B responseTypes(Consumer> responseTypesConsumer) { + acceptClaimValues(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, responseTypesConsumer); + return getThis(); + } + + /** + * Add the OAuth 2.0 {@code scope} that the Client will restrict itself to using, + * OPTIONAL. + * @param scope the OAuth 2.0 {@code scope} that the Client will restrict itself + * to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B scope(String scope) { + addClaimToClaimList(OAuth2ClientMetadataClaimNames.SCOPE, scope); + return getThis(); + } + + /** + * A {@code Consumer} of the OAuth 2.0 {@code scope} values that the Client will + * restrict itself to using, allowing the ability to add, replace, or remove, + * OPTIONAL. + * @param scopesConsumer a {@code Consumer} of the OAuth 2.0 {@code scope} values + * that the Client will restrict itself to using + * @return the {@link AbstractBuilder} for further configuration + */ + public B scopes(Consumer> scopesConsumer) { + acceptClaimValues(OAuth2ClientMetadataClaimNames.SCOPE, scopesConsumer); + return getThis(); + } + + /** + * Sets the {@code URL} for the Client's JSON Web Key Set, OPTIONAL. + * @param jwkSetUrl the {@code URL} for the Client's JSON Web Key Set + * @return the {@link AbstractBuilder} for further configuration + */ + public B jwkSetUrl(String jwkSetUrl) { + return claim(OAuth2ClientMetadataClaimNames.JWKS_URI, jwkSetUrl); + } + + /** + * Sets the claim. + * @param name the claim name + * @param value the claim value + * @return the {@link AbstractBuilder} for further configuration + */ + public B claim(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.claims.put(name, value); + return getThis(); + } + + /** + * Provides access to every {@link #claim(String, Object)} declared so far + * allowing the ability to add, replace, or remove. + * @param claimsConsumer a {@code Consumer} of the claims + * @return the {@link AbstractBuilder} for further configurations + */ + public B claims(Consumer> claimsConsumer) { + claimsConsumer.accept(this.claims); + return getThis(); + } + + /** + * Validate the claims and build the {@link AbstractOAuth2ClientRegistration}. + * @return the {@link AbstractOAuth2ClientRegistration} + */ + public abstract T build(); + + protected void validate() { + if (this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null + || this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_SECRET) != null) { + Assert.notNull(this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_ID), "client_id cannot be null"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null) { + Assert.isInstanceOf(Instant.class, this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT), + "client_id_issued_at must be of type Instant"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT) != null) { + Assert.notNull(this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_SECRET), + "client_secret cannot be null"); + Assert.isInstanceOf(Instant.class, + this.claims.get(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT), + "client_secret_expires_at must be of type Instant"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.REDIRECT_URIS) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OAuth2ClientMetadataClaimNames.REDIRECT_URIS), + "redirect_uris must be of type List"); + Assert.notEmpty((List) this.claims.get(OAuth2ClientMetadataClaimNames.REDIRECT_URIS), + "redirect_uris cannot be empty"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.GRANT_TYPES) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OAuth2ClientMetadataClaimNames.GRANT_TYPES), + "grant_types must be of type List"); + Assert.notEmpty((List) this.claims.get(OAuth2ClientMetadataClaimNames.GRANT_TYPES), + "grant_types cannot be empty"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES), + "response_types must be of type List"); + Assert.notEmpty((List) this.claims.get(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES), + "response_types cannot be empty"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.SCOPE) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OAuth2ClientMetadataClaimNames.SCOPE), + "scope must be of type List"); + Assert.notEmpty((List) this.claims.get(OAuth2ClientMetadataClaimNames.SCOPE), + "scope cannot be empty"); + } + if (this.claims.get(OAuth2ClientMetadataClaimNames.JWKS_URI) != null) { + validateURL(this.claims.get(OAuth2ClientMetadataClaimNames.JWKS_URI), "jwksUri must be a valid URL"); + } + } + + @SuppressWarnings("unchecked") + private void addClaimToClaimList(String name, String value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.claims.computeIfAbsent(name, (k) -> new LinkedList()); + ((List) this.claims.get(name)).add(value); + } + + @SuppressWarnings("unchecked") + private void acceptClaimValues(String name, Consumer> valuesConsumer) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(valuesConsumer, "valuesConsumer cannot be null"); + this.claims.computeIfAbsent(name, (k) -> new LinkedList()); + List values = (List) this.claims.get(name); + valuesConsumer.accept(values); + } + + private static void validateURL(Object url, String errorMessage) { + if (URL.class.isAssignableFrom(url.getClass())) { + return; + } + + try { + new URI(url.toString()).toURL(); + } + catch (Exception ex) { + throw new IllegalArgumentException(errorMessage, ex); + } + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimAccessor.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimAccessor.java new file mode 100644 index 0000000000..655ca7b7c6 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimAccessor.java @@ -0,0 +1,138 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization; + +import java.net.URL; +import java.time.Instant; +import java.util.List; + +import org.springframework.security.oauth2.core.ClaimAccessor; + +/** + * A {@link ClaimAccessor} for the claims that are contained in the OAuth 2.0 Client + * Registration Request and Response. + * + * @author Joe Grandja + * @since 7.0 + * @see ClaimAccessor + * @see OAuth2ClientMetadataClaimNames + * @see OAuth2ClientRegistration + * @see 2. Client Metadata + */ +public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor { + + /** + * Returns the Client Identifier {@code (client_id)}. + * @return the Client Identifier + */ + default String getClientId() { + return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_ID); + } + + /** + * Returns the time at which the Client Identifier was issued + * {@code (client_id_issued_at)}. + * @return the time at which the Client Identifier was issued + */ + default Instant getClientIdIssuedAt() { + return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT); + } + + /** + * Returns the Client Secret {@code (client_secret)}. + * @return the Client Secret + */ + default String getClientSecret() { + return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_SECRET); + } + + /** + * Returns the time at which the {@code client_secret} will expire + * {@code (client_secret_expires_at)}. + * @return the time at which the {@code client_secret} will expire + */ + default Instant getClientSecretExpiresAt() { + return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); + } + + /** + * Returns the name of the Client to be presented to the End-User + * {@code (client_name)}. + * @return the name of the Client to be presented to the End-User + */ + default String getClientName() { + return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_NAME); + } + + /** + * Returns the redirection {@code URI} values used by the Client + * {@code (redirect_uris)}. + * @return the redirection {@code URI} values used by the Client + */ + default List getRedirectUris() { + return getClaimAsStringList(OAuth2ClientMetadataClaimNames.REDIRECT_URIS); + } + + /** + * Returns the authentication method used by the Client for the Token Endpoint + * {@code (token_endpoint_auth_method)}. + * @return the authentication method used by the Client for the Token Endpoint + */ + default String getTokenEndpointAuthenticationMethod() { + return getClaimAsString(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); + } + + /** + * Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict + * itself to using {@code (grant_types)}. + * @return the OAuth 2.0 {@code grant_type} values that the Client will restrict + * itself to using + */ + default List getGrantTypes() { + return getClaimAsStringList(OAuth2ClientMetadataClaimNames.GRANT_TYPES); + } + + /** + * Returns the OAuth 2.0 {@code response_type} values that the Client will restrict + * itself to using {@code (response_types)}. + * @return the OAuth 2.0 {@code response_type} values that the Client will restrict + * itself to using + */ + default List getResponseTypes() { + return getClaimAsStringList(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES); + } + + /** + * Returns the OAuth 2.0 {@code scope} values that the Client will restrict itself to + * using {@code (scope)}. + * @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to + * using + */ + default List getScopes() { + return getClaimAsStringList(OAuth2ClientMetadataClaimNames.SCOPE); + } + + /** + * Returns the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}. + * @return the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)} + */ + default URL getJwkSetUrl() { + return getClaimAsURL(OAuth2ClientMetadataClaimNames.JWKS_URI); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimNames.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimNames.java new file mode 100644 index 0000000000..7ec6ad4358 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientMetadataClaimNames.java @@ -0,0 +1,93 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization; + +/** + * The names of the claims defined by OAuth 2.0 Dynamic Client Registration Protocol that + * are contained in the OAuth 2.0 Client Registration Request and Response. + * + * @author Joe Grandja + * @since 7.0 + * @see 2. Client Metadata + */ +public class OAuth2ClientMetadataClaimNames { + + /** + * {@code client_id} - the Client Identifier + */ + public static final String CLIENT_ID = "client_id"; + + /** + * {@code client_id_issued_at} - the time at which the Client Identifier was issued + */ + public static final String CLIENT_ID_ISSUED_AT = "client_id_issued_at"; + + /** + * {@code client_secret} - the Client Secret + */ + public static final String CLIENT_SECRET = "client_secret"; + + /** + * {@code client_secret_expires_at} - the time at which the {@code client_secret} will + * expire or 0 if it will not expire + */ + public static final String CLIENT_SECRET_EXPIRES_AT = "client_secret_expires_at"; + + /** + * {@code client_name} - the name of the Client to be presented to the End-User + */ + public static final String CLIENT_NAME = "client_name"; + + /** + * {@code redirect_uris} - the redirection {@code URI} values used by the Client + */ + public static final String REDIRECT_URIS = "redirect_uris"; + + /** + * {@code token_endpoint_auth_method} - the authentication method used by the Client + * for the Token Endpoint + */ + public static final String TOKEN_ENDPOINT_AUTH_METHOD = "token_endpoint_auth_method"; + + /** + * {@code grant_types} - the OAuth 2.0 {@code grant_type} values that the Client will + * restrict itself to using + */ + public static final String GRANT_TYPES = "grant_types"; + + /** + * {@code response_types} - the OAuth 2.0 {@code response_type} values that the Client + * will restrict itself to using + */ + public static final String RESPONSE_TYPES = "response_types"; + + /** + * {@code scope} - a space-separated list of OAuth 2.0 {@code scope} values that the + * Client will restrict itself to using + */ + public static final String SCOPE = "scope"; + + /** + * {@code jwks_uri} - the {@code URL} for the Client's JSON Web Key Set + */ + public static final String JWKS_URI = "jwks_uri"; + + protected OAuth2ClientMetadataClaimNames() { + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistration.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistration.java new file mode 100644 index 0000000000..f6627f0996 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistration.java @@ -0,0 +1,87 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization; + +import java.io.Serial; +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * A representation of an OAuth 2.0 Client Registration Request and Response, which is + * sent to and returned from the Client Registration Endpoint, and contains a set of + * claims about the Client's Registration information. The claims are defined by the OAuth + * 2.0 Dynamic Client Registration Protocol specification. + * + * @author Joe Grandja + * @since 7.0 + * @see AbstractOAuth2ClientRegistration + * @see 3.1. Client Registration + * Request + * @see 3.2.1. Client + * Registration Response + */ +public final class OAuth2ClientRegistration extends AbstractOAuth2ClientRegistration { + + @Serial + private static final long serialVersionUID = 283805553286847831L; + + private OAuth2ClientRegistration(Map claims) { + super(claims); + } + + /** + * Constructs a new {@link Builder} with empty claims. + * @return the {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Constructs a new {@link Builder} with the provided claims. + * @param claims the claims to initialize the builder + * @return the {@link Builder} + */ + public static Builder withClaims(Map claims) { + Assert.notEmpty(claims, "claims cannot be empty"); + return new Builder().claims((c) -> c.putAll(claims)); + } + + /** + * Helps configure an {@link OAuth2ClientRegistration}. + */ + public static final class Builder extends AbstractBuilder { + + private Builder() { + } + + /** + * Validate the claims and build the {@link OAuth2ClientRegistration}. + * @return the {@link OAuth2ClientRegistration} + */ + @Override + public OAuth2ClientRegistration build() { + validate(); + return new OAuth2ClientRegistration(getClaims()); + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProvider.java new file mode 100644 index 0000000000..54a910ac0a --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProvider.java @@ -0,0 +1,305 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.crypto.factory.PasswordEncoderFactories; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientMetadataClaimNames; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.converter.OAuth2ClientRegistrationRegisteredClientConverter; +import org.springframework.security.oauth2.server.authorization.converter.RegisteredClientOAuth2ClientRegistrationConverter; +import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Dynamic Client + * Registration Endpoint. + * + * @author Joe Grandja + * @since 7.0 + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see OAuth2ClientRegistrationAuthenticationToken + * @see PasswordEncoder + * @see 3. Client + * Registration Endpoint + */ +public final class OAuth2ClientRegistrationAuthenticationProvider implements AuthenticationProvider { + + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2"; + + private static final String DEFAULT_CLIENT_REGISTRATION_AUTHORIZED_SCOPE = "client.create"; + + private final Log logger = LogFactory.getLog(getClass()); + + private final RegisteredClientRepository registeredClientRepository; + + private final OAuth2AuthorizationService authorizationService; + + private Converter clientRegistrationConverter; + + private Converter registeredClientConverter; + + private PasswordEncoder passwordEncoder; + + private boolean openRegistrationAllowed; + + /** + * Constructs an {@code OAuth2ClientRegistrationAuthenticationProvider} using the + * provided parameters. + * @param registeredClientRepository the repository of registered clients + */ + public OAuth2ClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + this.clientRegistrationConverter = new RegisteredClientOAuth2ClientRegistrationConverter(); + this.registeredClientConverter = new OAuth2ClientRegistrationRegisteredClientConverter(); + this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OAuth2ClientRegistrationAuthenticationToken) authentication; + + // Check if "initial" access token is not provided + AbstractOAuth2TokenAuthenticationToken accessTokenAuthentication = null; + if (clientRegistrationAuthentication.getPrincipal() != null && AbstractOAuth2TokenAuthenticationToken.class + .isAssignableFrom(clientRegistrationAuthentication.getPrincipal().getClass())) { + accessTokenAuthentication = (AbstractOAuth2TokenAuthenticationToken) clientRegistrationAuthentication + .getPrincipal(); + } + if (accessTokenAuthentication == null) { + if (this.openRegistrationAllowed) { + return registerClient(clientRegistrationAuthentication, null); + } + else { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + } + + // Validate the "initial" access token + if (!accessTokenAuthentication.isAuthenticated()) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + String accessTokenValue = accessTokenAuthentication.getToken().getTokenValue(); + OAuth2Authorization authorization = this.authorizationService.findByToken(accessTokenValue, + OAuth2TokenType.ACCESS_TOKEN); + if (authorization == null) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Retrieved authorization with initial access token"); + } + + OAuth2Authorization.Token authorizedAccessToken = authorization.getAccessToken(); + if (!authorizedAccessToken.isActive()) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + checkScope(authorizedAccessToken, Collections.singleton(DEFAULT_CLIENT_REGISTRATION_AUTHORIZED_SCOPE)); + + return registerClient(clientRegistrationAuthentication, authorization); + } + + @Override + public boolean supports(Class authentication) { + return OAuth2ClientRegistrationAuthenticationToken.class.isAssignableFrom(authentication); + } + + /** + * Sets the {@link Converter} used for converting an {@link OAuth2ClientRegistration} + * to a {@link RegisteredClient}. + * @param registeredClientConverter the {@link Converter} used for converting an + * {@link OAuth2ClientRegistration} to a {@link RegisteredClient} + */ + public void setRegisteredClientConverter( + Converter registeredClientConverter) { + Assert.notNull(registeredClientConverter, "registeredClientConverter cannot be null"); + this.registeredClientConverter = registeredClientConverter; + } + + /** + * Sets the {@link Converter} used for converting a {@link RegisteredClient} to an + * {@link OAuth2ClientRegistration}. + * @param clientRegistrationConverter the {@link Converter} used for converting a + * {@link RegisteredClient} to an {@link OAuth2ClientRegistration} + */ + public void setClientRegistrationConverter( + Converter clientRegistrationConverter) { + Assert.notNull(clientRegistrationConverter, "clientRegistrationConverter cannot be null"); + this.clientRegistrationConverter = clientRegistrationConverter; + } + + /** + * Sets the {@link PasswordEncoder} used to encode the + * {@link RegisteredClient#getClientSecret() client secret}. If not set, the client + * secret will be encoded using + * {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()}. + * @param passwordEncoder the {@link PasswordEncoder} used to encode the client secret + */ + public void setPasswordEncoder(PasswordEncoder passwordEncoder) { + Assert.notNull(passwordEncoder, "passwordEncoder cannot be null"); + this.passwordEncoder = passwordEncoder; + } + + /** + * Set to {@code true} if open client registration (with no initial access token) is + * allowed. The default is {@code false}. + * @param openRegistrationAllowed {@code true} if open client registration is allowed, + * {@code false} otherwise + */ + public void setOpenRegistrationAllowed(boolean openRegistrationAllowed) { + this.openRegistrationAllowed = openRegistrationAllowed; + } + + private OAuth2ClientRegistrationAuthenticationToken registerClient( + OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthentication, + OAuth2Authorization authorization) { + + if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getRedirectUris())) { + throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI, + OAuth2ClientMetadataClaimNames.REDIRECT_URIS); + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Validated client registration request parameters"); + } + + RegisteredClient registeredClient = this.registeredClientConverter + .convert(clientRegistrationAuthentication.getClientRegistration()); + + if (StringUtils.hasText(registeredClient.getClientSecret())) { + // Encode the client secret + RegisteredClient updatedRegisteredClient = RegisteredClient.from(registeredClient) + .clientSecret(this.passwordEncoder.encode(registeredClient.getClientSecret())) + .build(); + this.registeredClientRepository.save(updatedRegisteredClient); + if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue() + .equals(clientRegistrationAuthentication.getClientRegistration() + .getTokenEndpointAuthenticationMethod())) { + // Return the hashed client_secret + registeredClient = updatedRegisteredClient; + } + } + else { + this.registeredClientRepository.save(registeredClient); + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Saved registered client"); + } + + if (authorization != null) { + // Invalidate the "initial" access token as it can only be used once + OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization) + .invalidate(authorization.getAccessToken().getToken()); + if (authorization.getRefreshToken() != null) { + builder.invalidate(authorization.getRefreshToken().getToken()); + } + authorization = builder.build(); + this.authorizationService.save(authorization); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Saved authorization with invalidated initial access token"); + } + } + + OAuth2ClientRegistration clientRegistration = this.clientRegistrationConverter.convert(registeredClient); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Authenticated client registration request"); + } + + OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = new OAuth2ClientRegistrationAuthenticationToken( + (Authentication) clientRegistrationAuthentication.getPrincipal(), clientRegistration); + clientRegistrationAuthenticationResult.setDetails(clientRegistrationAuthentication.getDetails()); + return clientRegistrationAuthenticationResult; + } + + @SuppressWarnings("unchecked") + private static void checkScope(OAuth2Authorization.Token authorizedAccessToken, + Set requiredScope) { + Collection authorizedScope = Collections.emptySet(); + if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) { + authorizedScope = (Collection) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE); + } + if (!authorizedScope.containsAll(requiredScope)) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + } + else if (authorizedScope.size() != requiredScope.size()) { + // Restrict the access token to only contain the required scope + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + } + + private static boolean isValidRedirectUris(List redirectUris) { + if (CollectionUtils.isEmpty(redirectUris)) { + return true; + } + + for (String redirectUri : redirectUris) { + try { + URI validRedirectUri = new URI(redirectUri); + if (validRedirectUri.getFragment() != null) { + return false; + } + } + catch (URISyntaxException ex) { + return false; + } + } + + return true; + } + + private static void throwInvalidClientRegistration(String errorCode, String fieldName) { + OAuth2Error error = new OAuth2Error(errorCode, "Invalid Client Registration: " + fieldName, ERROR_URI); + throw new OAuth2AuthenticationException(error); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationToken.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationToken.java new file mode 100644 index 0000000000..5204182052 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationToken.java @@ -0,0 +1,84 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.io.Serial; +import java.util.Collections; + +import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.util.Assert; + +/** + * An {@link Authentication} implementation used for the OAuth 2.0 Dynamic Client + * Registration Endpoint. + * + * @author Joe Grandja + * @since 7.0 + * @see AbstractAuthenticationToken + * @see OAuth2ClientRegistration + * @see OAuth2ClientRegistrationAuthenticationProvider + */ +public class OAuth2ClientRegistrationAuthenticationToken extends AbstractAuthenticationToken { + + @Serial + private static final long serialVersionUID = 7135429161909989115L; + + @Nullable + private final Authentication principal; + + private final OAuth2ClientRegistration clientRegistration; + + /** + * Constructs an {@code OAuth2ClientRegistrationAuthenticationToken} using the + * provided parameters. + * @param principal the authenticated principal + * @param clientRegistration the client registration + */ + public OAuth2ClientRegistrationAuthenticationToken(@Nullable Authentication principal, + OAuth2ClientRegistration clientRegistration) { + super(Collections.emptyList()); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.principal = principal; + this.clientRegistration = clientRegistration; + if (principal != null) { + setAuthenticated(principal.isAuthenticated()); + } + } + + @Nullable + @Override + public Object getPrincipal() { + return this.principal; + } + + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the client registration. + * @return the client registration + */ + public OAuth2ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/OAuth2ClientRegistrationRegisteredClientConverter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/OAuth2ClientRegistrationRegisteredClientConverter.java new file mode 100644 index 0000000000..ebad12cd0f --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/OAuth2ClientRegistrationRegisteredClientConverter.java @@ -0,0 +1,110 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.converter; + +import java.time.Instant; +import java.util.Base64; +import java.util.UUID; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; +import org.springframework.util.CollectionUtils; + +/** + * A {@link Converter} that converts the provided {@link OAuth2ClientRegistration} to a + * {@link RegisteredClient}. + * + * @author Joe Grandja + * @since 7.0 + */ +public final class OAuth2ClientRegistrationRegisteredClientConverter + implements Converter { + + private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 32); + + private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 48); + + @Override + public RegisteredClient convert(OAuth2ClientRegistration clientRegistration) { + // @formatter:off + RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString()) + .clientId(CLIENT_ID_GENERATOR.generateKey()) + .clientIdIssuedAt(Instant.now()) + .clientName(clientRegistration.getClientName()); + + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) { + builder + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .clientSecret(CLIENT_SECRET_GENERATOR.generateKey()); + } + else if (ClientAuthenticationMethod.NONE.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) { + builder.clientAuthenticationMethod(ClientAuthenticationMethod.NONE); + } + else { + builder + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .clientSecret(CLIENT_SECRET_GENERATOR.generateKey()); + } + + if (!CollectionUtils.isEmpty(clientRegistration.getRedirectUris())) { + builder.redirectUris((redirectUris) -> + redirectUris.addAll(clientRegistration.getRedirectUris())); + } + + if (!CollectionUtils.isEmpty(clientRegistration.getGrantTypes())) { + builder.authorizationGrantTypes((authorizationGrantTypes) -> + clientRegistration.getGrantTypes().forEach((grantType) -> + authorizationGrantTypes.add(new AuthorizationGrantType(grantType)))); + } + else { + builder.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); + } + + if (!CollectionUtils.isEmpty(clientRegistration.getResponseTypes()) && + clientRegistration.getResponseTypes().contains(OAuth2AuthorizationResponseType.CODE.getValue())) { + builder.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); + } + + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + builder.scopes((scopes) -> + scopes.addAll(clientRegistration.getScopes())); + } + + ClientSettings.Builder clientSettingsBuilder = ClientSettings.builder() + .requireProofKey(true) + .requireAuthorizationConsent(true); + if (clientRegistration.getJwkSetUrl() != null) { + clientSettingsBuilder.jwkSetUrl(clientRegistration.getJwkSetUrl().toString()); + } + + builder + .clientSettings(clientSettingsBuilder.build()); + + return builder.build(); + // @formatter:on + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/RegisteredClientOAuth2ClientRegistrationConverter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/RegisteredClientOAuth2ClientRegistrationConverter.java new file mode 100644 index 0000000000..7339dfcf10 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/converter/RegisteredClientOAuth2ClientRegistrationConverter.java @@ -0,0 +1,84 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.converter; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; +import org.springframework.util.CollectionUtils; + +/** + * A {@link Converter} that converts the provided {@link RegisteredClient} to an + * {@link OAuth2ClientRegistration}. + * + * @author Joe Grandja + * @since 7.0 + */ +public final class RegisteredClientOAuth2ClientRegistrationConverter + implements Converter { + + @Override + public OAuth2ClientRegistration convert(RegisteredClient registeredClient) { + // @formatter:off + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .clientId(registeredClient.getClientId()) + .clientIdIssuedAt(registeredClient.getClientIdIssuedAt()) + .clientName(registeredClient.getClientName()); + + builder + .tokenEndpointAuthenticationMethod(registeredClient.getClientAuthenticationMethods().iterator().next().getValue()); + + if (registeredClient.getClientSecret() != null) { + builder.clientSecret(registeredClient.getClientSecret()); + } + + if (registeredClient.getClientSecretExpiresAt() != null) { + builder.clientSecretExpiresAt(registeredClient.getClientSecretExpiresAt()); + } + + if (!CollectionUtils.isEmpty(registeredClient.getRedirectUris())) { + builder.redirectUris((redirectUris) -> + redirectUris.addAll(registeredClient.getRedirectUris())); + } + + builder.grantTypes((grantTypes) -> + registeredClient.getAuthorizationGrantTypes().forEach((authorizationGrantType) -> + grantTypes.add(authorizationGrantType.getValue()))); + + if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { + builder.responseType(OAuth2AuthorizationResponseType.CODE.getValue()); + } + + if (!CollectionUtils.isEmpty(registeredClient.getScopes())) { + builder.scopes((scopes) -> + scopes.addAll(registeredClient.getScopes())); + } + + ClientSettings clientSettings = registeredClient.getClientSettings(); + + if (clientSettings.getJwkSetUrl() != null) { + builder.jwkSetUrl(clientSettings.getJwkSetUrl()); + } + + return builder.build(); + // @formatter:on + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverter.java new file mode 100644 index 0000000000..d9b2777152 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverter.java @@ -0,0 +1,233 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.http.converter; + +import java.net.URL; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.security.oauth2.core.converter.ClaimConversionService; +import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientMetadataClaimNames; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * A {@link HttpMessageConverter} for an {@link OAuth2ClientRegistration OAuth 2.0 Dynamic + * Client Registration Request and Response}. + * + * @author Joe Grandja + * @since 7.0 + * @see AbstractHttpMessageConverter + * @see OAuth2ClientRegistration + */ +public class OAuth2ClientRegistrationHttpMessageConverter + extends AbstractHttpMessageConverter { + + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() { + }; + + private final GenericHttpMessageConverter jsonMessageConverter = HttpMessageConverters + .getJsonMessageConverter(); + + private Converter, OAuth2ClientRegistration> clientRegistrationConverter = new MapOAuth2ClientRegistrationConverter(); + + private Converter> clientRegistrationParametersConverter = new OAuth2ClientRegistrationMapConverter(); + + public OAuth2ClientRegistrationHttpMessageConverter() { + super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); + } + + @Override + protected boolean supports(Class clazz) { + return OAuth2ClientRegistration.class.isAssignableFrom(clazz); + } + + @Override + @SuppressWarnings("unchecked") + protected OAuth2ClientRegistration readInternal(Class clazz, + HttpInputMessage inputMessage) throws HttpMessageNotReadableException { + try { + Map clientRegistrationParameters = (Map) this.jsonMessageConverter + .read(STRING_OBJECT_MAP.getType(), null, inputMessage); + return this.clientRegistrationConverter.convert(clientRegistrationParameters); + } + catch (Exception ex) { + throw new HttpMessageNotReadableException( + "An error occurred reading the OAuth 2.0 Client Registration: " + ex.getMessage(), ex, + inputMessage); + } + } + + @Override + protected void writeInternal(OAuth2ClientRegistration clientRegistration, HttpOutputMessage outputMessage) + throws HttpMessageNotWritableException { + try { + Map clientRegistrationParameters = this.clientRegistrationParametersConverter + .convert(clientRegistration); + this.jsonMessageConverter.write(clientRegistrationParameters, STRING_OBJECT_MAP.getType(), + MediaType.APPLICATION_JSON, outputMessage); + } + catch (Exception ex) { + throw new HttpMessageNotWritableException( + "An error occurred writing the OAuth 2.0 Client Registration: " + ex.getMessage(), ex); + } + } + + /** + * Sets the {@link Converter} used for converting the OAuth 2.0 Client Registration + * parameters to an {@link OAuth2ClientRegistration}. + * @param clientRegistrationConverter the {@link Converter} used for converting to an + * {@link OAuth2ClientRegistration} + */ + public final void setClientRegistrationConverter( + Converter, OAuth2ClientRegistration> clientRegistrationConverter) { + Assert.notNull(clientRegistrationConverter, "clientRegistrationConverter cannot be null"); + this.clientRegistrationConverter = clientRegistrationConverter; + } + + /** + * Sets the {@link Converter} used for converting the {@link OAuth2ClientRegistration} + * to a {@code Map} representation of the OAuth 2.0 Client Registration parameters. + * @param clientRegistrationParametersConverter the {@link Converter} used for + * converting to a {@code Map} representation of the OAuth 2.0 Client Registration + * parameters + */ + public final void setClientRegistrationParametersConverter( + Converter> clientRegistrationParametersConverter) { + Assert.notNull(clientRegistrationParametersConverter, "clientRegistrationParametersConverter cannot be null"); + this.clientRegistrationParametersConverter = clientRegistrationParametersConverter; + } + + private static final class MapOAuth2ClientRegistrationConverter + implements Converter, OAuth2ClientRegistration> { + + private static final ClaimConversionService CLAIM_CONVERSION_SERVICE = ClaimConversionService + .getSharedInstance(); + + private static final TypeDescriptor OBJECT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Object.class); + + private static final TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); + + private static final TypeDescriptor INSTANT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Instant.class); + + private static final TypeDescriptor URL_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(URL.class); + + private static final Converter INSTANT_CONVERTER = getConverter(INSTANT_TYPE_DESCRIPTOR); + + private final ClaimTypeConverter claimTypeConverter; + + private MapOAuth2ClientRegistrationConverter() { + Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); + Converter collectionStringConverter = getConverter( + TypeDescriptor.collection(Collection.class, STRING_TYPE_DESCRIPTOR)); + Converter urlConverter = getConverter(URL_TYPE_DESCRIPTOR); + + Map> claimConverters = new HashMap<>(); + claimConverters.put(OAuth2ClientMetadataClaimNames.CLIENT_ID, stringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, INSTANT_CONVERTER); + claimConverters.put(OAuth2ClientMetadataClaimNames.CLIENT_SECRET, stringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, + MapOAuth2ClientRegistrationConverter::convertClientSecretExpiresAt); + claimConverters.put(OAuth2ClientMetadataClaimNames.CLIENT_NAME, stringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, collectionStringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, stringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.GRANT_TYPES, collectionStringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, collectionStringConverter); + claimConverters.put(OAuth2ClientMetadataClaimNames.SCOPE, + MapOAuth2ClientRegistrationConverter::convertScope); + claimConverters.put(OAuth2ClientMetadataClaimNames.JWKS_URI, urlConverter); + this.claimTypeConverter = new ClaimTypeConverter(claimConverters); + } + + @Override + public OAuth2ClientRegistration convert(Map source) { + Map parsedClaims = this.claimTypeConverter.convert(source); + Object clientSecretExpiresAt = parsedClaims.get(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); + if (clientSecretExpiresAt instanceof Number && clientSecretExpiresAt.equals(0)) { + parsedClaims.remove(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); + } + return OAuth2ClientRegistration.withClaims(parsedClaims).build(); + } + + private static Converter getConverter(TypeDescriptor targetDescriptor) { + return (source) -> CLAIM_CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor); + } + + private static Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) { + if (clientSecretExpiresAt != null && String.valueOf(clientSecretExpiresAt).equals("0")) { + // 0 indicates that client_secret_expires_at does not expire + return null; + } + return (Instant) INSTANT_CONVERTER.convert(clientSecretExpiresAt); + } + + private static List convertScope(Object scope) { + if (scope == null) { + return Collections.emptyList(); + } + return Arrays.asList(StringUtils.delimitedListToStringArray(scope.toString(), " ")); + } + + } + + private static final class OAuth2ClientRegistrationMapConverter + implements Converter> { + + @Override + public Map convert(OAuth2ClientRegistration source) { + Map responseClaims = new LinkedHashMap<>(source.getClaims()); + if (source.getClientIdIssuedAt() != null) { + responseClaims.put(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, + source.getClientIdIssuedAt().getEpochSecond()); + } + if (source.getClientSecret() != null) { + long clientSecretExpiresAt = 0; + if (source.getClientSecretExpiresAt() != null) { + clientSecretExpiresAt = source.getClientSecretExpiresAt().getEpochSecond(); + } + responseClaims.put(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + } + if (!CollectionUtils.isEmpty(source.getScopes())) { + responseClaims.put(OAuth2ClientMetadataClaimNames.SCOPE, + StringUtils.collectionToDelimitedString(source.getScopes(), " ")); + } + return responseClaims; + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java index 2e0f9ed40c..9fad4443d4 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java @@ -17,7 +17,6 @@ package org.springframework.security.oauth2.server.authorization.oidc; import java.net.URL; -import java.time.Instant; import java.util.List; import org.springframework.security.oauth2.core.ClaimAccessor; @@ -26,6 +25,7 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientMetadataClaimAccessor; /** * A {@link ClaimAccessor} for the "claims" that are contained in the OpenID Client @@ -34,7 +34,7 @@ import org.springframework.security.oauth2.jwt.Jwt; * @author Ovidiu Popa * @author Joe Grandja * @since 7.0 - * @see ClaimAccessor + * @see OAuth2ClientMetadataClaimAccessor * @see OidcClientMetadataClaimNames * @see OidcClientRegistration * @see 3.1. * Client Registration Metadata */ -public interface OidcClientMetadataClaimAccessor extends ClaimAccessor { - - /** - * Returns the Client Identifier {@code (client_id)}. - * @return the Client Identifier - */ - default String getClientId() { - return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_ID); - } - - /** - * Returns the time at which the Client Identifier was issued - * {@code (client_id_issued_at)}. - * @return the time at which the Client Identifier was issued - */ - default Instant getClientIdIssuedAt() { - return getClaimAsInstant(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT); - } - - /** - * Returns the Client Secret {@code (client_secret)}. - * @return the Client Secret - */ - default String getClientSecret() { - return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_SECRET); - } - - /** - * Returns the time at which the {@code client_secret} will expire - * {@code (client_secret_expires_at)}. - * @return the time at which the {@code client_secret} will expire - */ - default Instant getClientSecretExpiresAt() { - return getClaimAsInstant(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); - } - - /** - * Returns the name of the Client to be presented to the End-User - * {@code (client_name)}. - * @return the name of the Client to be presented to the End-User - */ - default String getClientName() { - return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_NAME); - } - - /** - * Returns the redirection {@code URI} values used by the Client - * {@code (redirect_uris)}. - * @return the redirection {@code URI} values used by the Client - */ - default List getRedirectUris() { - return getClaimAsStringList(OidcClientMetadataClaimNames.REDIRECT_URIS); - } +public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataClaimAccessor { /** * Returns the post logout redirection {@code URI} values used by the Client @@ -109,15 +57,6 @@ public interface OidcClientMetadataClaimAccessor extends ClaimAccessor { return getClaimAsStringList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); } - /** - * Returns the authentication method used by the Client for the Token Endpoint - * {@code (token_endpoint_auth_method)}. - * @return the authentication method used by the Client for the Token Endpoint - */ - default String getTokenEndpointAuthenticationMethod() { - return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); - } - /** * Returns the {@link JwsAlgorithm JWS} algorithm that must be used for signing the * {@link Jwt JWT} used to authenticate the Client at the Token Endpoint for the @@ -131,44 +70,6 @@ public interface OidcClientMetadataClaimAccessor extends ClaimAccessor { return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG); } - /** - * Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict - * itself to using {@code (grant_types)}. - * @return the OAuth 2.0 {@code grant_type} values that the Client will restrict - * itself to using - */ - default List getGrantTypes() { - return getClaimAsStringList(OidcClientMetadataClaimNames.GRANT_TYPES); - } - - /** - * Returns the OAuth 2.0 {@code response_type} values that the Client will restrict - * itself to using {@code (response_types)}. - * @return the OAuth 2.0 {@code response_type} values that the Client will restrict - * itself to using - */ - default List getResponseTypes() { - return getClaimAsStringList(OidcClientMetadataClaimNames.RESPONSE_TYPES); - } - - /** - * Returns the OAuth 2.0 {@code scope} values that the Client will restrict itself to - * using {@code (scope)}. - * @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to - * using - */ - default List getScopes() { - return getClaimAsStringList(OidcClientMetadataClaimNames.SCOPE); - } - - /** - * Returns the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}. - * @return the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)} - */ - default URL getJwkSetUrl() { - return getClaimAsURL(OidcClientMetadataClaimNames.JWKS_URI); - } - /** * Returns the {@link SignatureAlgorithm JWS} algorithm required for signing the * {@link OidcIdToken ID Token} issued to the Client diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java index b122b1c94b..2527330e23 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java @@ -20,6 +20,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientMetadataClaimNames; /** * The names of the "claims" defined by OpenID Connect Dynamic Client Registration 1.0 @@ -28,6 +29,7 @@ import org.springframework.security.oauth2.jwt.Jwt; * @author Ovidiu Popa * @author Joe Grandja * @since 7.0 + * @see OAuth2ClientMetadataClaimNames * @see 2. * Client Metadata @@ -35,38 +37,7 @@ import org.springframework.security.oauth2.jwt.Jwt; * "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ClientMetadata">3.1. * Client Registration Metadata */ -public final class OidcClientMetadataClaimNames { - - /** - * {@code client_id} - the Client Identifier - */ - public static final String CLIENT_ID = "client_id"; - - /** - * {@code client_id_issued_at} - the time at which the Client Identifier was issued - */ - public static final String CLIENT_ID_ISSUED_AT = "client_id_issued_at"; - - /** - * {@code client_secret} - the Client Secret - */ - public static final String CLIENT_SECRET = "client_secret"; - - /** - * {@code client_secret_expires_at} - the time at which the {@code client_secret} will - * expire or 0 if it will not expire - */ - public static final String CLIENT_SECRET_EXPIRES_AT = "client_secret_expires_at"; - - /** - * {@code client_name} - the name of the Client to be presented to the End-User - */ - public static final String CLIENT_NAME = "client_name"; - - /** - * {@code redirect_uris} - the redirection {@code URI} values used by the Client - */ - public static final String REDIRECT_URIS = "redirect_uris"; +public final class OidcClientMetadataClaimNames extends OAuth2ClientMetadataClaimNames { /** * {@code post_logout_redirect_uris} - the post logout redirection {@code URI} values @@ -76,12 +47,6 @@ public final class OidcClientMetadataClaimNames { */ public static final String POST_LOGOUT_REDIRECT_URIS = "post_logout_redirect_uris"; - /** - * {@code token_endpoint_auth_method} - the authentication method used by the Client - * for the Token Endpoint - */ - public static final String TOKEN_ENDPOINT_AUTH_METHOD = "token_endpoint_auth_method"; - /** * {@code token_endpoint_auth_signing_alg} - the {@link JwsAlgorithm JWS} algorithm * that must be used for signing the {@link Jwt JWT} used to authenticate the Client @@ -91,29 +56,6 @@ public final class OidcClientMetadataClaimNames { */ public static final String TOKEN_ENDPOINT_AUTH_SIGNING_ALG = "token_endpoint_auth_signing_alg"; - /** - * {@code grant_types} - the OAuth 2.0 {@code grant_type} values that the Client will - * restrict itself to using - */ - public static final String GRANT_TYPES = "grant_types"; - - /** - * {@code response_types} - the OAuth 2.0 {@code response_type} values that the Client - * will restrict itself to using - */ - public static final String RESPONSE_TYPES = "response_types"; - - /** - * {@code scope} - a space-separated list of OAuth 2.0 {@code scope} values that the - * Client will restrict itself to using - */ - public static final String SCOPE = "scope"; - - /** - * {@code jwks_uri} - the {@code URL} for the Client's JSON Web Key Set - */ - public static final String JWKS_URI = "jwks_uri"; - /** * {@code id_token_signed_response_alg} - the {@link JwsAlgorithm JWS} algorithm * required for signing the {@link OidcIdToken ID Token} issued to the Client diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java index c51cfb3fad..e20397628c 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java @@ -17,12 +17,6 @@ package org.springframework.security.oauth2.server.authorization.oidc; import java.io.Serial; -import java.io.Serializable; -import java.net.URI; -import java.net.URL; -import java.time.Instant; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -33,6 +27,7 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.AbstractOAuth2ClientRegistration; import org.springframework.util.Assert; /** @@ -44,6 +39,7 @@ import org.springframework.util.Assert; * @author Ovidiu Popa * @author Joe Grandja * @since 7.0 + * @see AbstractOAuth2ClientRegistration * @see OidcClientMetadataClaimAccessor * @see 3.1. @@ -55,25 +51,14 @@ import org.springframework.util.Assert; * "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ClientMetadata">3.1. * Client Registration Metadata */ -public final class OidcClientRegistration implements OidcClientMetadataClaimAccessor, Serializable { +public final class OidcClientRegistration extends AbstractOAuth2ClientRegistration + implements OidcClientMetadataClaimAccessor { @Serial - private static final long serialVersionUID = 6518710174552040014L; - - private final Map claims; + private static final long serialVersionUID = -8485448209864668396L; private OidcClientRegistration(Map claims) { - Assert.notEmpty(claims, "claims cannot be empty"); - this.claims = Collections.unmodifiableMap(new LinkedHashMap<>(claims)); - } - - /** - * Returns the metadata as claims. - * @return a {@code Map} of the metadata as claims - */ - @Override - public Map getClaims() { - return this.claims; + super(claims); } /** @@ -97,82 +82,11 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce /** * Helps configure an {@link OidcClientRegistration}. */ - public static final class Builder { - - private final Map claims = new LinkedHashMap<>(); + public static final class Builder extends AbstractBuilder { private Builder() { } - /** - * Sets the Client Identifier, REQUIRED. - * @param clientId the Client Identifier - * @return the {@link Builder} for further configuration - */ - public Builder clientId(String clientId) { - return claim(OidcClientMetadataClaimNames.CLIENT_ID, clientId); - } - - /** - * Sets the time at which the Client Identifier was issued, OPTIONAL. - * @param clientIdIssuedAt the time at which the Client Identifier was issued - * @return the {@link Builder} for further configuration - */ - public Builder clientIdIssuedAt(Instant clientIdIssuedAt) { - return claim(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIdIssuedAt); - } - - /** - * Sets the Client Secret, OPTIONAL. - * @param clientSecret the Client Secret - * @return the {@link Builder} for further configuration - */ - public Builder clientSecret(String clientSecret) { - return claim(OidcClientMetadataClaimNames.CLIENT_SECRET, clientSecret); - } - - /** - * Sets the time at which the {@code client_secret} will expire or {@code null} if - * it will not expire, REQUIRED if {@code client_secret} was issued. - * @param clientSecretExpiresAt the time at which the {@code client_secret} will - * expire or {@code null} if it will not expire - * @return the {@link Builder} for further configuration - */ - public Builder clientSecretExpiresAt(Instant clientSecretExpiresAt) { - return claim(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); - } - - /** - * Sets the name of the Client to be presented to the End-User, OPTIONAL. - * @param clientName the name of the Client to be presented to the End-User - * @return the {@link Builder} for further configuration - */ - public Builder clientName(String clientName) { - return claim(OidcClientMetadataClaimNames.CLIENT_NAME, clientName); - } - - /** - * Add the redirection {@code URI} used by the Client, REQUIRED. - * @param redirectUri the redirection {@code URI} used by the Client - * @return the {@link Builder} for further configuration - */ - public Builder redirectUri(String redirectUri) { - addClaimToClaimList(OidcClientMetadataClaimNames.REDIRECT_URIS, redirectUri); - return this; - } - - /** - * A {@code Consumer} of the redirection {@code URI} values used by the Client, - * allowing the ability to add, replace, or remove, REQUIRED. - * @param redirectUrisConsumer a {@code Consumer} of the redirection {@code URI} - * values used by the Client - * @return the {@link Builder} for further configuration - */ - public Builder redirectUris(Consumer> redirectUrisConsumer) { - acceptClaimValues(OidcClientMetadataClaimNames.REDIRECT_URIS, redirectUrisConsumer); - return this; - } - /** * Add the post logout redirection {@code URI} used by the Client, OPTIONAL. The * {@code post_logout_redirect_uri} parameter is used by the client when @@ -199,17 +113,6 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce return this; } - /** - * Sets the authentication method used by the Client for the Token Endpoint, - * OPTIONAL. - * @param tokenEndpointAuthenticationMethod the authentication method used by the - * Client for the Token Endpoint - * @return the {@link Builder} for further configuration - */ - public Builder tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticationMethod) { - return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod); - } - /** * Sets the {@link JwsAlgorithm JWS} algorithm that must be used for signing the * {@link Jwt JWT} used to authenticate the Client at the Token Endpoint for the @@ -225,90 +128,6 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, authenticationSigningAlgorithm); } - /** - * Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to - * using, OPTIONAL. - * @param grantType the OAuth 2.0 {@code grant_type} that the Client will restrict - * itself to using - * @return the {@link Builder} for further configuration - */ - public Builder grantType(String grantType) { - addClaimToClaimList(OidcClientMetadataClaimNames.GRANT_TYPES, grantType); - return this; - } - - /** - * A {@code Consumer} of the OAuth 2.0 {@code grant_type} values that the Client - * will restrict itself to using, allowing the ability to add, replace, or remove, - * OPTIONAL. - * @param grantTypesConsumer a {@code Consumer} of the OAuth 2.0 - * {@code grant_type} values that the Client will restrict itself to using - * @return the {@link Builder} for further configuration - */ - public Builder grantTypes(Consumer> grantTypesConsumer) { - acceptClaimValues(OidcClientMetadataClaimNames.GRANT_TYPES, grantTypesConsumer); - return this; - } - - /** - * Add the OAuth 2.0 {@code response_type} that the Client will restrict itself to - * using, OPTIONAL. - * @param responseType the OAuth 2.0 {@code response_type} that the Client will - * restrict itself to using - * @return the {@link Builder} for further configuration - */ - public Builder responseType(String responseType) { - addClaimToClaimList(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseType); - return this; - } - - /** - * A {@code Consumer} of the OAuth 2.0 {@code response_type} values that the - * Client will restrict itself to using, allowing the ability to add, replace, or - * remove, OPTIONAL. - * @param responseTypesConsumer a {@code Consumer} of the OAuth 2.0 - * {@code response_type} values that the Client will restrict itself to using - * @return the {@link Builder} for further configuration - */ - public Builder responseTypes(Consumer> responseTypesConsumer) { - acceptClaimValues(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseTypesConsumer); - return this; - } - - /** - * Add the OAuth 2.0 {@code scope} that the Client will restrict itself to using, - * OPTIONAL. - * @param scope the OAuth 2.0 {@code scope} that the Client will restrict itself - * to using - * @return the {@link Builder} for further configuration - */ - public Builder scope(String scope) { - addClaimToClaimList(OidcClientMetadataClaimNames.SCOPE, scope); - return this; - } - - /** - * A {@code Consumer} of the OAuth 2.0 {@code scope} values that the Client will - * restrict itself to using, allowing the ability to add, replace, or remove, - * OPTIONAL. - * @param scopesConsumer a {@code Consumer} of the OAuth 2.0 {@code scope} values - * that the Client will restrict itself to using - * @return the {@link Builder} for further configuration - */ - public Builder scopes(Consumer> scopesConsumer) { - acceptClaimValues(OidcClientMetadataClaimNames.SCOPE, scopesConsumer); - return this; - } - - /** - * Sets the {@code URL} for the Client's JSON Web Key Set, OPTIONAL. - * @param jwkSetUrl the {@code URL} for the Client's JSON Web Key Set - * @return the {@link Builder} for further configuration - */ - public Builder jwkSetUrl(String jwkSetUrl) { - return claim(OidcClientMetadataClaimNames.JWKS_URI, jwkSetUrl); - } - /** * Sets the {@link SignatureAlgorithm JWS} algorithm required for signing the * {@link OidcIdToken ID Token} issued to the Client, OPTIONAL. @@ -343,120 +162,51 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce return claim(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI, registrationClientUrl); } - /** - * Sets the claim. - * @param name the claim name - * @param value the claim value - * @return the {@link Builder} for further configuration - */ - public Builder claim(String name, Object value) { - Assert.hasText(name, "name cannot be empty"); - Assert.notNull(value, "value cannot be null"); - this.claims.put(name, value); - return this; - } - - /** - * Provides access to every {@link #claim(String, Object)} declared so far - * allowing the ability to add, replace, or remove. - * @param claimsConsumer a {@code Consumer} of the claims - * @return the {@link Builder} for further configurations - */ - public Builder claims(Consumer> claimsConsumer) { - claimsConsumer.accept(this.claims); - return this; - } - /** * Validate the claims and build the {@link OidcClientRegistration}. *

* The following claims are REQUIRED: {@code client_id}, {@code redirect_uris}. * @return the {@link OidcClientRegistration} */ + @Override public OidcClientRegistration build() { validate(); - return new OidcClientRegistration(this.claims); + return new OidcClientRegistration(getClaims()); } - private void validate() { - if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null - || this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET) != null) { - Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID), "client_id cannot be null"); - } - if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null) { - Assert.isInstanceOf(Instant.class, this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT), - "client_id_issued_at must be of type Instant"); - } - if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT) != null) { - Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET), - "client_secret cannot be null"); - Assert.isInstanceOf(Instant.class, - this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT), - "client_secret_expires_at must be of type Instant"); - } - Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be null"); - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), + @Override + protected void validate() { + super.validate(); + Assert.notNull(getClaims().get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be null"); + Assert.isInstanceOf(List.class, getClaims().get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris must be of type List"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), + Assert.notEmpty((List) getClaims().get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be empty"); - if (this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS) != null) { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), + if (getClaims().get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS) != null) { + Assert.isInstanceOf(List.class, getClaims().get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), "post_logout_redirect_uris must be of type List"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), + Assert.notEmpty((List) getClaims().get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), "post_logout_redirect_uris cannot be empty"); } - if (this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES) != null) { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), - "grant_types must be of type List"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), - "grant_types cannot be empty"); - } - if (this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES) != null) { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES), - "response_types must be of type List"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES), - "response_types cannot be empty"); - } - if (this.claims.get(OidcClientMetadataClaimNames.SCOPE) != null) { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.SCOPE), - "scope must be of type List"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.SCOPE), "scope cannot be empty"); - } - if (this.claims.get(OidcClientMetadataClaimNames.JWKS_URI) != null) { - validateURL(this.claims.get(OidcClientMetadataClaimNames.JWKS_URI), "jwksUri must be a valid URL"); - } } @SuppressWarnings("unchecked") private void addClaimToClaimList(String name, String value) { Assert.hasText(name, "name cannot be empty"); Assert.notNull(value, "value cannot be null"); - this.claims.computeIfAbsent(name, (k) -> new LinkedList()); - ((List) this.claims.get(name)).add(value); + getClaims().computeIfAbsent(name, (k) -> new LinkedList()); + ((List) getClaims().get(name)).add(value); } @SuppressWarnings("unchecked") private void acceptClaimValues(String name, Consumer> valuesConsumer) { Assert.hasText(name, "name cannot be empty"); Assert.notNull(valuesConsumer, "valuesConsumer cannot be null"); - this.claims.computeIfAbsent(name, (k) -> new LinkedList()); - List values = (List) this.claims.get(name); + getClaims().computeIfAbsent(name, (k) -> new LinkedList()); + List values = (List) getClaims().get(name); valuesConsumer.accept(values); } - private static void validateURL(Object url, String errorMessage) { - if (URL.class.isAssignableFrom(url.getClass())) { - return; - } - - try { - new URI(url.toString()).toURL(); - } - catch (Exception ex) { - throw new IllegalArgumentException(errorMessage, ex); - } - } - } } diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java index 11961eb56e..f894ee151e 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java @@ -40,7 +40,7 @@ import org.springframework.util.Assert; public class OidcClientRegistrationAuthenticationToken extends AbstractAuthenticationToken { @Serial - private static final long serialVersionUID = -6198261907690781217L; + private static final long serialVersionUID = 5392324479052435784L; private final Authentication principal; diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java index deeaef0429..b7abf76da6 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java @@ -137,6 +137,15 @@ public final class AuthorizationServerSettings extends AbstractSettings { return getSetting(ConfigurationSettingNames.AuthorizationServer.TOKEN_INTROSPECTION_ENDPOINT); } + /** + * Returns the OAuth 2.0 Dynamic Client Registration endpoint. The default is + * {@code /oauth2/register}. + * @return the OAuth 2.0 Dynamic Client Registration endpoint + */ + public String getClientRegistrationEndpoint() { + return getSetting(ConfigurationSettingNames.AuthorizationServer.CLIENT_REGISTRATION_ENDPOINT); + } + /** * Returns the OpenID Connect 1.0 Client Registration endpoint. The default is * {@code /connect/register}. @@ -177,6 +186,7 @@ public final class AuthorizationServerSettings extends AbstractSettings { .jwkSetEndpoint("/oauth2/jwks") .tokenRevocationEndpoint("/oauth2/revoke") .tokenIntrospectionEndpoint("/oauth2/introspect") + .clientRegistrationEndpoint("/oauth2/register") .oidcClientRegistrationEndpoint("/connect/register") .oidcUserInfoEndpoint("/userinfo") .oidcLogoutEndpoint("/connect/logout"); @@ -315,6 +325,17 @@ public final class AuthorizationServerSettings extends AbstractSettings { tokenIntrospectionEndpoint); } + /** + * Sets the OAuth 2.0 Dynamic Client Registration endpoint. + * @param clientRegistrationEndpoint the OAuth 2.0 Dynamic Client Registration + * endpoint + * @return the {@link Builder} for further configuration + */ + public Builder clientRegistrationEndpoint(String clientRegistrationEndpoint) { + return setting(ConfigurationSettingNames.AuthorizationServer.CLIENT_REGISTRATION_ENDPOINT, + clientRegistrationEndpoint); + } + /** * Sets the OpenID Connect 1.0 Client Registration endpoint. * @param oidcClientRegistrationEndpoint the OpenID Connect 1.0 Client diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java index 97ff73517d..1f7a1dfbc0 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java @@ -150,6 +150,12 @@ public final class ConfigurationSettingNames { public static final String TOKEN_INTROSPECTION_ENDPOINT = AUTHORIZATION_SERVER_SETTINGS_NAMESPACE .concat("token-introspection-endpoint"); + /** + * Set the OAuth 2.0 Dynamic Client Registration endpoint. + */ + public static final String CLIENT_REGISTRATION_ENDPOINT = AUTHORIZATION_SERVER_SETTINGS_NAMESPACE + .concat("client-registration-endpoint"); + /** * Set the OpenID Connect 1.0 Client Registration endpoint. */ diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilter.java new file mode 100644 index 0000000000..a94ca87e9f --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilter.java @@ -0,0 +1,212 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.web; + +import java.io.IOException; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.core.log.LogMessage; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2ClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientRegistrationAuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.web.filter.OncePerRequestFilter; + +/** + * A {@code Filter} that processes OAuth 2.0 Dynamic Client Registration Requests. + * + * @author Joe Grandja + * @since 7.0 + * @see OAuth2ClientRegistration + * @see OAuth2ClientRegistrationAuthenticationConverter + * @see OAuth2ClientRegistrationAuthenticationProvider + * @see 3. Client + * Registration Endpoint + */ +public final class OAuth2ClientRegistrationEndpointFilter extends OncePerRequestFilter { + + /** + * The default endpoint {@code URI} for OAuth 2.0 Client Registration requests. + */ + private static final String DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI = "/oauth2/register"; + + private final AuthenticationManager authenticationManager; + + private final RequestMatcher clientRegistrationEndpointMatcher; + + private final HttpMessageConverter clientRegistrationHttpMessageConverter = new OAuth2ClientRegistrationHttpMessageConverter(); + + private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + + private AuthenticationConverter authenticationConverter = new OAuth2ClientRegistrationAuthenticationConverter(); + + private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse; + + private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + + /** + * Constructs an {@code OAuth2ClientRegistrationEndpointFilter} using the provided + * parameters. + * @param authenticationManager the authentication manager + */ + public OAuth2ClientRegistrationEndpointFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI); + } + + /** + * Constructs an {@code OAuth2ClientRegistrationEndpointFilter} using the provided + * parameters. + * @param authenticationManager the authentication manager + * @param clientRegistrationEndpointUri the endpoint {@code URI} for OAuth 2.0 Client + * Registration requests + */ + public OAuth2ClientRegistrationEndpointFilter(AuthenticationManager authenticationManager, + String clientRegistrationEndpointUri) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.hasText(clientRegistrationEndpointUri, "clientRegistrationEndpointUri cannot be empty"); + this.authenticationManager = authenticationManager; + this.clientRegistrationEndpointMatcher = PathPatternRequestMatcher.withDefaults() + .matcher(HttpMethod.POST, clientRegistrationEndpointUri); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!this.clientRegistrationEndpointMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + try { + Authentication clientRegistrationAuthentication = this.authenticationConverter.convert(request); + + Authentication clientRegistrationAuthenticationResult = this.authenticationManager + .authenticate(clientRegistrationAuthentication); + + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, + clientRegistrationAuthenticationResult); + } + catch (OAuth2AuthenticationException ex) { + if (this.logger.isTraceEnabled()) { + this.logger.trace(LogMessage.format("Client registration request failed: %s", ex.getError()), ex); + } + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); + } + catch (Exception ex) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, + "OAuth 2.0 Client Registration Error: " + ex.getMessage(), + "https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2"); + if (this.logger.isTraceEnabled()) { + this.logger.trace(error.getDescription(), ex); + } + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(error)); + } + finally { + SecurityContextHolder.clearContext(); + } + } + + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract a Client + * Registration Request from {@link HttpServletRequest} to an instance of + * {@link OAuth2ClientRegistrationAuthenticationToken} used for authenticating the + * request. + * @param authenticationConverter an {@link AuthenticationConverter} used when + * attempting to extract a Client Registration Request from {@link HttpServletRequest} + */ + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an + * {@link OAuth2ClientRegistrationAuthenticationToken} and returning the + * {@link OAuth2ClientRegistration Client Registration Response}. + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used + * for handling an {@link OAuth2ClientRegistrationAuthenticationToken} + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an + * {@link OAuth2AuthenticationException} and returning the {@link OAuth2Error Error + * Response}. + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used + * for handling an {@link OAuth2AuthenticationException} + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + private void sendClientRegistrationResponse(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException { + OAuth2ClientRegistration clientRegistration = ((OAuth2ClientRegistrationAuthenticationToken) authentication) + .getClientRegistration(); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.CREATED); + this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse); + } + + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + AuthenticationException authenticationException) throws IOException { + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); + HttpStatus httpStatus = HttpStatus.BAD_REQUEST; + if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) { + httpStatus = HttpStatus.UNAUTHORIZED; + } + else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(error.getErrorCode())) { + httpStatus = HttpStatus.FORBIDDEN; + } + else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { + httpStatus = HttpStatus.UNAUTHORIZED; + } + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(httpStatus); + this.errorHttpResponseConverter.write(error, null, httpResponse); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientRegistrationAuthenticationConverter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientRegistrationAuthenticationConverter.java new file mode 100644 index 0000000000..4741384fad --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientRegistrationAuthenticationConverter.java @@ -0,0 +1,69 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2ClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientRegistrationEndpointFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; + +/** + * Attempts to extract an OAuth 2.0 Dynamic Client Registration Request from + * {@link HttpServletRequest} and then converts to an + * {@link OAuth2ClientRegistrationAuthenticationToken} used for authenticating the + * request. + * + * @author Joe Grandja + * @since 7.0 + * @see AuthenticationConverter + * @see OAuth2ClientRegistrationAuthenticationToken + * @see OAuth2ClientRegistrationEndpointFilter + */ +public final class OAuth2ClientRegistrationAuthenticationConverter implements AuthenticationConverter { + + private final HttpMessageConverter clientRegistrationHttpMessageConverter = new OAuth2ClientRegistrationHttpMessageConverter(); + + @Override + public Authentication convert(HttpServletRequest request) { + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + + OAuth2ClientRegistration clientRegistration; + try { + clientRegistration = this.clientRegistrationHttpMessageConverter.read(OAuth2ClientRegistration.class, + new ServletServerHttpRequest(request)); + } + catch (Exception ex) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, + "OAuth 2.0 Client Registration Error: " + ex.getMessage(), + "https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2"); + throw new OAuth2AuthenticationException(error, ex); + } + + return new OAuth2ClientRegistrationAuthenticationToken(principal, clientRegistration); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistrationTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistrationTests.java new file mode 100644 index 0000000000..b51b93c3bf --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2ClientRegistrationTests.java @@ -0,0 +1,360 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization; + +import java.net.URL; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OAuth2ClientRegistration}. + * + * @author Joe Grandja + */ +public class OAuth2ClientRegistrationTests { + + @Test + public void buildWhenAllClaimsProvidedThenCreated() throws Exception { + // @formatter:off + Instant clientIdIssuedAt = Instant.now(); + Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(clientIdIssuedAt) + .clientSecret("client-secret") + .clientSecretExpiresAt(clientSecretExpiresAt) + .clientName("client-name") + .redirectUri("https://client.example.com") + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .scope("scope1") + .scope("scope2") + .jwkSetUrl("https://client.example.com/jwks") + .claim("a-claim", "a-value") + .build(); + // @formatter:on + + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(clientIdIssuedAt); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) + .isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", + "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getJwkSetUrl()).isEqualTo(new URL("https://client.example.com/jwks")); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); + } + + @Test + public void withClaimsWhenClaimsProvidedThenCreated() throws Exception { + Instant clientIdIssuedAt = Instant.now(); + Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS); + HashMap claims = new HashMap<>(); + claims.put(OAuth2ClientMetadataClaimNames.CLIENT_ID, "client-id"); + claims.put(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIdIssuedAt); + claims.put(OAuth2ClientMetadataClaimNames.CLIENT_SECRET, "client-secret"); + claims.put(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + claims.put(OAuth2ClientMetadataClaimNames.CLIENT_NAME, "client-name"); + claims.put(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, + Collections.singletonList("https://client.example.com")); + claims.put(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); + claims.put(OAuth2ClientMetadataClaimNames.GRANT_TYPES, + Arrays.asList(AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), + AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())); + claims.put(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, Collections.singletonList("code")); + claims.put(OAuth2ClientMetadataClaimNames.SCOPE, Arrays.asList("scope1", "scope2")); + claims.put(OAuth2ClientMetadataClaimNames.JWKS_URI, "https://client.example.com/jwks"); + claims.put("a-claim", "a-value"); + + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.withClaims(claims).build(); + + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(clientIdIssuedAt); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) + .isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", + "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getJwkSetUrl()).isEqualTo(new URL("https://client.example.com/jwks")); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); + } + + @Test + public void withClaimsWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> OAuth2ClientRegistration.withClaims(null)) + .withMessage("claims cannot be empty"); + } + + @Test + public void withClaimsWhenEmptyThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2ClientRegistration.withClaims(Collections.emptyMap())) + .withMessage("claims cannot be empty"); + } + + @Test + public void buildWhenMissingClientIdThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder().clientIdIssuedAt(Instant.now()); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("client_id cannot be null"); + } + + @Test + public void buildWhenClientSecretAndMissingClientIdThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder().clientSecret("client-secret"); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("client_id cannot be null"); + } + + @Test + public void buildWhenClientIdIssuedAtNotInstantThenThrowIllegalArgumentException() { + // @formatter:off + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .clientId("client-id") + .claim(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, "clientIdIssuedAt"); + // @formatter:on + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("client_id_issued_at must be of type Instant"); + } + + @Test + public void buildWhenMissingClientSecretThenThrowIllegalArgumentException() { + // @formatter:off + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecretExpiresAt(Instant.now().plus(30, ChronoUnit.DAYS)); + // @formatter:on + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("client_secret cannot be null"); + } + + @Test + public void buildWhenClientSecretExpiresAtNotInstantThenThrowIllegalArgumentException() { + // @formatter:off + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .claim(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, "clientSecretExpiresAt"); + // @formatter:on + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("client_secret_expires_at must be of type Instant"); + } + + @Test + public void buildWhenRedirectUrisNotListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, "redirectUris"); + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("redirect_uris must be of type List"); + } + + @Test + public void buildWhenRedirectUrisEmptyListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.REDIRECT_URIS, Collections.emptyList()); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("redirect_uris cannot be empty"); + } + + @Test + public void buildWhenRedirectUrisAddingOrRemovingThenCorrectValues() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client1.example.com") + .redirectUris((redirectUris) -> { + redirectUris.clear(); + redirectUris.add("https://client2.example.com"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getRedirectUris()).containsExactly("https://client2.example.com"); + } + + @Test + public void buildWhenGrantTypesNotListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.GRANT_TYPES, "grantTypes"); + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("grant_types must be of type List"); + } + + @Test + public void buildWhenGrantTypesEmptyListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.GRANT_TYPES, Collections.emptyList()); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("grant_types cannot be empty"); + } + + @Test + public void buildWhenGrantTypesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .grantType("authorization_code") + .grantTypes((grantTypes) -> { + grantTypes.clear(); + grantTypes.add("client_credentials"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getGrantTypes()).containsExactly("client_credentials"); + } + + @Test + public void buildWhenResponseTypesNotListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, "responseTypes"); + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("response_types must be of type List"); + } + + @Test + public void buildWhenResponseTypesEmptyListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES, Collections.emptyList()); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("response_types cannot be empty"); + } + + @Test + public void buildWhenResponseTypesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .responseType("token") + .responseTypes((responseTypes) -> { + responseTypes.clear(); + responseTypes.add("code"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getResponseTypes()).containsExactly("code"); + } + + @Test + public void buildWhenScopesNotListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.SCOPE, "scopes"); + + assertThatIllegalArgumentException().isThrownBy(builder::build) + .withMessageStartingWith("scope must be of type List"); + } + + @Test + public void buildWhenScopesEmptyListThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.SCOPE, Collections.emptyList()); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("scope cannot be empty"); + } + + @Test + public void buildWhenScopesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .scope("should-be-removed") + .scopes((scopes) -> { + scopes.clear(); + scopes.add("scope1"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getScopes()).containsExactly("scope1"); + } + + @Test + public void buildWhenJwksUriNotUrlThenThrowIllegalArgumentException() { + OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() + .claim(OAuth2ClientMetadataClaimNames.JWKS_URI, "not an url"); + + assertThatIllegalArgumentException().isThrownBy(builder::build).withMessage("jwksUri must be a valid URL"); + } + + @Test + public void claimWhenNameNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2ClientRegistration.builder().claim(null, "claim-value")) + .withMessage("name cannot be empty"); + } + + @Test + public void claimWhenValueNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2ClientRegistration.builder().claim("claim-name", null)) + .withMessage("value cannot be null"); + } + + @Test + public void claimsWhenRemovingClaimThenNotPresent() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .claim("claim-name", "claim-value") + .claims((claims) -> claims.remove("claim-name")) + .build(); + // @formatter:on + + assertThat(clientRegistration.hasClaim("claim-name")).isFalse(); + } + + @Test + public void claimsWhenAddingClaimThenPresent() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .claim("claim-name", "claim-value") + .build(); + // @formatter:on + + assertThat(clientRegistration.hasClaim("claim-name")).isTrue(); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProviderTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProviderTests.java new file mode 100644 index 0000000000..b4fdbf8954 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientRegistrationAuthenticationProviderTests.java @@ -0,0 +1,500 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.crypto.password.NoOpPasswordEncoder; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.JwsHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.TestJwsHeaders; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientMetadataClaimNames; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link OAuth2ClientRegistrationAuthenticationProvider}. + * + * @author Joe Grandja + */ +public class OAuth2ClientRegistrationAuthenticationProviderTests { + + private RegisteredClientRepository registeredClientRepository; + + private OAuth2AuthorizationService authorizationService; + + private PasswordEncoder passwordEncoder; + + private OAuth2ClientRegistrationAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.passwordEncoder = spy(new PasswordEncoder() { + @Override + public String encode(CharSequence rawPassword) { + return NoOpPasswordEncoder.getInstance().encode(rawPassword); + } + + @Override + public boolean matches(CharSequence rawPassword, String encodedPassword) { + return NoOpPasswordEncoder.getInstance().matches(rawPassword, encodedPassword); + } + }); + this.authenticationProvider = new OAuth2ClientRegistrationAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + this.authenticationProvider.setPasswordEncoder(this.passwordEncoder); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2ClientRegistrationAuthenticationProvider(null, this.authorizationService)) + .withMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2ClientRegistrationAuthenticationProvider(this.registeredClientRepository, null)) + .withMessage("authorizationService cannot be null"); + } + + @Test + public void setRegisteredClientConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setRegisteredClientConverter(null)) + .withMessage("registeredClientConverter cannot be null"); + } + + @Test + public void setClientRegistrationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setClientRegistrationConverter(null)) + .withMessage("clientRegistrationConverter cannot be null"); + } + + @Test + public void setPasswordEncoderWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authenticationProvider.setPasswordEncoder(null)) + .withMessage("passwordEncoder cannot be null"); + } + + @Test + public void supportsWhenTypeOAuth2ClientRegistrationAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2ClientRegistrationAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenPrincipalNotOAuth2TokenAuthenticationTokenThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + JwtAuthenticationToken principal = new JwtAuthenticationToken(createJwtClientRegistration()); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void authenticateWhenAccessTokenNotFoundThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + verify(this.authorizationService).findByToken(eq(jwt.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .invalidate(jwtAccessToken) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotAuthorizedThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwt(Collections.singleton("unauthorized.scope")); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_unauthorized.scope")); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenContainsRequiredScopeAndAdditionalScopeThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwt(new HashSet<>(Arrays.asList("client.create", "scope1"))); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create", "SCOPE_scope1")); + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("invalid uri") + .build(); + // @formatter:on + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .satisfies((error) -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REDIRECT_URI); + assertThat(error.getDescription()).contains(OAuth2ClientMetadataClaimNames.REDIRECT_URIS); + }); + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenRedirectUriContainsFragmentThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com#fragment") + .build(); + // @formatter:on + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .satisfies((error) -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REDIRECT_URI); + assertThat(error.getDescription()).contains(OAuth2ClientMetadataClaimNames.REDIRECT_URIS); + }); + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenValidAccessTokenThenReturnClientRegistration() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, jwtAccessToken, jwt.getClaims()) + .build(); + given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN))) + .willReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + principal, clientRegistration); + OAuth2ClientRegistrationAuthenticationToken authenticationResult = (OAuth2ClientRegistrationAuthenticationToken) this.authenticationProvider + .authenticate(authentication); + + ArgumentCaptor registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).findByToken(eq(jwtAccessToken.getTokenValue()), + eq(OAuth2TokenType.ACCESS_TOKEN)); + verify(this.registeredClientRepository).save(registeredClientCaptor.capture()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verify(this.passwordEncoder).encode(any()); + + // assert "initial" access token is invalidated + OAuth2Authorization authorizationResult = authorizationCaptor.getValue(); + assertThat(authorizationResult.getAccessToken().isInvalidated()).isTrue(); + if (authorizationResult.getRefreshToken() != null) { + assertThat(authorizationResult.getRefreshToken().isInvalidated()).isTrue(); + } + + assertClientRegistration(clientRegistration, authenticationResult.getClientRegistration(), + registeredClientCaptor.getValue()); + } + + @Test + public void authenticateWhenOpenRegistrationThenReturnClientRegistration() { + this.authenticationProvider.setOpenRegistrationAllowed(true); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OAuth2ClientRegistrationAuthenticationToken authentication = new OAuth2ClientRegistrationAuthenticationToken( + null, clientRegistration); + OAuth2ClientRegistrationAuthenticationToken authenticationResult = (OAuth2ClientRegistrationAuthenticationToken) this.authenticationProvider + .authenticate(authentication); + + ArgumentCaptor registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class); + + verifyNoInteractions(this.authorizationService); + verify(this.registeredClientRepository).save(registeredClientCaptor.capture()); + verify(this.passwordEncoder).encode(any()); + + assertClientRegistration(clientRegistration, authenticationResult.getClientRegistration(), + registeredClientCaptor.getValue()); + } + + private static void assertClientRegistration(OAuth2ClientRegistration clientRegistrationRequest, + OAuth2ClientRegistration clientRegistrationResult, RegisteredClient registeredClient) { + + assertThat(registeredClient.getId()).isNotNull(); + assertThat(registeredClient.getClientId()).isNotNull(); + assertThat(registeredClient.getClientIdIssuedAt()).isNotNull(); + assertThat(registeredClient.getClientSecret()).isNotNull(); + assertThat(registeredClient.getClientName()).isEqualTo(clientRegistrationRequest.getClientName()); + assertThat(registeredClient.getClientAuthenticationMethods()) + .containsExactly(ClientAuthenticationMethod.CLIENT_SECRET_BASIC); + assertThat(registeredClient.getRedirectUris()).containsExactly("https://client.example.com"); + assertThat(registeredClient.getAuthorizationGrantTypes()).containsExactlyInAnyOrder( + AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); + assertThat(registeredClient.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(registeredClient.getClientSettings().isRequireProofKey()).isTrue(); + assertThat(registeredClient.getClientSettings().isRequireAuthorizationConsent()).isTrue(); + + assertThat(clientRegistrationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(clientRegistrationResult.getClientIdIssuedAt()).isEqualTo(registeredClient.getClientIdIssuedAt()); + assertThat(clientRegistrationResult.getClientSecret()).isEqualTo(registeredClient.getClientSecret()); + assertThat(clientRegistrationResult.getClientSecretExpiresAt()) + .isEqualTo(registeredClient.getClientSecretExpiresAt()); + assertThat(clientRegistrationResult.getClientName()).isEqualTo(registeredClient.getClientName()); + assertThat(clientRegistrationResult.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(registeredClient.getRedirectUris()); + + List grantTypes = new ArrayList<>(); + registeredClient.getAuthorizationGrantTypes() + .forEach((authorizationGrantType) -> grantTypes.add(authorizationGrantType.getValue())); + assertThat(clientRegistrationResult.getGrantTypes()).containsExactlyInAnyOrderElementsOf(grantTypes); + + assertThat(clientRegistrationResult.getResponseTypes()) + .containsExactly(OAuth2AuthorizationResponseType.CODE.getValue()); + assertThat(clientRegistrationResult.getScopes()) + .containsExactlyInAnyOrderElementsOf(registeredClient.getScopes()); + assertThat(clientRegistrationResult.getTokenEndpointAuthenticationMethod()) + .isEqualTo(registeredClient.getClientAuthenticationMethods().iterator().next().getValue()); + } + + private static Jwt createJwtClientRegistration() { + return createJwt(Collections.singleton("client.create")); + } + + private static Jwt createJwt(Set scopes) { + // @formatter:off + JwsHeader jwsHeader = TestJwsHeaders.jwsHeader() + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet() + .claim(OAuth2ParameterNames.SCOPE, scopes) + .build(); + Jwt jwt = Jwt.withTokenValue("jwt-access-token") + .headers((headers) -> headers.putAll(jwsHeader.getHeaders())) + .claims((claims) -> claims.putAll(jwtClaimsSet.getClaims())) + .build(); + // @formatter:on + return jwt; + } + +} diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverterTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverterTests.java new file mode 100644 index 0000000000..eac8747f80 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/http/converter/OAuth2ClientRegistrationHttpMessageConverterTests.java @@ -0,0 +1,234 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.http.converter; + +import java.net.URL; +import java.time.Instant; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.mock.http.MockHttpOutputMessage; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OAuth2ClientRegistrationHttpMessageConverter} + * + * @author Joe Grandja + * @since 7.0 + */ +public class OAuth2ClientRegistrationHttpMessageConverterTests { + + private final OAuth2ClientRegistrationHttpMessageConverter messageConverter = new OAuth2ClientRegistrationHttpMessageConverter(); + + @Test + public void supportsWhenOAuth2ClientRegistrationThenTrue() { + assertThat(this.messageConverter.supports(OAuth2ClientRegistration.class)).isTrue(); + } + + @Test + public void setClientRegistrationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.messageConverter.setClientRegistrationConverter(null)) + .withMessageContaining("clientRegistrationConverter cannot be null"); + } + + @Test + public void setClientRegistrationParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.messageConverter.setClientRegistrationParametersConverter(null)) + .withMessageContaining("clientRegistrationParametersConverter cannot be null"); + } + + @Test + public void readInternalWhenValidParametersThenSuccess() throws Exception { + // @formatter:off + String clientRegistrationRequest = "{\n" + + " \"client_id\": \"client-id\",\n" + + " \"client_id_issued_at\": 1607633867,\n" + + " \"client_secret\": \"client-secret\",\n" + + " \"client_secret_expires_at\": 1607637467,\n" + + " \"client_name\": \"client-name\",\n" + + " \"redirect_uris\": [\n" + + " \"https://client.example.com\"\n" + + " ],\n" + + " \"token_endpoint_auth_method\": \"client_secret_basic\",\n" + + " \"grant_types\": [\n" + + " \"authorization_code\",\n" + + " \"client_credentials\"\n" + + " ],\n" + + " \"response_types\":[\n" + + " \"code\"\n" + + " ],\n" + + " \"scope\": \"scope1 scope2\",\n" + + " \"jwks_uri\": \"https://client.example.com/jwks\",\n" + + " \"a-claim\": \"a-value\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(clientRegistrationRequest.getBytes(), + HttpStatus.OK); + OAuth2ClientRegistration clientRegistration = this.messageConverter.readInternal(OAuth2ClientRegistration.class, + response); + + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(Instant.ofEpochSecond(1607633867L)); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(Instant.ofEpochSecond(1607637467L)); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) + .isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", + "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getJwkSetUrl()).isEqualTo(new URL("https://client.example.com/jwks")); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); + } + + @Test + public void readInternalWhenClientSecretNoExpiryThenSuccess() { + // @formatter:off + String clientRegistrationRequest = "{\n" + + " \"client_id\": \"client-id\",\n" + + " \"client_secret\": \"client-secret\",\n" + + " \"client_secret_expires_at\": 0,\n" + + " \"redirect_uris\": [\n" + + " \"https://client.example.com\"\n" + + " ]\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(clientRegistrationRequest.getBytes(), + HttpStatus.OK); + OAuth2ClientRegistration clientRegistration = this.messageConverter.readInternal(OAuth2ClientRegistration.class, + response); + + assertThat(clientRegistration.getClaims()).hasSize(3); + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isNull(); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + } + + @Test + public void readInternalWhenFailingConverterThenThrowException() { + String errorMessage = "this is not a valid converter"; + this.messageConverter.setClientRegistrationConverter((source) -> { + throw new RuntimeException(errorMessage); + }); + MockClientHttpResponse response = new MockClientHttpResponse("{}".getBytes(), HttpStatus.OK); + + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OAuth2ClientRegistration.class, response)) + .withMessageContaining("An error occurred reading the OAuth 2.0 Client Registration") + .withMessageContaining(errorMessage); + } + + @Test + public void writeInternalWhenClientRegistrationThenSuccess() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.ofEpochSecond(1607633867)) + .clientSecret("client-secret") + .clientSecretExpiresAt(Instant.ofEpochSecond(1607637467)) + .clientName("client-name") + .redirectUri("https://client.example.com") + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .scope("scope1") + .scope("scope2") + .jwkSetUrl("https://client.example.com/jwks") + .claim("a-claim", "a-value") + .build(); + // @formatter:on + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.messageConverter.writeInternal(clientRegistration, outputMessage); + + String clientRegistrationResponse = outputMessage.getBodyAsString(); + assertThat(clientRegistrationResponse).contains("\"client_id\":\"client-id\""); + assertThat(clientRegistrationResponse).contains("\"client_id_issued_at\":1607633867"); + assertThat(clientRegistrationResponse).contains("\"client_secret\":\"client-secret\""); + assertThat(clientRegistrationResponse).contains("\"client_secret_expires_at\":1607637467"); + assertThat(clientRegistrationResponse).contains("\"client_name\":\"client-name\""); + assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"https://client.example.com\"]"); + assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_method\":\"client_secret_basic\""); + assertThat(clientRegistrationResponse) + .contains("\"grant_types\":[\"authorization_code\",\"client_credentials\"]"); + assertThat(clientRegistrationResponse).contains("\"response_types\":[\"code\"]"); + assertThat(clientRegistrationResponse).contains("\"scope\":\"scope1 scope2\""); + assertThat(clientRegistrationResponse).contains("\"jwks_uri\":\"https://client.example.com/jwks\""); + assertThat(clientRegistrationResponse).contains("\"a-claim\":\"a-value\""); + } + + @Test + public void writeInternalWhenClientSecretNoExpiryThenSuccess() { + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientSecret("client-secret") + .redirectUri("https://client.example.com") + .build(); + // @formatter:on + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.messageConverter.writeInternal(clientRegistration, outputMessage); + + String clientRegistrationResponse = outputMessage.getBodyAsString(); + assertThat(clientRegistrationResponse).contains("\"client_id\":\"client-id\""); + assertThat(clientRegistrationResponse).contains("\"client_secret\":\"client-secret\""); + assertThat(clientRegistrationResponse).contains("\"client_secret_expires_at\":0"); + assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"https://client.example.com\"]"); + } + + @Test + public void writeInternalWhenWriteFailsThenThrowException() { + String errorMessage = "this is not a valid converter"; + Converter> failingConverter = (source) -> { + throw new RuntimeException(errorMessage); + }; + this.messageConverter.setClientRegistrationParametersConverter(failingConverter); + + // @formatter:off + OAuth2ClientRegistration clientRegistration = OAuth2ClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + // @formatter:off + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + + assertThatExceptionOfType(HttpMessageNotWritableException.class).isThrownBy(() -> this.messageConverter.writeInternal(clientRegistration, outputMessage)) + .withMessageContaining("An error occurred writing the OAuth 2.0 Client Registration") + .withMessageContaining(errorMessage); + } + +} diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java index ade47ff329..51e8ff4611 100644 --- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java @@ -41,6 +41,7 @@ public class AuthorizationServerSettingsTests { assertThat(authorizationServerSettings.getJwkSetEndpoint()).isEqualTo("/oauth2/jwks"); assertThat(authorizationServerSettings.getTokenRevocationEndpoint()).isEqualTo("/oauth2/revoke"); assertThat(authorizationServerSettings.getTokenIntrospectionEndpoint()).isEqualTo("/oauth2/introspect"); + assertThat(authorizationServerSettings.getClientRegistrationEndpoint()).isEqualTo("/oauth2/register"); assertThat(authorizationServerSettings.getOidcClientRegistrationEndpoint()).isEqualTo("/connect/register"); assertThat(authorizationServerSettings.getOidcUserInfoEndpoint()).isEqualTo("/userinfo"); assertThat(authorizationServerSettings.getOidcLogoutEndpoint()).isEqualTo("/connect/logout"); @@ -54,6 +55,7 @@ public class AuthorizationServerSettingsTests { String jwkSetEndpoint = "/oauth2/v1/jwks"; String tokenRevocationEndpoint = "/oauth2/v1/revoke"; String tokenIntrospectionEndpoint = "/oauth2/v1/introspect"; + String clientRegistrationEndpoint = "/oauth2/v1/register"; String oidcClientRegistrationEndpoint = "/connect/v1/register"; String oidcUserInfoEndpoint = "/connect/v1/userinfo"; String oidcLogoutEndpoint = "/connect/v1/logout"; @@ -68,6 +70,7 @@ public class AuthorizationServerSettingsTests { .tokenRevocationEndpoint(tokenRevocationEndpoint) .tokenIntrospectionEndpoint(tokenIntrospectionEndpoint) .tokenRevocationEndpoint(tokenRevocationEndpoint) + .clientRegistrationEndpoint(clientRegistrationEndpoint) .oidcClientRegistrationEndpoint(oidcClientRegistrationEndpoint) .oidcUserInfoEndpoint(oidcUserInfoEndpoint) .oidcLogoutEndpoint(oidcLogoutEndpoint) @@ -82,6 +85,7 @@ public class AuthorizationServerSettingsTests { assertThat(authorizationServerSettings.getJwkSetEndpoint()).isEqualTo(jwkSetEndpoint); assertThat(authorizationServerSettings.getTokenRevocationEndpoint()).isEqualTo(tokenRevocationEndpoint); assertThat(authorizationServerSettings.getTokenIntrospectionEndpoint()).isEqualTo(tokenIntrospectionEndpoint); + assertThat(authorizationServerSettings.getClientRegistrationEndpoint()).isEqualTo(clientRegistrationEndpoint); assertThat(authorizationServerSettings.getOidcClientRegistrationEndpoint()) .isEqualTo(oidcClientRegistrationEndpoint); assertThat(authorizationServerSettings.getOidcUserInfoEndpoint()).isEqualTo(oidcUserInfoEndpoint); @@ -111,6 +115,7 @@ public class AuthorizationServerSettingsTests { assertThat(authorizationServerSettings.getJwkSetEndpoint()).isEqualTo("/oauth2/jwks"); assertThat(authorizationServerSettings.getTokenRevocationEndpoint()).isEqualTo("/oauth2/revoke"); assertThat(authorizationServerSettings.getTokenIntrospectionEndpoint()).isEqualTo("/oauth2/introspect"); + assertThat(authorizationServerSettings.getClientRegistrationEndpoint()).isEqualTo("/oauth2/register"); assertThat(authorizationServerSettings.getOidcClientRegistrationEndpoint()).isEqualTo("/connect/register"); assertThat(authorizationServerSettings.getOidcUserInfoEndpoint()).isEqualTo("/userinfo"); assertThat(authorizationServerSettings.getOidcLogoutEndpoint()).isEqualTo("/connect/logout"); @@ -123,7 +128,7 @@ public class AuthorizationServerSettingsTests { .settings((settings) -> settings.put("name2", "value2")) .build(); - assertThat(authorizationServerSettings.getSettings()).hasSize(14); + assertThat(authorizationServerSettings.getSettings()).hasSize(15); assertThat(authorizationServerSettings.getSetting("name1")).isEqualTo("value1"); assertThat(authorizationServerSettings.getSetting("name2")).isEqualTo("value2"); } @@ -168,6 +173,13 @@ public class AuthorizationServerSettingsTests { .withMessage("value cannot be null"); } + @Test + public void clientRegistrationEndpointWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> AuthorizationServerSettings.builder().clientRegistrationEndpoint(null)) + .withMessage("value cannot be null"); + } + @Test public void oidcClientRegistrationEndpointWhenNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilterTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilterTests.java new file mode 100644 index 0000000000..15d9c505b1 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientRegistrationEndpointFilterTests.java @@ -0,0 +1,415 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.web; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpRequest; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.jwt.JwsHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.TestJwsHeaders; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2ClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link OAuth2ClientRegistrationEndpointFilter}. + * + * @author Joe Grandja + */ +public class OAuth2ClientRegistrationEndpointFilterTests { + + private static final String DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI = "/oauth2/register"; + + private AuthenticationManager authenticationManager; + + private OAuth2ClientRegistrationEndpointFilter filter; + + private final HttpMessageConverter clientRegistrationHttpMessageConverter = new OAuth2ClientRegistrationHttpMessageConverter(); + + private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + + @BeforeEach + public void setup() { + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OAuth2ClientRegistrationEndpointFilter(this.authenticationManager); + } + + @AfterEach + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2ClientRegistrationEndpointFilter(null)) + .withMessage("authenticationManager cannot be null"); + } + + @Test + public void constructorWhenClientRegistrationEndpointUriNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2ClientRegistrationEndpointFilter(this.authenticationManager, null)) + .withMessage("clientRegistrationEndpointUri cannot be empty"); + } + + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + } + + @Test + public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenClientRegistrationRequestGetThenNotProcessed() throws Exception { + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenClientRegistrationRequestInvalidThenInvalidRequestError() throws Exception { + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + request.setContent("invalid content".getBytes()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).startsWith("OAuth 2.0 Client Registration Error: "); + } + + @Test + public void doFilterWhenClientRegistrationRequestInvalidTokenThenUnauthorizedError() throws Exception { + doFilterWhenClientRegistrationRequestInvalidThenError(OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED); + } + + @Test + public void doFilterWhenClientRegistrationRequestInsufficientTokenScopeThenForbiddenError() throws Exception { + doFilterWhenClientRegistrationRequestInvalidThenError(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, + HttpStatus.FORBIDDEN); + } + + private void doFilterWhenClientRegistrationRequestInvalidThenError(String errorCode, HttpStatus status) + throws Exception { + Jwt jwt = createJwt("client.create"); + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + given(this.authenticationManager.authenticate(any())).willThrow(new OAuth2AuthenticationException(errorCode)); + + // @formatter:off + OAuth2ClientRegistration clientRegistrationRequest = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + writeClientRegistrationRequest(request, clientRegistrationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(status.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(errorCode); + } + + @Test + public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception { + // @formatter:off + OAuth2ClientRegistration expectedClientRegistrationResponse = createClientRegistration(); + + OAuth2ClientRegistration clientRegistrationRequest = OAuth2ClientRegistration.builder() + .clientName(expectedClientRegistrationResponse.getClientName()) + .redirectUris((redirectUris) -> redirectUris.addAll(expectedClientRegistrationResponse.getRedirectUris())) + .grantTypes((grantTypes) -> grantTypes.addAll(expectedClientRegistrationResponse.getGrantTypes())) + .scopes((scopes) -> scopes.addAll(expectedClientRegistrationResponse.getScopes())) + .build(); + // @formatter:on + + Jwt jwt = createJwt("client.create"); + JwtAuthenticationToken principal = new JwtAuthenticationToken(jwt, + AuthorityUtils.createAuthorityList("SCOPE_client.create")); + + OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = new OAuth2ClientRegistrationAuthenticationToken( + principal, expectedClientRegistrationResponse); + + given(this.authenticationManager.authenticate(any())).willReturn(clientRegistrationAuthenticationResult); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + writeClientRegistrationRequest(request, clientRegistrationRequest); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.CREATED.value()); + OAuth2ClientRegistration clientRegistrationResponse = readClientRegistrationResponse(response); + assertThat(clientRegistrationResponse.getClientId()) + .isEqualTo(expectedClientRegistrationResponse.getClientId()); + assertThat(clientRegistrationResponse.getClientIdIssuedAt()).isBetween( + expectedClientRegistrationResponse.getClientIdIssuedAt().minusSeconds(1), + expectedClientRegistrationResponse.getClientIdIssuedAt().plusSeconds(1)); + assertThat(clientRegistrationResponse.getClientSecret()) + .isEqualTo(expectedClientRegistrationResponse.getClientSecret()); + assertThat(clientRegistrationResponse.getClientSecretExpiresAt()) + .isEqualTo(expectedClientRegistrationResponse.getClientSecretExpiresAt()); + assertThat(clientRegistrationResponse.getClientName()) + .isEqualTo(expectedClientRegistrationResponse.getClientName()); + assertThat(clientRegistrationResponse.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getRedirectUris()); + assertThat(clientRegistrationResponse.getGrantTypes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getGrantTypes()); + assertThat(clientRegistrationResponse.getResponseTypes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getResponseTypes()); + assertThat(clientRegistrationResponse.getScopes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getScopes()); + assertThat(clientRegistrationResponse.getTokenEndpointAuthenticationMethod()) + .isEqualTo(expectedClientRegistrationResponse.getTokenEndpointAuthenticationMethod()); + } + + @Test + public void doFilterWhenCustomAuthenticationConverterThenUsed() throws ServletException, IOException { + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + this.filter.setAuthenticationConverter(authenticationConverter); + + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationConverter).convert(request); + } + + @Test + public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception { + OAuth2ClientRegistration expectedClientRegistrationResponse = createClientRegistration(); + Authentication principal = new TestingAuthenticationToken("principal", "Credentials"); + + OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = new OAuth2ClientRegistrationAuthenticationToken( + principal, expectedClientRegistrationResponse); + + given(this.authenticationManager.authenticate(any())).willReturn(clientRegistrationAuthenticationResult); + AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(successHandler); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + // @formatter:off + OAuth2ClientRegistration clientRegistrationRequest = OAuth2ClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + writeClientRegistrationRequest(request, clientRegistrationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(successHandler).onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult); + } + + @Test + public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception { + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + String requestUri = DEFAULT_OAUTH2_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + request.setContent("invalid content".getBytes()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationFailureHandler).onAuthenticationFailure(eq(request), eq(response), + any(OAuth2AuthenticationException.class)); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse(response.getContentAsByteArray(), + HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); + } + + private void writeClientRegistrationRequest(MockHttpServletRequest request, + OAuth2ClientRegistration clientRegistration) throws Exception { + MockClientHttpRequest httpRequest = new MockClientHttpRequest(); + this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpRequest); + request.setContent(httpRequest.getBodyAsBytes()); + } + + private OAuth2ClientRegistration readClientRegistrationResponse(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse(response.getContentAsByteArray(), + HttpStatus.valueOf(response.getStatus())); + return this.clientRegistrationHttpMessageConverter.read(OAuth2ClientRegistration.class, httpResponse); + } + + private static OAuth2ClientRegistration createClientRegistration() { + // @formatter:off + return OAuth2ClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + } + + private static Jwt createJwt(String scope) { + // @formatter:off + JwsHeader jwsHeader = TestJwsHeaders.jwsHeader() + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet() + .claim(OAuth2ParameterNames.SCOPE, Collections.singleton(scope)) + .build(); + Jwt jwt = Jwt.withTokenValue("jwt-access-token") + .headers((headers) -> headers.putAll(jwsHeader.getHeaders())) + .claims((claims) -> claims.putAll(jwtClaimsSet.getClaims())) + .build(); + // @formatter:on + return jwt; + } + +}