diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java
index 6b820bbd3b..00e0169100 100644
--- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java
+++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java
@@ -24,6 +24,7 @@ import org.springframework.beans.factory.support.*;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.beans.factory.xml.XmlReaderContext;
+import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.security.access.vote.ConsensusBased;
import org.springframework.security.config.Elements;
@@ -33,8 +34,10 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher;
+import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
+import org.springframework.util.AntPathMatcher;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element;
@@ -87,6 +90,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
private static final String ACCESS_ATTR = "access";
+ private static final String TYPE_ATTR = "type";
+
/**
* @param element
@@ -105,9 +110,10 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
for(Element interceptMessage : interceptMessages) {
String matcherPattern = interceptMessage.getAttribute(PATTERN_ATTR);
String accessExpression = interceptMessage.getAttribute(ACCESS_ATTR);
- BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpDestinationMessageMatcher.class);
- matcher.addConstructorArgValue(matcherPattern);
- matcherToExpression.put(matcher.getBeanDefinition(), accessExpression);
+ String messageType = interceptMessage.getAttribute(TYPE_ATTR);
+
+ BeanDefinition matcher = createMatcher(matcherPattern, messageType, parserContext, interceptMessage);
+ matcherToExpression.put(matcher, accessExpression);
}
BeanDefinitionBuilder mds = BeanDefinitionBuilder.rootBeanDefinition(ExpressionBasedMessageSecurityMetadataSourceFactory.class);
@@ -137,6 +143,34 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
return null;
}
+ private BeanDefinition createMatcher(String matcherPattern, String messageType, ParserContext parserContext, Element interceptMessage) {
+ boolean hasPattern = StringUtils.hasText(matcherPattern);
+ boolean hasMessageType = StringUtils.hasText(messageType);
+ if(!hasPattern) {
+ BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpMessageTypeMatcher.class);
+ matcher.addConstructorArgValue(messageType);
+ return matcher.getBeanDefinition();
+ }
+
+
+ String factoryName = null;
+ if(hasPattern && hasMessageType) {
+ SimpMessageType type = SimpMessageType.valueOf(messageType);
+ if(SimpMessageType.MESSAGE == type) {
+ factoryName = "createMessageMatcher";
+ } else if(SimpMessageType.SUBSCRIBE == type) {
+ factoryName = "createSubscribeMatcher";
+ } else {
+ parserContext.getReaderContext().error("Cannot use intercept-websocket@message-type="+messageType+" with a pattern because the type does not have a destination.", interceptMessage);
+ }
+ }
+ BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpDestinationMessageMatcher.class);
+ matcher.setFactoryMethod(factoryName);
+ matcher.addConstructorArgValue(matcherPattern);
+ matcher.addConstructorArgValue(new RootBeanDefinition(AntPathMatcher.class));
+ return matcher.getBeanDefinition();
+ }
+
static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor {
private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel";
diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-4.0.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-4.0.rnc
index 6546af2f5a..e1671b6922 100644
--- a/config/src/main/resources/org/springframework/security/config/spring-security-4.0.rnc
+++ b/config/src/main/resources/org/springframework/security/config/spring-security-4.0.rnc
@@ -290,6 +290,9 @@ intercept-message.attrlist &=
intercept-message.attrlist &=
## The access configuration attributes that apply for the configured message. For example, permitAll grants access to anyone, hasRole('ROLE_ADMIN') requires the user have the role 'ROLE_ADMIN'.
attribute access {xsd:token}?
+intercept-message.attrlist &=
+ ## The type of message to match on. Valid values are defined in SimpMessageType (i.e. CONNECT, CONNECT_ACK, HEARTBEAT, MESSAGE, SUBSCRIBE, UNSUBSCRIBE, DISCONNECT, DISCONNECT_ACK, OTHER).
+ attribute type {"CONNECT" | "CONNECT_ACK" | "HEARTBEAT" | "MESSAGE" | "SUBSCRIBE"| "UNSUBSCRIBE" | "DISCONNECT" | "DISCONNECT_ACK" | "OTHER"}?
http-firewall =
## Allows a custom instance of HttpFirewall to be injected into the FilterChainProxy created by the namespace.
diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-4.0.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-4.0.xsd
index 22988b0743..fd21d91486 100644
--- a/config/src/main/resources/org/springframework/security/config/spring-security-4.0.xsd
+++ b/config/src/main/resources/org/springframework/security/config/spring-security-4.0.xsd
@@ -897,6 +897,27 @@
+
+
+ The type of message to match on. Valid values are defined in SimpMessageType (i.e.
+ CONNECT, CONNECT_ACK, HEARTBEAT, MESSAGE, SUBSCRIBE, UNSUBSCRIBE, DISCONNECT,
+ DISCONNECT_ACK, OTHER).
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/config/src/test/groovy/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.groovy
index f29808ffe9..7d41bd199d 100644
--- a/config/src/test/groovy/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.groovy
+++ b/config/src/test/groovy/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.groovy
@@ -3,6 +3,7 @@ package org.springframework.security.config.websocket
import org.springframework.beans.BeansException
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory
+import org.springframework.beans.factory.parsing.BeanDefinitionParsingException
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor
import org.springframework.beans.factory.support.RootBeanDefinition
@@ -17,8 +18,10 @@ import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse
import org.springframework.security.core.Authentication
import org.springframework.security.core.annotation.AuthenticationPrincipal
+import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher
import org.springframework.security.web.csrf.CsrfToken
import org.springframework.security.web.csrf.DefaultCsrfToken
+import org.springframework.security.web.csrf.InvalidCsrfTokenException
import org.springframework.security.web.csrf.MissingCsrfTokenException
import org.springframework.stereotype.Controller
import org.springframework.web.servlet.HandlerMapping
@@ -30,6 +33,7 @@ import org.springframework.web.socket.server.support.HttpSessionHandshakeInterce
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
+import spock.lang.Unroll
import static org.mockito.Mockito.*
@@ -50,6 +54,7 @@ import org.springframework.security.core.context.SecurityContextHolder
class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
boolean useSockJS = false
+ CsrfToken csrfToken = new DefaultCsrfToken('headerName', 'paramName', 'token')
def cleanup() {
SecurityContextHolder.clearContext()
@@ -89,6 +94,75 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
noExceptionThrown()
}
+ @Unroll
+ def "message type - #type"(SimpMessageType type) {
+ setup:
+ websocket {
+ 'intercept-message'('type': type.toString(), access:'permitAll')
+ 'intercept-message'(pattern:'/**', access:'denyAll')
+ }
+ messageUser = null
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
+ if(SimpMessageType.CONNECT == type) {
+ headers.setNativeHeader(csrfToken.headerName, csrfToken.token)
+ }
+ Message message = message(headers, '/permitAll')
+
+ when: 'message is sent to the permitAll endpoint with no user'
+ clientInboundChannel.send(message)
+
+ then: 'access is granted'
+ noExceptionThrown()
+
+ where:
+ type << SimpMessageType.values()
+ }
+
+ @Unroll
+ def "pattern and message type - #type"(SimpMessageType type) {
+ setup:
+ websocket {
+ 'intercept-message'(pattern: '/permitAll', 'type': type.toString(), access:'permitAll')
+ 'intercept-message'(pattern:'/**', access:'denyAll')
+ }
+
+ when: 'message is sent to the permitAll endpoint with no user'
+ clientInboundChannel.send(message('/permitAll', type))
+
+ then: 'access is granted'
+ noExceptionThrown()
+
+ when: 'message sent to other message type'
+ clientInboundChannel.send(message('/permitAll', SimpMessageType.UNSUBSCRIBE))
+
+ then: 'does not match'
+ MessageDeliveryException e = thrown()
+ e.cause instanceof AccessDeniedException
+
+ when: 'message is sent to other pattern'
+ clientInboundChannel.send(message('/other', type))
+
+ then: 'does not match'
+ MessageDeliveryException eOther = thrown()
+ eOther.cause instanceof AccessDeniedException
+
+ where:
+ type << [SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE]
+ }
+
+ @Unroll
+ def "intercept-message with invalid type and pattern - #type"(SimpMessageType type) {
+ when:
+ websocket {
+ 'intercept-message'(pattern : '/**', 'type': type.toString(), access:'permitAll')
+ }
+ then:
+ thrown(BeanDefinitionParsingException)
+
+ where:
+ type << [SimpMessageType.CONNECT, SimpMessageType.CONNECT_ACK, SimpMessageType.DISCONNECT, SimpMessageType.DISCONNECT_ACK, SimpMessageType.HEARTBEAT, SimpMessageType.OTHER, SimpMessageType.UNSUBSCRIBE ]
+ }
+
def 'messages with no id automatically adds Authentication argument resolver'() {
setup:
def id = 'authenticationController'
@@ -186,7 +260,7 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
then: 'CSRF Protection blocks the Message'
MessageDeliveryException expected = thrown()
- expected.cause instanceof MissingCsrfTokenException
+ expected.cause instanceof InvalidCsrfTokenException
}
def 'websocket with no id does not override customArgumentResolvers'() {
@@ -314,8 +388,8 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
appContext.getBean("clientInboundChannel")
}
- def message(String destination) {
- SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
+ def message(String destination, SimpMessageType type=SimpMessageType.MESSAGE) {
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
message(headers, destination)
}
@@ -327,6 +401,9 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
if(messageUser != null) {
headers.user = messageUser
}
+ if(csrfToken != null) {
+ headers.sessionAttributes[CsrfToken.name] = csrfToken
+ }
new GenericMessage("hi",headers.messageHeaders)
}
diff --git a/docs/manual/src/docs/asciidoc/index.adoc b/docs/manual/src/docs/asciidoc/index.adoc
index f2634f2f83..360e765507 100644
--- a/docs/manual/src/docs/asciidoc/index.adoc
+++ b/docs/manual/src/docs/asciidoc/index.adoc
@@ -7840,6 +7840,10 @@ Defines an authorization rule for a message.
[[nsa-intercept-message-pattern]]
* **pattern** An ant based pattern that matches on the Message destination. For example, "/**" matches any Message with a destination; "/admin/**" matches any Message that has a destination that starts with "/admin/**".
+[[nsa-intercept-message-type]]
+* **type** The type of message to match on.
+Valid values are defined in SimpMessageType (i.e. CONNECT, CONNECT_ACK, HEARTBEAT, MESSAGE, SUBSCRIBE, UNSUBSCRIBE, DISCONNECT, DISCONNECT_ACK, OTHER).
+
[[nsa-intercept-message-access]]
* **access** The expression used to secure the Message. For example, "denyAll" will deny access to all of the matching Messages; "permitAll" will grant access to all of the matching Messages; "hasRole('ADMIN') requires the current user to have the role 'ROLE_ADMIN' for the matching Messages.