diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java new file mode 100644 index 0000000000..2947e90ff5 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.security.Key; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.SecurityContext; + +/** + * Class for delegating to a Nimbus JWSKeySelector by the given JWSAlgorithm + * + * @author Josh Cummings + */ +class JWSAlgorithmMapJWSKeySelector implements JWSKeySelector { + private Map> jwsKeySelectors; + + JWSAlgorithmMapJWSKeySelector(Map> jwsKeySelectors) { + this.jwsKeySelectors = jwsKeySelectors; + } + + @Override + public List selectJWSKeys(JWSHeader header, C context) throws KeySourceException { + JWSKeySelector keySelector = this.jwsKeySelectors.get(header.getAlgorithm()); + if (keySelector == null) { + throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm()); + } + return keySelector.selectJWSKeys(header, context); + } + + public Set getExpectedJWSAlgorithms() { + return this.jwsKeySelectors.keySet(); + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 814d5ff06a..4c4df529ef 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -23,8 +23,12 @@ import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.time.Instant; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; import javax.crypto.SecretKey; import com.nimbusds.jose.JWSAlgorithm; @@ -209,7 +213,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { */ public static final class JwkSetUriJwtDecoderBuilder { private String jwkSetUri; - private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256; + private Set signatureAlgorithms = new HashSet<>(); private RestOperations restOperations = new RestTemplate(); private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { @@ -218,15 +222,30 @@ public final class NimbusJwtDecoder implements JwtDecoder { } /** - * Use the given signing - * algorithm. + * Append the given signing + * algorithm + * to the set of algorithms to use. * * @param signatureAlgorithm the algorithm to use * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations */ public JwkSetUriJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) { Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null"); - this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + this.signatureAlgorithms.add(signatureAlgorithm); + return this; + } + + /** + * Configure the list of + * algorithms + * to use with the given {@link Consumer}. + * + * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list + * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations + */ + public JwkSetUriJwtDecoderBuilder jwsAlgorithms(Consumer> signatureAlgorithmsConsumer) { + Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null"); + signatureAlgorithmsConsumer.accept(this.signatureAlgorithms); return this; } @@ -245,13 +264,27 @@ public final class NimbusJwtDecoder implements JwtDecoder { return this; } + JWSKeySelector jwsKeySelector(JWKSource jwkSource) { + if (this.signatureAlgorithms.isEmpty()) { + return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); + } else if (this.signatureAlgorithms.size() == 1) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName()); + return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource); + } else { + Map> jwsKeySelectors = new HashMap<>(); + for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { + JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource)); + } + return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors); + } + } + JWTProcessor processor() { ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); JWKSource jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever); - JWSKeySelector jwsKeySelector = - new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); - jwtProcessor.setJWSKeySelector(jwsKeySelector); + jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); // Spring Security validates the claim set independent from Nimbus jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 1faf5e295d..fa5e4062a9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -18,8 +18,12 @@ package org.springframework.security.oauth2.jwt; import java.security.interfaces.RSAPublicKey; import java.time.Instant; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; import javax.crypto.SecretKey; @@ -31,6 +35,7 @@ import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKSelector; import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.BadJOSEException; import com.nimbusds.jose.proc.JWKSecurityContext; import com.nimbusds.jose.proc.JWSKeySelector; @@ -233,7 +238,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { */ public static final class JwkSetUriReactiveJwtDecoderBuilder { private final String jwkSetUri; - private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256; + private Set signatureAlgorithms = new HashSet<>(); private WebClient webClient = WebClient.create(); private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { @@ -242,15 +247,30 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } /** - * Use the given signing - * algorithm. + * Append the given signing + * algorithm + * to the set of algorithms to use. * * @param signatureAlgorithm the algorithm to use * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations */ public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) { Assert.notNull(signatureAlgorithm, "sig cannot be null"); - this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + this.signatureAlgorithms.add(signatureAlgorithm); + return this; + } + + /** + * Configure the list of + * algorithms + * to use with the given {@link Consumer}. + * + * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list + * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations + */ + public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms(Consumer> signatureAlgorithmsConsumer) { + Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null"); + signatureAlgorithmsConsumer.accept(this.signatureAlgorithms); return this; } @@ -278,28 +298,53 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new NimbusReactiveJwtDecoder(processor()); } + JWSKeySelector jwsKeySelector(JWKSource jwkSource) { + if (this.signatureAlgorithms.isEmpty()) { + return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); + } else if (this.signatureAlgorithms.size() == 1) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName()); + return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource); + } else { + Map> jwsKeySelectors = new HashMap<>(); + for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { + JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource)); + } + return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors); + } + } + Converter> processor() { JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); - - JWSKeySelector jwsKeySelector = - new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); + JWSKeySelector jwsKeySelector = jwsKeySelector(jwkSource); jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); source.setWebClient(this.webClient); + Set expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector); return jwt -> { - JWKSelector selector = createSelector(jwt.getHeader()); + JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader()); return source.get(selector) .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) .map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList))); }; } - private JWKSelector createSelector(Header header) { - if (!this.jwsAlgorithm.equals(header.getAlgorithm())) { + private Set getExpectedJwsAlgorithms(JWSKeySelector jwsKeySelector) { + if (jwsKeySelector instanceof JWSVerificationKeySelector) { + return Collections.singleton(((JWSVerificationKeySelector) jwsKeySelector).getExpectedJWSAlgorithm()); + } + if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) { + return ((JWSAlgorithmMapJWSKeySelector) jwsKeySelector).getExpectedJWSAlgorithms(); + } + throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass()); + } + + private JWKSelector createSelector(Set expectedJwsAlgorithms, Header header) { + if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) { throw new JwtException("Unsupported algorithm of " + header.getAlgorithm()); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index 8a67875869..e194267569 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -39,7 +39,10 @@ import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.crypto.MACSigner; import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; @@ -357,6 +360,46 @@ public class NimbusJwtDecoderTests { .isEqualTo("test-subject"); } + @Test + public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(JWK_SET_URI).jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsVerificationKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) + .isEqualTo(JWSAlgorithm.RS256); + } + + @Test + public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(JWK_SET_URI).jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsVerificationKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) + .isEqualTo(JWSAlgorithm.RS512); + } + + @Test + public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(JWK_SET_URI) + .jwsAlgorithm(SignatureAlgorithm.RS256) + .jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector); + JWSAlgorithmMapJWSKeySelector jwsAlgorithmMapKeySelector = + (JWSAlgorithmMapJWSKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms()) + .containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512); + } + private RSAPublicKey key() throws InvalidKeySpecException { byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes()); EncodedKeySpec spec = new X509EncodedKeySpec(decoded); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index b7ae963b10..59a361141a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -16,31 +16,6 @@ package org.springframework.security.oauth2.jwt; -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.JWSHeader; -import com.nimbusds.jose.JWSSigner; -import com.nimbusds.jose.crypto.MACSigner; -import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jwt.JWTClaimsSet; -import com.nimbusds.jwt.SignedJWT; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; -import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2TokenValidator; -import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; -import org.springframework.security.oauth2.jose.TestKeys; -import org.springframework.security.oauth2.jose.jws.MacAlgorithm; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import javax.crypto.SecretKey; import java.net.UnknownHostException; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; @@ -54,12 +29,49 @@ import java.util.Base64; import java.util.Collections; import java.util.Date; import java.util.Map; +import javax.crypto.SecretKey; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.crypto.MACSigner; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.JWKSecurityContext; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.Mockito.*; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri; +import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSource; +import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withPublicKey; +import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey; /** * @author Rob Winch @@ -363,6 +375,46 @@ public class NimbusReactiveJwtDecoderTests { .isInstanceOf(JwtException.class); } + @Test + public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(this.jwkSetUri).jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsVerificationKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) + .isEqualTo(JWSAlgorithm.RS256); + } + + @Test + public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(this.jwkSetUri).jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsVerificationKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) + .isEqualTo(JWSAlgorithm.RS512); + } + + @Test + public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() { + JWKSource jwkSource = mock(JWKSource.class); + JWSKeySelector jwsKeySelector = + withJwkSetUri(this.jwkSetUri) + .jwsAlgorithm(SignatureAlgorithm.RS256) + .jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector); + JWSAlgorithmMapJWSKeySelector jwsAlgorithmMapKeySelector = + (JWSAlgorithmMapJWSKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms()) + .containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512); + } + private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception { SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet); JWSSigner signer = new MACSigner(secretKey);