1
0
mirror of synced 2026-05-22 13:23:17 +00:00

Mark Observations with Firewall Failures

Closes gh-11994
This commit is contained in:
Josh Cummings
2022-10-11 17:44:29 -06:00
parent 46ab84684b
commit 2713075d08
4 changed files with 70 additions and 4 deletions
@@ -57,6 +57,7 @@ import org.springframework.security.web.access.intercept.AuthorizationFilter;
import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
import org.springframework.security.web.debug.DebugFilter;
import org.springframework.security.web.firewall.HttpFirewall;
import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
import org.springframework.security.web.firewall.RequestRejectedHandler;
import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -307,6 +308,10 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
if (this.requestRejectedHandler != null) {
filterChainProxy.setRequestRejectedHandler(this.requestRejectedHandler);
}
else if (!this.observationRegistry.isNoop()) {
filterChainProxy
.setRequestRejectedHandler(new ObservationMarkingRequestRejectedHandler(this.observationRegistry));
}
filterChainProxy.setFilterChainDecorator(getFilterChainDecorator());
filterChainProxy.afterPropertiesSet();
@@ -319,6 +324,7 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
+ "********************************************************************\n\n");
result = new DebugFilter(filterChainProxy);
}
this.postBuildAction.run();
return result;
}
@@ -40,7 +40,7 @@ public class HttpFirewallBeanDefinitionParser implements BeanDefinitionParser {
pc.getReaderContext().error("ref attribute is required", pc.extractSource(element));
}
// Ensure the FCP is registered.
HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, element);
BeanDefinition filterChainProxy = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY);
filterChainProxy.getPropertyValues().addPropertyValue("firewall", new RuntimeBeanReference(ref));
return null;
@@ -58,6 +58,7 @@ import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.ObservationFilterChainDecorator;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
@@ -120,7 +121,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(),
pc.extractSource(element));
pc.pushContainingComponent(compositeDef);
registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
registerFilterChainProxyIfNecessary(pc, element);
// Obtain the filter chains and add the new chain to it
BeanDefinition listFactoryBean = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAINS);
List<BeanReference> filterChains = (List<BeanReference>) listFactoryBean.getPropertyValues()
@@ -351,7 +352,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
return customFilters;
}
static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) {
static void registerFilterChainProxyIfNecessary(ParserContext pc, Element element) {
Object source = pc.extractSource(element);
BeanDefinitionRegistry registry = pc.getRegistry();
if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) {
return;
@@ -378,6 +380,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
requestRejected.addConstructorArgValue("requestRejectedHandler");
requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY);
requestRejected.addConstructorArgValue("requestRejectedHandler");
requestRejected.addPropertyValue("observationRegistry", getObservationRegistry(element));
AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition();
String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean);
registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean);
@@ -391,7 +394,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
return BeanDefinitionBuilder.rootBeanDefinition(ObservationRegistryFactory.class).getBeanDefinition();
}
static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
public static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
private final String beanName;
@@ -399,6 +402,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
private final String targetPropertyName;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) {
this.beanName = beanName;
this.targetBeanName = targetBeanName;
@@ -412,6 +417,13 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
beanDefinition.getPropertyValues().add(this.targetPropertyName,
new RuntimeBeanReference(this.beanName));
}
else if (!this.observationRegistry.isNoop()) {
BeanDefinition observable = BeanDefinitionBuilder
.rootBeanDefinition(ObservationMarkingRequestRejectedHandler.class)
.addConstructorArgValue(this.observationRegistry).getBeanDefinition();
BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName);
beanDefinition.getPropertyValues().add(this.targetPropertyName, observable);
}
}
@Override
@@ -419,6 +431,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
}
public void setObservationRegistry(ObservationRegistry registry) {
this.observationRegistry = registry;
}
}
/**