diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index 498fa81a57..848094274c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,23 +43,33 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private final AtomicReference> cachedJWKSet = new AtomicReference<>(Mono.empty()); + /** + * cached url for jwk set. + */ + private final AtomicReference cachedJwkSetUrl = new AtomicReference<>(); + private WebClient webClient = WebClient.create(); - private final String jwkSetURL; + private Mono jwkSetURLProvider; ReactiveRemoteJWKSource(String jwkSetURL) { Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty"); - this.jwkSetURL = jwkSetURL; + this.cachedJwkSetUrl.set(jwkSetURL); + } + + ReactiveRemoteJWKSource(Mono jwkSetURLProvider) { + Assert.notNull(jwkSetURLProvider, "jwkSetURLProvider cannot be null"); + this.jwkSetURLProvider = jwkSetURLProvider; } @Override public Mono> get(JWKSelector jwkSelector) { // @formatter:off return this.cachedJWKSet.get() - .switchIfEmpty(Mono.defer(() -> getJWKSet())) + .switchIfEmpty(Mono.defer(this::getJWKSet)) .flatMap((jwkSet) -> get(jwkSelector, jwkSet)) .switchIfEmpty(Mono.defer(() -> getJWKSet() - .map((jwkSet) -> jwkSelector.select(jwkSet))) + .map(jwkSelector::select)) ); // @formatter:on } @@ -95,13 +105,18 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private Mono getJWKSet() { // @formatter:off - return this.webClient.get() - .uri(this.jwkSetURL) - .retrieve() - .bodyToMono(String.class) + return Mono.justOrEmpty(this.cachedJwkSetUrl.get()) + .switchIfEmpty(Mono.defer(() -> this.jwkSetURLProvider + .doOnNext(this.cachedJwkSetUrl::set)) + ) + .flatMap((jwkSetURL) -> this.webClient.get() + .uri(jwkSetURL) + .retrieve() + .bodyToMono(String.class) + ) .map(this::parse) .doOnNext((jwkSet) -> this.cachedJWKSet - .set(Mono.just(jwkSet)) + .set(Mono.just(jwkSet)) ) .cache(); // @formatter:on diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index ddcc1c913f..2f0881fc22 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.jwt; import java.util.Collections; import java.util.List; +import java.util.function.Supplier; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; @@ -31,10 +32,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; /** * @author Rob Winch @@ -52,6 +58,9 @@ public class ReactiveRemoteJWKSourceTests { private MockWebServer server; + @Mock + private Supplier mockStringSupplier; + // @formatter:off private String keys = "{\n" + " \"keys\": [\n" @@ -156,4 +165,18 @@ public class ReactiveRemoteJWKSourceTests { assertThat(this.source.get(this.selector).block()).isEmpty(); } + @Test + public void getShouldRecoverAndReturnKeysAfterErrorCase() { + given(this.matcher.matches(any())).willReturn(true); + this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(mockStringSupplier)); + doThrow(WebClientResponseException.ServiceUnavailable.class).when(this.mockStringSupplier).get(); + // first case: id provider has error state + assertThatThrownBy(() -> this.source.get(this.selector).block()) + .isExactlyInstanceOf(WebClientResponseException.ServiceUnavailable.class); + // second case: id provider is healthy again + doReturn(this.server.url("/").toString()).when(this.mockStringSupplier).get(); + var actual = this.source.get(this.selector).block(); + assertThat(actual).isNotEmpty(); + } + }