diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java index 6e33d02183..0eed65e834 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java @@ -21,6 +21,7 @@ import java.util.Collections; import jakarta.servlet.http.HttpServletRequest; +import org.springframework.context.ApplicationContext; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; @@ -28,6 +29,7 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.PortMapper; +import org.springframework.security.web.PortResolver; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -272,6 +274,10 @@ public abstract class AbstractAuthenticationFilterConfigurer C getBeanOrNull(B http, Class clazz) { + ApplicationContext context = http.getSharedObject(ApplicationContext.class); + if (context == null) { + return null; + } + return context.getBeanProvider(clazz).getIfUnique(); + } + @SuppressWarnings("unchecked") private T getSelf() { return (T) this; diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 4dc1af29e3..ad93db75d3 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -83,6 +83,7 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.PortResolver; import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; @@ -578,8 +579,13 @@ public final class OAuth2LoginConfigurer> new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); RequestMatcher formLoginNotEnabled = getFormLoginNotEnabledRequestMatcher(http); LinkedHashMap entryPoints = new LinkedHashMap<>(); + LoginUrlAuthenticationEntryPoint loginUrlEntryPoint = new LoginUrlAuthenticationEntryPoint(providerLoginPage); + PortResolver portResolver = getBeanOrNull(ResolvableType.forClass(PortResolver.class)); + if (portResolver != null) { + loginUrlEntryPoint.setPortResolver(portResolver); + } entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher), - formLoginNotEnabled), new LoginUrlAuthenticationEntryPoint(providerLoginPage)); + formLoginNotEnabled), loginUrlEntryPoint); DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints); loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint()); return loginEntryPoint; diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index b5d683a3ce..1f7efc1829 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -52,6 +52,7 @@ import org.springframework.security.saml2.provider.service.web.authentication.Op import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.PortResolver; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; @@ -344,8 +345,13 @@ public final class Saml2LoginConfigurer> RequestMatcher notXRequestedWith = new NegatedRequestMatcher( new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); LinkedHashMap entryPoints = new LinkedHashMap<>(); + LoginUrlAuthenticationEntryPoint loginUrlEntryPoint = new LoginUrlAuthenticationEntryPoint(providerLoginPage); + PortResolver portResolver = getBeanOrNull(http, PortResolver.class); + if (portResolver != null) { + loginUrlEntryPoint.setPortResolver(portResolver); + } entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher)), - new LoginUrlAuthenticationEntryPoint(providerLoginPage)); + loginUrlEntryPoint); DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints); loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint()); return loginEntryPoint; diff --git a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java index 78eb4a6d0a..333601385e 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java @@ -240,6 +240,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } private RuntimeBeanReference createPortResolver(BeanReference portMapper, ParserContext pc) { + String beanName = "portResolver"; + if (pc.getRegistry().containsBeanDefinition(beanName)) { + return new RuntimeBeanReference(beanName); + } RootBeanDefinition portResolver = new RootBeanDefinition(PortResolverImpl.class); portResolver.getPropertyValues().addPropertyValue("portMapper", portMapper); String portResolverName = pc.getReaderContext().generateBeanName(portResolver); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java index 49b8ed2a1a..bea1f943d8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java @@ -38,6 +38,7 @@ import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders; import org.springframework.security.web.PortMapper; +import org.springframework.security.web.PortResolver; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.ExceptionTranslationFilter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -378,6 +379,13 @@ public class FormLoginConfigurerTests { verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class)); } + @Test + public void configureWhenPortResolverBeanThenPortResolverUsed() throws Exception { + this.spring.register(CustomPortResolverConfig.class).autowire(); + this.mockMvc.perform(get("/requires-authentication")).andExpect(status().is3xxRedirection()); + verify(this.spring.getContext().getBean(PortResolver.class)).getServerPort(any()); + } + @Configuration @EnableWebSecurity static class RequestCacheConfig { @@ -723,6 +731,35 @@ public class FormLoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class CustomPortResolverConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((requests) -> requests + .anyRequest().authenticated() + ) + .formLogin(withDefaults()) + .requestCache(withDefaults()); + return http.build(); + // @formatter:on + } + + @Bean + PortResolver portResolver() { + return mock(PortResolver.class); + } + + @Bean + UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager(PasswordEncodedUser.user()); + } + + } + static class ReflectingObjectPostProcessor implements ObjectPostProcessor { @Override diff --git a/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java b/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java index 52237273df..9088bb443f 100644 --- a/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java @@ -35,6 +35,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.PortResolver; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; @@ -45,6 +46,7 @@ import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -210,6 +212,17 @@ public class FormLoginConfigTests { // @formatter:on } + @Test + public void portResolver() throws Exception { + this.spring.configLocations(this.xml("PortResolverBean")).autowire(); + // @formatter:off + this.mvc.perform(get("/requires-authentication")) + .andExpect(status().is3xxRedirection()); + // @formatter:on + PortResolver portResolver = this.spring.getContext().getBean(PortResolver.class); + verify(portResolver, atLeastOnce()).getServerPort(any()); + } + private Filter getFilter(ApplicationContext context, Class filterClass) { FilterChainProxy filterChain = context.getBean(BeanIds.FILTER_CHAIN_PROXY, FilterChainProxy.class); List filters = filterChain.getFilters("/any"); diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt index 965c361b4a..3d35b8c2fe 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt @@ -34,6 +34,7 @@ import org.springframework.security.core.userdetails.User import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated +import org.springframework.security.web.PortResolver import org.springframework.security.web.SecurityFilterChain import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler @@ -240,6 +241,29 @@ class FormLoginDslTests { } } + @Test + fun `portResolerBean is used`() { + this.spring.register(PortResolverBeanConfig::class.java, AllSecuredConfig::class.java, UserConfig::class.java).autowire() + + val portResolver = this.spring.context.getBean(PortResolver::class.java) + every { portResolver.getServerPort(any()) }.returns(1234) + this.mockMvc.get("/") + .andExpect { + status().isFound + redirectedUrl("http://localhost:1234/login") + } + + verify { portResolver.getServerPort(any()) } + } + + @Configuration + open class PortResolverBeanConfig { + @Bean + open fun portResolverBean(): PortResolver { + return mockk() + } + } + @Test fun `login when custom failure url then used`() { this.spring.register(FailureHandlerConfig::class.java, UserConfig::class.java).autowire() diff --git a/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-PortResolverBean.xml b/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-PortResolverBean.xml new file mode 100644 index 0000000000..0bd1f85ddb --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-PortResolverBean.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + +