diff --git a/config/spring-security-config.gradle b/config/spring-security-config.gradle
index 580a2bdbc2..718d717a99 100644
--- a/config/spring-security-config.gradle
+++ b/config/spring-security-config.gradle
@@ -15,10 +15,12 @@ dependencies {
optional project(':spring-security-oauth2-jose')
optional project(':spring-security-oauth2-resource-server')
optional project(':spring-security-openid')
+ optional project(':spring-security-rsocket')
optional project(':spring-security-web')
optional 'io.projectreactor:reactor-core'
optional 'org.aspectj:aspectjweaver'
optional 'org.springframework:spring-jdbc'
+ optional 'org.springframework:spring-messaging'
optional 'org.springframework:spring-tx'
optional 'org.springframework:spring-webmvc'
optional'org.springframework:spring-web'
@@ -39,6 +41,7 @@ dependencies {
testCompile 'com.squareup.okhttp3:mockwebserver'
testCompile 'ch.qos.logback:logback-classic'
testCompile 'io.projectreactor.netty:reactor-netty'
+ testCompile 'io.rsocket:rsocket-transport-netty'
testCompile 'javax.annotation:jsr250-api:1.0'
testCompile 'javax.xml.bind:jaxb-api'
testCompile 'ldapsdk:ldapsdk:4.1'
diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java
new file mode 100644
index 0000000000..e4dce801f9
--- /dev/null
+++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.rsocket;
+
+import org.springframework.context.annotation.Import;
+
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * Add this annotation to a {@code Configuration} class to have Spring Security
+ * {@link RSocketSecurity} support added.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ * @see RSocketSecurity
+ */
+@Documented
+@Target(ElementType.TYPE)
+@Retention(RetentionPolicy.RUNTIME)
+@Import({ RSocketSecurityConfiguration.class })
+public @interface EnableRSocketSecurity { }
diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java
new file mode 100644
index 0000000000..274dc4b539
--- /dev/null
+++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java
@@ -0,0 +1,313 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.rsocket;
+
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.core.ResolvableType;
+import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager;
+import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager;
+import org.springframework.security.authorization.AuthorizationDecision;
+import org.springframework.security.authorization.ReactiveAuthorizationManager;
+import org.springframework.security.config.Customizer;
+import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
+import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptor;
+import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor;
+import org.springframework.security.rsocket.interceptor.authentication.AnonymousPayloadInterceptor;
+import org.springframework.security.rsocket.interceptor.authentication.AuthenticationPayloadInterceptor;
+import org.springframework.security.rsocket.interceptor.authentication.BearerPayloadExchangeConverter;
+import org.springframework.security.rsocket.interceptor.authorization.AuthorizationPayloadInterceptor;
+import org.springframework.security.rsocket.interceptor.authorization.PayloadExchangeMatcherReactiveAuthorizationManager;
+import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcher;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry;
+import org.springframework.security.rsocket.util.PayloadExchangeMatchers;
+import org.springframework.security.rsocket.util.RoutePayloadExchangeMatcher;
+import reactor.core.publisher.Mono;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Allows configuring RSocket based security.
+ *
+ * A minimal example can be found below:
+ *
+ *
+ * @EnableRSocketSecurity
+ * public class SecurityConfig {
+ * // @formatter:off
+ * @Bean
+ * PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ * rsocket
+ * .authorizePayload(authorize ->
+ * authorize
+ * .anyRequest().authenticated()
+ * );
+ * return rsocket.build();
+ * }
+ * // @formatter:on
+ *
+ * // @formatter:off
+ * @Bean
+ * public MapReactiveUserDetailsService userDetailsService() {
+ * UserDetails user = User.withDefaultPasswordEncoder()
+ * .username("user")
+ * .password("password")
+ * .roles("USER")
+ * .build();
+ * return new MapReactiveUserDetailsService(user);
+ * }
+ * // @formatter:on
+ * }
+ *
+ *
+ * A more advanced configuration can be seen below:
+ *
+ *
+ * @EnableRSocketSecurity
+ * public class SecurityConfig {
+ * // @formatter:off
+ * @Bean
+ * PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ * rsocket
+ * .authorizePayload(authorize ->
+ * authorize
+ * // must have ROLE_SETUP to make connection
+ * .setup().hasRole("SETUP")
+ * // must have ROLE_ADMIN for routes starting with "admin."
+ * .route("admin.*").hasRole("ADMIN")
+ * // any other request must be authenticated for
+ * .anyRequest().authenticated()
+ * );
+ * return rsocket.build();
+ * }
+ * // @formatter:on
+ * }
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class RSocketSecurity {
+
+ private BasicAuthenticationSpec basicAuthSpec;
+
+ private JwtSpec jwtSpec;
+
+ private AuthorizePayloadsSpec authorizePayload;
+
+ private ApplicationContext context;
+
+ private ReactiveAuthenticationManager authenticationManager;
+
+ public RSocketSecurity authenticationManager(ReactiveAuthenticationManager authenticationManager) {
+ this.authenticationManager = authenticationManager;
+ return this;
+ }
+
+ public RSocketSecurity basicAuthentication(Customizer basic) {
+ if (this.basicAuthSpec == null) {
+ this.basicAuthSpec = new BasicAuthenticationSpec();
+ }
+ basic.customize(this.basicAuthSpec);
+ return this;
+ }
+
+ public class BasicAuthenticationSpec {
+ private ReactiveAuthenticationManager authenticationManager;
+
+ public BasicAuthenticationSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) {
+ this.authenticationManager = authenticationManager;
+ return this;
+ }
+
+ private ReactiveAuthenticationManager getAuthenticationManager() {
+ if (this.authenticationManager == null) {
+ return RSocketSecurity.this.authenticationManager;
+ }
+ return this.authenticationManager;
+ }
+
+ protected AuthenticationPayloadInterceptor build() {
+ ReactiveAuthenticationManager manager = getAuthenticationManager();
+ return new AuthenticationPayloadInterceptor(manager);
+ }
+
+ private BasicAuthenticationSpec() {}
+ }
+
+ public RSocketSecurity jwt(Customizer jwt) {
+ if (this.jwtSpec == null) {
+ this.jwtSpec = new JwtSpec();
+ }
+ jwt.customize(this.jwtSpec);
+ return this;
+ }
+
+ public class JwtSpec {
+ private ReactiveAuthenticationManager authenticationManager;
+
+ public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) {
+ this.authenticationManager = authenticationManager;
+ return this;
+ }
+
+ private ReactiveAuthenticationManager getAuthenticationManager() {
+ if (this.authenticationManager != null) {
+ return this.authenticationManager;
+ }
+ ReactiveJwtDecoder jwtDecoder = getBeanOrNull(ReactiveJwtDecoder.class);
+ if (jwtDecoder != null) {
+ this.authenticationManager = new JwtReactiveAuthenticationManager(jwtDecoder);
+ return this.authenticationManager;
+ }
+ return RSocketSecurity.this.authenticationManager;
+ }
+
+ protected AuthenticationPayloadInterceptor build() {
+ ReactiveAuthenticationManager manager = getAuthenticationManager();
+ AuthenticationPayloadInterceptor result = new AuthenticationPayloadInterceptor(manager);
+ result.setAuthenticationConverter(new BearerPayloadExchangeConverter());
+ return result;
+ }
+
+ private JwtSpec() {}
+ }
+
+ public RSocketSecurity authorizePayload(Customizer authorize) {
+ if (this.authorizePayload == null) {
+ this.authorizePayload = new AuthorizePayloadsSpec();
+ }
+ authorize.customize(this.authorizePayload);
+ return this;
+ }
+
+ public PayloadSocketAcceptorInterceptor build() {
+ PayloadSocketAcceptorInterceptor interceptor = new PayloadSocketAcceptorInterceptor(
+ payloadInterceptors());
+ RSocketMessageHandler handler = getBean(RSocketMessageHandler.class);
+ interceptor.setDefaultDataMimeType(handler.getDefaultDataMimeType());
+ interceptor.setDefaultMetadataMimeType(handler.getDefaultMetadataMimeType());
+ return interceptor;
+ }
+
+ private List payloadInterceptors() {
+ List payloadInterceptors = new ArrayList<>();
+
+ if (this.basicAuthSpec != null) {
+ payloadInterceptors.add(this.basicAuthSpec.build());
+ }
+ if (this.jwtSpec != null) {
+ payloadInterceptors.add(this.jwtSpec.build());
+ }
+ payloadInterceptors.add(new AnonymousPayloadInterceptor("anonymousUser"));
+
+ if (this.authorizePayload != null) {
+ payloadInterceptors.add(this.authorizePayload.build());
+ }
+ return payloadInterceptors;
+ }
+
+ public class AuthorizePayloadsSpec {
+
+ private PayloadExchangeMatcherReactiveAuthorizationManager.Builder authzBuilder =
+ PayloadExchangeMatcherReactiveAuthorizationManager.builder();
+
+ public Access setup() {
+ return matcher(PayloadExchangeMatchers.setup());
+ }
+
+ public Access anyRequest() {
+ return matcher(PayloadExchangeMatchers.anyExchange());
+ }
+
+ protected AuthorizationPayloadInterceptor build() {
+ return new AuthorizationPayloadInterceptor(this.authzBuilder.build());
+ }
+
+ public Access route(String pattern) {
+ RSocketMessageHandler handler = getBean(RSocketMessageHandler.class);
+ PayloadExchangeMatcher matcher = new RoutePayloadExchangeMatcher(
+ handler.getMetadataExtractor(),
+ handler.getRouteMatcher(),
+ pattern);
+ return matcher(matcher);
+ }
+
+ public Access matcher(PayloadExchangeMatcher matcher) {
+ return new Access(matcher);
+ }
+
+ public class Access {
+
+ private final PayloadExchangeMatcher matcher;
+
+ private Access(PayloadExchangeMatcher matcher) {
+ this.matcher = matcher;
+ }
+
+ public AuthorizePayloadsSpec authenticated() {
+ return access(AuthenticatedReactiveAuthorizationManager.authenticated());
+ }
+
+ public AuthorizePayloadsSpec hasRole(String role) {
+ return access(AuthorityReactiveAuthorizationManager.hasRole(role));
+ }
+
+ public AuthorizePayloadsSpec permitAll() {
+ return access((a, ctx) -> Mono
+ .just(new AuthorizationDecision(true)));
+ }
+
+ public AuthorizePayloadsSpec access(
+ ReactiveAuthorizationManager authorization) {
+ AuthorizePayloadsSpec.this.authzBuilder.add(new PayloadExchangeMatcherEntry<>(this.matcher, authorization));
+ return AuthorizePayloadsSpec.this;
+ }
+ }
+ }
+
+ private T getBean(Class beanClass) {
+ if (this.context == null) {
+ return null;
+ }
+ return this.context.getBean(beanClass);
+ }
+
+ private T getBeanOrNull(Class beanClass) {
+ return getBeanOrNull(ResolvableType.forClass(beanClass));
+ }
+
+ private T getBeanOrNull(ResolvableType type) {
+ if (this.context == null) {
+ return null;
+ }
+ String[] names = this.context.getBeanNamesForType(type);
+ if (names.length == 1) {
+ return (T) this.context.getBean(names[0]);
+ }
+ return null;
+ }
+
+ protected void setApplicationContext(ApplicationContext applicationContext)
+ throws BeansException {
+ this.context = applicationContext;
+ }
+}
diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java
new file mode 100644
index 0000000000..fdf9bd31bc
--- /dev/null
+++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.rsocket;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.annotation.Scope;
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.authentication.UserDetailsRepositoryReactiveAuthenticationManager;
+import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
+import org.springframework.security.crypto.password.PasswordEncoder;
+
+/**
+ * @author Rob Winch
+ * @since 5.2
+ */
+@Configuration(proxyBeanMethods = false)
+class RSocketSecurityConfiguration {
+
+ private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.rsocket.RSocketSecurityConfiguration.";
+ private static final String RSOCKET_SECURITY_BEAN_NAME = BEAN_NAME_PREFIX + "rsocketSecurity";
+
+ private ReactiveAuthenticationManager authenticationManager;
+
+ private ReactiveUserDetailsService reactiveUserDetailsService;
+
+ private PasswordEncoder passwordEncoder;
+
+ @Autowired(required = false)
+ void setAuthenticationManager(
+ ReactiveAuthenticationManager authenticationManager) {
+ this.authenticationManager = authenticationManager;
+ }
+
+ @Autowired(required = false)
+ void setUserDetailsService(ReactiveUserDetailsService userDetailsService) {
+ this.reactiveUserDetailsService = userDetailsService;
+ }
+
+ @Autowired(required = false)
+ void setPasswordEncoder(PasswordEncoder passwordEncoder) {
+ this.passwordEncoder = passwordEncoder;
+ }
+
+ @Bean(name = RSOCKET_SECURITY_BEAN_NAME)
+ @Scope("prototype")
+ public RSocketSecurity rsocketSecurity(ApplicationContext context) {
+ RSocketSecurity security = new RSocketSecurity()
+ .authenticationManager(authenticationManager());
+ security.setApplicationContext(context);
+ return security;
+ }
+
+ private ReactiveAuthenticationManager authenticationManager() {
+ if (this.authenticationManager != null) {
+ return this.authenticationManager;
+ }
+ if (this.reactiveUserDetailsService != null) {
+ UserDetailsRepositoryReactiveAuthenticationManager manager =
+ new UserDetailsRepositoryReactiveAuthenticationManager(this.reactiveUserDetailsService);
+ if (this.passwordEncoder != null) {
+ manager.setPasswordEncoder(this.passwordEncoder);
+ }
+ return manager;
+ }
+ return null;
+ }
+}
diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java
new file mode 100644
index 0000000000..501822a65c
--- /dev/null
+++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.rsocket;
+
+import io.rsocket.AbstractRSocket;
+import io.rsocket.ConnectionSetupPayload;
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.SocketAcceptor;
+import io.rsocket.util.ByteBufPayload;
+import reactor.core.publisher.Mono;
+
+public class HelloHandler implements SocketAcceptor {
+
+ @Override
+ public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) {
+ return Mono.just(
+ new AbstractRSocket() {
+ @Override
+ public Mono requestResponse(Payload payload) {
+ String data = payload.getDataUtf8();
+ payload.release();
+ System.out.println("Got " + data);
+ return Mono.just(ByteBufPayload.create("Hello " + data));
+ }
+ });
+ }
+}
diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java
new file mode 100644
index 0000000000..e48f269b16
--- /dev/null
+++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.rsocket;
+
+import io.rsocket.RSocketFactory;
+import io.rsocket.frame.decoder.PayloadDecoder;
+import io.rsocket.transport.netty.server.CloseableChannel;
+import io.rsocket.transport.netty.server.TcpServerTransport;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.messaging.rsocket.RSocketRequester;
+import org.springframework.messaging.rsocket.RSocketStrategies;
+import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
+import org.springframework.security.config.Customizer;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
+import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
+import org.springframework.security.rsocket.metadata.BearerTokenMetadata;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringRunner;
+import reactor.core.publisher.Mono;
+
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@ContextConfiguration
+@RunWith(SpringRunner.class)
+public class JwtITests {
+ @Autowired
+ RSocketMessageHandler handler;
+
+ @Autowired
+ PayloadSocketAcceptorInterceptor interceptor;
+
+ @Autowired
+ ServerController controller;
+
+ @Autowired
+ ReactiveJwtDecoder decoder;
+
+ private CloseableChannel server;
+
+ private RSocketRequester requester;
+
+ @Before
+ public void setup() {
+ this.server = RSocketFactory.receive()
+ .frameDecoder(PayloadDecoder.ZERO_COPY)
+ .addSocketAcceptorPlugin(this.interceptor)
+ .acceptor(this.handler.responder())
+ .transport(TcpServerTransport.create("localhost", 7000))
+ .start()
+ .block();
+ }
+
+ @After
+ public void dispose() {
+ this.requester.rsocket().dispose();
+ this.server.dispose();
+ this.controller.payloads.clear();
+ }
+
+ @Test
+ public void routeWhenAuthorized() {
+ BearerTokenMetadata credentials =
+ new BearerTokenMetadata("token");
+ when(this.decoder.decode(any())).thenReturn(Mono.just(jwt()));
+ this.requester = requester()
+ .setupMetadata(credentials.getToken(), BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE)
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ String hiRob = this.requester.route("secure.retrieve-mono")
+ .data("rob")
+ .retrieveMono(String.class)
+ .block();
+
+ assertThat(hiRob).isEqualTo("Hi rob");
+ }
+
+ private Jwt jwt() {
+ Map claims = new HashMap<>();
+ claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
+ claims.put(IdTokenClaimNames.SUB, "rob");
+ claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id"));
+ Instant issuedAt = Instant.now();
+ Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
+ return new Jwt("token", issuedAt, expiresAt, claims, claims);
+ }
+
+ private RSocketRequester.Builder requester() {
+ return RSocketRequester.builder()
+ .rsocketStrategies(this.handler.getRSocketStrategies());
+ }
+
+
+ @Configuration
+ @EnableRSocketSecurity
+ static class Config {
+
+ @Bean
+ public ServerController controller() {
+ return new ServerController();
+ }
+
+ @Bean
+ public RSocketMessageHandler messageHandler() {
+ RSocketMessageHandler handler = new RSocketMessageHandler();
+ handler.setRSocketStrategies(rsocketStrategies());
+ return handler;
+ }
+
+ @Bean
+ public RSocketStrategies rsocketStrategies() {
+ return RSocketStrategies.builder()
+ .encoder(new BasicAuthenticationEncoder())
+ .build();
+ }
+
+ @Bean
+ PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ rsocket
+ .authorizePayload(authorize ->
+ authorize
+ .route("secure.admin.*").authenticated()
+ .anyRequest().permitAll()
+ )
+ .jwt(Customizer.withDefaults());
+ return rsocket.build();
+ }
+
+ @Bean
+ ReactiveJwtDecoder jwtDecoder() {
+ return mock(ReactiveJwtDecoder.class);
+ }
+ }
+
+ @Controller
+ static class ServerController {
+ private List payloads = new ArrayList<>();
+
+ @MessageMapping("**")
+ String connect(String payload) {
+ return "Hi " + payload;
+ }
+ }
+
+}
diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java
new file mode 100644
index 0000000000..f6883310fe
--- /dev/null
+++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java
@@ -0,0 +1,246 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.rsocket;
+
+import io.rsocket.RSocketFactory;
+import io.rsocket.exceptions.ApplicationErrorException;
+import io.rsocket.frame.decoder.PayloadDecoder;
+import io.rsocket.transport.netty.server.CloseableChannel;
+import io.rsocket.transport.netty.server.TcpServerTransport;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.messaging.rsocket.RSocketRequester;
+import org.springframework.messaging.rsocket.RSocketStrategies;
+import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
+import org.springframework.security.config.Customizer;
+import org.springframework.security.config.annotation.rsocket.EnableRSocketSecurity;
+import org.springframework.security.config.annotation.rsocket.RSocketSecurity;
+import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
+import org.springframework.security.core.userdetails.User;
+import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
+import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringRunner;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+
+/**
+ * @author Rob Winch
+ */
+@ContextConfiguration
+@RunWith(SpringRunner.class)
+public class RSocketMessageHandlerConnectionITests {
+ @Autowired
+ RSocketMessageHandler handler;
+
+ @Autowired
+ PayloadSocketAcceptorInterceptor interceptor;
+
+ @Autowired
+ ServerController controller;
+
+ private CloseableChannel server;
+
+ private RSocketRequester requester;
+
+ @Before
+ public void setup() {
+ this.server = RSocketFactory.receive()
+ .frameDecoder(PayloadDecoder.ZERO_COPY)
+ .addSocketAcceptorPlugin(this.interceptor)
+ .acceptor(this.handler.responder())
+ .transport(TcpServerTransport.create("localhost", 7000))
+ .start()
+ .block();
+ }
+
+ @After
+ public void dispose() {
+ this.requester.rsocket().dispose();
+ this.server.dispose();
+ this.controller.payloads.clear();
+ }
+
+ @Test
+ public void routeWhenAuthorized() {
+ UsernamePasswordMetadata credentials =
+ new UsernamePasswordMetadata("user", "password");
+ this.requester = requester()
+ .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ String hiRob = this.requester.route("secure.retrieve-mono")
+ .data("rob")
+ .retrieveMono(String.class)
+ .block();
+
+ assertThat(hiRob).isEqualTo("Hi rob");
+ }
+
+ @Test
+ public void routeWhenNotAuthorized() {
+ UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password");
+ this.requester = requester()
+ .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ assertThatCode(() -> this.requester.route("secure.admin.retrieve-mono")
+ .data("data")
+ .retrieveMono(String.class)
+ .block())
+ .isInstanceOf(ApplicationErrorException.class);
+ }
+
+ @Test
+ public void routeWhenStreamCredentialsAuthorized() {
+ UsernamePasswordMetadata connectCredentials = new UsernamePasswordMetadata("user", "password");
+ this.requester = requester()
+ .setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ String hiRob = this.requester.route("secure.admin.retrieve-mono")
+ .metadata(new UsernamePasswordMetadata("admin", "password"), UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .data("rob")
+ .retrieveMono(String.class)
+ .block();
+
+ assertThat(hiRob).isEqualTo("Hi rob");
+ }
+
+ @Test
+ public void connectWhenNotAuthenticated() {
+ this.requester = requester()
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ assertThatCode(() -> this.requester.route("retrieve-mono")
+ .data("data")
+ .retrieveMono(String.class)
+ .block())
+ .isNotNull();
+ // FIXME: https://github.com/rsocket/rsocket-java/issues/686
+ // .isInstanceOf(RejectedSetupException.class);
+ }
+
+ @Test
+ public void connectWhenNotAuthorized() {
+ UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("evil", "password");
+ this.requester = requester()
+ .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .connectTcp(this.server.address().getHostName(), this.server.address().getPort())
+ .block();
+
+ assertThatCode(() -> this.requester.route("retrieve-mono")
+ .data("data")
+ .retrieveMono(String.class)
+ .block())
+ .isNotNull();
+// FIXME: https://github.com/rsocket/rsocket-java/issues/686
+// .isInstanceOf(RejectedSetupException.class);
+ }
+
+ private RSocketRequester.Builder requester() {
+ return RSocketRequester.builder()
+ .rsocketStrategies(this.handler.getRSocketStrategies());
+ }
+
+
+ @Configuration
+ @EnableRSocketSecurity
+ static class Config {
+
+ @Bean
+ public ServerController controller() {
+ return new ServerController();
+ }
+
+ @Bean
+ public RSocketMessageHandler messageHandler() {
+ RSocketMessageHandler handler = new RSocketMessageHandler();
+ handler.setRSocketStrategies(rsocketStrategies());
+ return handler;
+ }
+
+ @Bean
+ public RSocketStrategies rsocketStrategies() {
+ return RSocketStrategies.builder()
+ .encoder(new BasicAuthenticationEncoder())
+ .build();
+ }
+
+ @Bean
+ MapReactiveUserDetailsService uds() {
+ UserDetails admin = User.withDefaultPasswordEncoder()
+ .username("admin")
+ .password("password")
+ .roles("USER", "ADMIN", "SETUP")
+ .build();
+ UserDetails user = User.withDefaultPasswordEncoder()
+ .username("user")
+ .password("password")
+ .roles("USER", "SETUP")
+ .build();
+
+ UserDetails evil = User.withDefaultPasswordEncoder()
+ .username("evil")
+ .password("password")
+ .roles("EVIL")
+ .build();
+ return new MapReactiveUserDetailsService(admin, user, evil);
+ }
+
+ @Bean
+ PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ rsocket
+ .authorizePayload(authorize ->
+ authorize
+ .setup().hasRole("SETUP")
+ .route("secure.admin.*").hasRole("ADMIN")
+ .route("secure.**").hasRole("USER")
+ .anyRequest().permitAll()
+ )
+ .basicAuthentication(Customizer.withDefaults());
+ return rsocket.build();
+ }
+ }
+
+ @Controller
+ static class ServerController {
+ private List payloads = new ArrayList<>();
+
+ @MessageMapping("**")
+ String connect(String payload) {
+ return "Hi " + payload;
+ }
+ }
+
+}
diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java
new file mode 100644
index 0000000000..3cbcc2117d
--- /dev/null
+++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java
@@ -0,0 +1,312 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.rsocket;
+
+import io.rsocket.RSocketFactory;
+import io.rsocket.exceptions.ApplicationErrorException;
+import io.rsocket.frame.decoder.PayloadDecoder;
+import io.rsocket.transport.netty.server.CloseableChannel;
+import io.rsocket.transport.netty.server.TcpServerTransport;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.messaging.rsocket.RSocketRequester;
+import org.springframework.messaging.rsocket.RSocketStrategies;
+import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
+import org.springframework.security.config.Customizer;
+import org.springframework.security.config.annotation.rsocket.EnableRSocketSecurity;
+import org.springframework.security.config.annotation.rsocket.RSocketSecurity;
+import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
+import org.springframework.security.core.userdetails.User;
+import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
+import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringRunner;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+
+/**
+ * @author Rob Winch
+ */
+@ContextConfiguration
+@RunWith(SpringRunner.class)
+public class RSocketMessageHandlerITests {
+ @Autowired
+ RSocketMessageHandler handler;
+
+ @Autowired
+ PayloadSocketAcceptorInterceptor interceptor;
+
+ @Autowired
+ ServerController controller;
+
+ private CloseableChannel server;
+
+ private RSocketRequester requester;
+
+ @Before
+ public void setup() {
+ this.server = RSocketFactory.receive()
+ .frameDecoder(PayloadDecoder.ZERO_COPY)
+ .addSocketAcceptorPlugin(this.interceptor)
+ .acceptor(this.handler.responder())
+ .transport(TcpServerTransport.create("localhost", 7000))
+ .start()
+ .block();
+
+ this.requester = RSocketRequester.builder()
+ // .rsocketFactory(factory -> factory.addRequesterPlugin(payloadInterceptor))
+ .rsocketStrategies(this.handler.getRSocketStrategies())
+ .connectTcp("localhost", 7000)
+ .block();
+ }
+
+ @After
+ public void dispose() {
+ this.requester.rsocket().dispose();
+ this.server.dispose();
+ this.controller.payloads.clear();
+ }
+
+ @Test
+ public void retrieveMonoWhenSecureThenDenied() throws Exception {
+ String data = "rob";
+ assertThatCode(() -> this.requester.route("secure.retrieve-mono")
+ .data(data)
+ .retrieveMono(String.class)
+ .block()
+ ).isInstanceOf(ApplicationErrorException.class);
+ assertThat(this.controller.payloads).isEmpty();
+ }
+
+ @Test
+ public void retrieveMonoWhenAuthenticationFailedThenException() throws Exception {
+ String data = "rob";
+ UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("invalid", "password");
+ assertThatCode(() -> this.requester.route("secure.retrieve-mono")
+ .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .data(data)
+ .retrieveMono(String.class)
+ .block()
+ ).isInstanceOf(ApplicationErrorException.class);
+ assertThat(this.controller.payloads).isEmpty();
+ }
+
+ @Test
+ public void retrieveMonoWhenAuthorizedThenGranted() throws Exception {
+ String data = "rob";
+ UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password");
+ String hiRob = this.requester.route("secure.retrieve-mono")
+ .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE)
+ .data(data)
+ .retrieveMono(String.class)
+ .block();
+
+ assertThat(hiRob).isEqualTo("Hi rob");
+ assertThat(this.controller.payloads).containsOnly(data);
+ }
+
+ @Test
+ public void retrieveMonoWhenPublicThenGranted() throws Exception {
+ String data = "rob";
+ String hiRob = this.requester.route("retrieve-mono")
+ .data(data)
+ .retrieveMono(String.class)
+ .block();
+
+ assertThat(hiRob).isEqualTo("Hi rob");
+ assertThat(this.controller.payloads).containsOnly(data);
+ }
+
+ @Test
+ public void retrieveFluxWhenDataFluxAndSecureThenDenied() throws Exception {
+ Flux data = Flux.just("a", "b", "c");
+ assertThatCode(() -> this.requester.route("secure.secure.retrieve-flux")
+ .data(data, String.class)
+ .retrieveFlux(String.class)
+ .collectList()
+ .block()).isInstanceOf(
+ ApplicationErrorException.class);
+
+ assertThat(this.controller.payloads).isEmpty();
+ }
+
+ @Test
+ public void retrieveFluxWhenDataFluxAndPublicThenGranted() throws Exception {
+ Flux data = Flux.just("a", "b", "c");
+ List hi = this.requester.route("retrieve-flux")
+ .data(data, String.class)
+ .retrieveFlux(String.class)
+ .collectList()
+ .block();
+
+ assertThat(hi).containsOnly("hello a", "hello b", "hello c");
+ assertThat(this.controller.payloads).containsOnlyElementsOf(data.collectList().block());
+ }
+
+ @Test
+ public void retrieveFluxWhenDataStringAndSecureThenDenied() throws Exception {
+ String data = "a";
+ assertThatCode(() -> this.requester.route("secure.hello")
+ .data(data)
+ .retrieveFlux(String.class)
+ .collectList()
+ .block()).isInstanceOf(
+ ApplicationErrorException.class);
+
+ assertThat(this.controller.payloads).isEmpty();
+ }
+
+ @Test
+ public void retrieveFluxWhenDataStringAndPublicThenGranted() throws Exception {
+ String data = "a";
+ List hi = this.requester.route("retrieve-flux")
+ .data(data)
+ .retrieveFlux(String.class)
+ .collectList()
+ .block();
+
+ assertThat(hi).contains("hello a");
+ assertThat(this.controller.payloads).containsOnly(data);
+ }
+
+ @Test
+ public void sendWhenSecureThenDenied() throws Exception {
+ String data = "hi";
+ this.requester.route("secure.send")
+ .data(data)
+ .send()
+ .block();
+
+ assertThat(this.controller.payloads).isEmpty();
+ }
+
+ @Test
+ public void sendWhenPublicThenGranted() throws Exception {
+ String data = "hi";
+ this.requester.route("send")
+ .data(data)
+ .send()
+ .block();
+ assertThat(this.controller.awaitPayloads()).containsOnly("hi");
+ }
+
+ @Configuration
+ @EnableRSocketSecurity
+ static class Config {
+
+ @Bean
+ public ServerController controller() {
+ return new ServerController();
+ }
+
+ @Bean
+ public RSocketMessageHandler messageHandler() {
+ RSocketMessageHandler handler = new RSocketMessageHandler();
+ handler.setRSocketStrategies(rsocketStrategies());
+ return handler;
+ }
+
+ @Bean
+ public RSocketStrategies rsocketStrategies() {
+ return RSocketStrategies.builder()
+ .encoder(new BasicAuthenticationEncoder())
+ .build();
+ }
+
+ @Bean
+ MapReactiveUserDetailsService uds() {
+ UserDetails rob = User.withDefaultPasswordEncoder()
+ .username("rob")
+ .password("password")
+ .roles("USER", "ADMIN")
+ .build();
+ UserDetails rossen = User.withDefaultPasswordEncoder()
+ .username("rossen")
+ .password("password")
+ .roles("USER")
+ .build();
+ return new MapReactiveUserDetailsService(rob, rossen);
+ }
+
+ @Bean
+ PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ rsocket
+ .authorizePayload(authorize -> {
+ authorize
+ .route("secure.*").authenticated()
+ .anyRequest().permitAll();
+ })
+ .basicAuthentication(Customizer.withDefaults());
+ return rsocket.build();
+ }
+ }
+
+ @Controller
+ static class ServerController {
+ private List payloads = new ArrayList<>();
+
+ @MessageMapping({"secure.retrieve-mono", "retrieve-mono"})
+ String retrieveMono(String payload) {
+ add(payload);
+ return "Hi " + payload;
+ }
+
+ @MessageMapping({"secure.retrieve-flux", "retrieve-flux"})
+ Flux retrieveFlux(Flux payload) {
+ return payload.doOnNext(this::add)
+ .map(p -> "hello " + p);
+ }
+
+ @MessageMapping({"secure.send", "send"})
+ Mono send(Flux payload) {
+ return payload
+ .doOnNext(this::add)
+ .then(Mono.fromRunnable(() -> {
+ doNotifyAll();
+ }));
+ }
+
+ private synchronized void doNotifyAll() {
+ this.notifyAll();
+ }
+
+ private synchronized List awaitPayloads() throws InterruptedException {
+ this.wait();
+ return this.payloads;
+ }
+
+ private void add(String p) {
+ this.payloads.add(p);
+ }
+ }
+
+}
diff --git a/etc/checkstyle/header.txt b/etc/checkstyle/header.txt
index e432c9f5bd..5e5d28b99f 100644
--- a/etc/checkstyle/header.txt
+++ b/etc/checkstyle/header.txt
@@ -1,5 +1,5 @@
^\Q/*\E$
-^\Q * Copyright\E (\d{4}\-\d{4} the original author or authors\.|(\d{4}, )*(\d{4}) Acegi Technology Pty Limited)$
+^\Q * Copyright\E (\d{4}(\-\d{4})? the original author or authors\.|(\d{4}, )*(\d{4}) Acegi Technology Pty Limited)$
^\Q *\E$
^\Q * Licensed under the Apache License, Version 2.0 (the "License");\E$
^\Q * you may not use this file except in compliance with the License.\E$
@@ -13,4 +13,4 @@
^\Q * See the License for the specific language governing permissions and\E$
^\Q * limitations under the License.\E$
^\Q */\E$
-^.*$
\ No newline at end of file
+^.*$
diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle
index b221730670..76549e80af 100644
--- a/gradle/dependency-management.gradle
+++ b/gradle/dependency-management.gradle
@@ -1,15 +1,17 @@
if (!project.hasProperty('reactorVersion')) {
- ext.reactorVersion = 'Dysprosium-M3'
+ ext.reactorVersion = 'Dysprosium-RC1'
}
if (!project.hasProperty('springVersion')) {
- ext.springVersion = '5.2.0.RC1'
+ ext.springVersion = '5.2.0.BUILD-SNAPSHOT'
}
if (!project.hasProperty('springDataVersion')) {
ext.springDataVersion = 'Moore-RC2'
}
+ext.rsocketVersion = '1.0.0-RC3'
+
dependencyManagement {
imports {
mavenBom "io.projectreactor:reactor-bom:${reactorVersion}"
@@ -71,6 +73,8 @@ dependencyManagement {
dependency 'commons-logging:commons-logging:1.2'
dependency 'dom4j:dom4j:1.6.1'
dependency 'io.projectreactor.tools:blockhound:1.0.0.M4'
+ dependency "io.rsocket:rsocket-core:${rsocketVersion}"
+ dependency "io.rsocket:rsocket-transport-netty:${rsocketVersion}"
dependency 'javax.activation:activation:1.1.1'
dependency 'javax.annotation:jsr250-api:1.0'
dependency 'javax.inject:javax.inject:1'
diff --git a/rsocket/spring-security-rsocket.gradle b/rsocket/spring-security-rsocket.gradle
new file mode 100644
index 0000000000..fa508c2760
--- /dev/null
+++ b/rsocket/spring-security-rsocket.gradle
@@ -0,0 +1,9 @@
+apply plugin: 'io.spring.convention.spring-module'
+
+dependencies {
+ compile project(':spring-security-core')
+ compile 'io.rsocket:rsocket-core'
+ optional project(':spring-security-oauth2-resource-server')
+ optional 'org.springframework:spring-messaging'
+ testCompile 'io.projectreactor:reactor-test'
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java
new file mode 100644
index 0000000000..eb7d3f9494
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+import java.util.List;
+import java.util.ListIterator;
+
+/**
+ * A {@link PayloadInterceptorChain} which exposes the Reactor {@link Context} via a member variable.
+ * This class is not Thread safe, so a new instance must be created for each Thread.
+ *
+ * Internally {@code ContextPayloadInterceptorChain} is used to ensure that the Reactor
+ * {@code Context} is captured so it can be transferred to subscribers outside of this
+ * {@code Context} in {@code PayloadSocketAcceptor}.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ * @see PayloadSocketAcceptor
+ */
+class ContextPayloadInterceptorChain implements PayloadInterceptorChain {
+
+ private final PayloadInterceptor currentInterceptor;
+
+ private final ContextPayloadInterceptorChain next;
+
+ private Context context;
+
+ ContextPayloadInterceptorChain(List interceptors) {
+ if (interceptors == null) {
+ throw new IllegalArgumentException("interceptors cannot be null");
+ }
+ if (interceptors.isEmpty()) {
+ throw new IllegalArgumentException("interceptors cannot be empty");
+ }
+ ContextPayloadInterceptorChain interceptor = init(interceptors);
+ this.currentInterceptor = interceptor.currentInterceptor;
+ this.next = interceptor.next;
+ }
+
+ private static ContextPayloadInterceptorChain init(List interceptors) {
+ ContextPayloadInterceptorChain interceptor = new ContextPayloadInterceptorChain(null, null);
+ ListIterator extends PayloadInterceptor> iterator = interceptors.listIterator(interceptors.size());
+ while (iterator.hasPrevious()) {
+ interceptor = new ContextPayloadInterceptorChain(iterator.previous(), interceptor);
+ }
+ return interceptor;
+ }
+
+ private ContextPayloadInterceptorChain(PayloadInterceptor currentInterceptor, ContextPayloadInterceptorChain next) {
+ this.currentInterceptor = currentInterceptor;
+ this.next = next;
+ }
+
+ public Mono next(PayloadExchange exchange) {
+ return Mono.defer(() ->
+ shouldIntercept() ?
+ this.currentInterceptor.intercept(exchange, this.next) :
+ Mono.subscriberContext()
+ .doOnNext(c -> this.context = c)
+ .then()
+ );
+ }
+
+ Context getContext() {
+ if (this.next == null) {
+ return this.context;
+ }
+ return this.next.getContext();
+ }
+
+ private boolean shouldIntercept() {
+ return this.currentInterceptor != null && this.next != null;
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + "[currentInterceptor=" + this.currentInterceptor + "]";
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java
new file mode 100644
index 0000000000..9a289cd2e6
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.Payload;
+import org.springframework.util.Assert;
+import org.springframework.util.MimeType;
+
+/**
+ * Default implementation of {@link PayloadExchange}
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class DefaultPayloadExchange implements PayloadExchange {
+
+ private final PayloadExchangeType type;
+
+ private final Payload payload;
+
+ private final MimeType metadataMimeType;
+
+ private final MimeType dataMimeType;
+
+ public DefaultPayloadExchange(PayloadExchangeType type, Payload payload, MimeType metadataMimeType,
+ MimeType dataMimeType) {
+ Assert.notNull(type, "type cannot be null");
+ Assert.notNull(payload, "payload cannot be null");
+ Assert.notNull(metadataMimeType, "metadataMimeType cannot be null");
+ Assert.notNull(dataMimeType, "dataMimeType cannot be null");
+ this.type = type;
+ this.payload = payload;
+ this.metadataMimeType = metadataMimeType;
+ this.dataMimeType = dataMimeType;
+ }
+
+ @Override
+ public PayloadExchangeType getType() {
+ return this.type;
+ }
+
+ @Override
+ public Payload getPayload() {
+ return this.payload;
+ }
+
+ @Override
+ public MimeType getMetadataMimeType() {
+ return this.metadataMimeType;
+ }
+
+ @Override
+ public MimeType getDataMimeType() {
+ return this.dataMimeType;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java
new file mode 100644
index 0000000000..7cf8ca4dca
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.Payload;
+import org.springframework.util.MimeType;
+
+/**
+ * Contract for a Payload interaction.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public interface PayloadExchange {
+ PayloadExchangeType getType();
+
+ Payload getPayload();
+
+ MimeType getDataMimeType();
+
+ MimeType getMetadataMimeType();
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java
new file mode 100644
index 0000000000..455b0e96ef
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+/**
+ * The {@link PayloadExchange} type
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public enum PayloadExchangeType {
+ /**
+ * The Setup . Can
+ * be used to determine if a Payload is part of the connection
+ */
+ SETUP(false),
+
+ /**
+ * A Fire and Forget exchange.
+ */
+ FIRE_AND_FORGET(true),
+
+ /**
+ * A Request
+ * Response exchange.
+ */
+ REQUEST_RESPONSE(true),
+
+ /**
+ * A Request Stream
+ * exchange. This is only represents the request portion. The {@link #PAYLOAD} type
+ * represents the data that submitted.
+ */
+ REQUEST_STREAM(true),
+
+ /**
+ * A Request
+ * Channel exchange.
+ */
+ REQUEST_CHANNEL(true),
+
+ /**
+ * A Payload exchange.
+ */
+ PAYLOAD(false),
+
+ /**
+ * A Metadata Push
+ * exchange.
+ */
+ METADATA_PUSH(true);
+
+ private final boolean isRequest;
+
+ PayloadExchangeType(boolean isRequest) {
+ this.isRequest = isRequest;
+ }
+
+ /**
+ * Determines if this exchange is a type of request (i.e. the initial frame).
+ * @return true if it is a request, else false
+ */
+ public boolean isRequest() {
+ return this.isRequest;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java
new file mode 100644
index 0000000000..8984ef5417
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * Contract for interception-style, chained processing of Payloads that may
+ * be used to implement cross-cutting, application-agnostic requirements such
+ * as security, timeouts, and others.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public interface PayloadInterceptor {
+ /**
+ * Process the Web request and (optionally) delegate to the next
+ * {@code PayloadInterceptor} through the given {@link PayloadInterceptorChain}.
+ * @param exchange the current payload exchange
+ * @param chain provides a way to delegate to the next interceptor
+ * @return {@code Mono} to indicate when payload processing is complete
+ */
+ Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain);
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java
new file mode 100644
index 0000000000..97c6ef7487
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * Contract to allow a {@link PayloadInterceptor} to delegate to the next in the chain.
+ * *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public interface PayloadInterceptorChain {
+ /**
+ * Process the payload exchange.
+ * @param exchange the current server exchange
+ * @return {@code Mono} to indicate when request processing is complete
+ */
+ Mono next(PayloadExchange exchange);
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java
new file mode 100644
index 0000000000..8a32f0faef
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.ResponderRSocket;
+import io.rsocket.util.RSocketProxy;
+import org.reactivestreams.Publisher;
+import org.springframework.util.MimeType;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+import java.util.List;
+
+/**
+ * Combines the {@link PayloadInterceptor} with a {@link ResponderRSocket}
+ * @author Rob Winch
+ * @since 5.2
+ */
+class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket {
+ private final List interceptors;
+
+ private final MimeType metadataMimeType;
+
+ private final MimeType dataMimeType;
+
+ private final Context context;
+
+ PayloadInterceptorRSocket(RSocket delegate,
+ List interceptors, MimeType metadataMimeType,
+ MimeType dataMimeType) {
+ this(delegate, interceptors, metadataMimeType, dataMimeType, Context.empty());
+ }
+
+ PayloadInterceptorRSocket(RSocket delegate,
+ List interceptors, MimeType metadataMimeType,
+ MimeType dataMimeType, Context context) {
+ super(delegate);
+ this.metadataMimeType = metadataMimeType;
+ this.dataMimeType = dataMimeType;
+ if (delegate == null) {
+ throw new IllegalArgumentException("delegate cannot be null");
+ }
+ if (interceptors == null) {
+ throw new IllegalArgumentException("interceptors cannot be null");
+ }
+ if (interceptors.isEmpty()) {
+ throw new IllegalArgumentException("interceptors cannot be empty");
+ }
+ this.interceptors = interceptors;
+ this.context = context;
+ }
+
+ @Override
+ public Mono fireAndForget(Payload payload) {
+ return intercept(PayloadExchangeType.FIRE_AND_FORGET, payload)
+ .flatMap(context ->
+ this.source.fireAndForget(payload)
+ .subscriberContext(context)
+ );
+ }
+
+ @Override
+ public Mono requestResponse(Payload payload) {
+ return intercept(PayloadExchangeType.REQUEST_RESPONSE, payload)
+ .flatMap(context ->
+ this.source.requestResponse(payload)
+ .subscriberContext(context)
+ );
+ }
+
+ @Override
+ public Flux requestStream(Payload payload) {
+ return intercept(PayloadExchangeType.REQUEST_STREAM, payload)
+ .flatMapMany(context ->
+ this.source.requestStream(payload)
+ .subscriberContext(context)
+ );
+ }
+
+ @Override
+ public Flux requestChannel(Publisher payloads) {
+ return Flux.from(payloads)
+ .switchOnFirst((signal, innerFlux) -> {
+ Payload firstPayload = signal.get();
+ return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload)
+ .flatMapMany(context ->
+ innerFlux
+ .skip(1)
+ .flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
+ .transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads))
+ .transform(securedPayloads -> this.source.requestChannel(securedPayloads))
+ .subscriberContext(context)
+ );
+ });
+ }
+
+ @Override
+ public Mono metadataPush(Payload payload) {
+ return intercept(PayloadExchangeType.METADATA_PUSH, payload)
+ .flatMap(c -> this.source
+ .metadataPush(payload)
+ .subscriberContext(c)
+ );
+ }
+
+ private Mono intercept(PayloadExchangeType type, Payload payload) {
+ return Mono.defer(() -> {
+ ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors);
+ DefaultPayloadExchange exchange = new DefaultPayloadExchange(type, payload,
+ this.metadataMimeType, this.dataMimeType);
+ return chain.next(exchange)
+ .then(Mono.fromCallable(() -> chain.getContext()))
+ .defaultIfEmpty(Context.empty())
+ .subscriberContext(this.context);
+ });
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + "[source=" + this.source + ",interceptors="
+ + this.interceptors + "]";
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java
new file mode 100644
index 0000000000..333a268f98
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.ConnectionSetupPayload;
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.SocketAcceptor;
+import io.rsocket.metadata.WellKnownMimeType;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+import org.springframework.util.StringUtils;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+import java.util.List;
+
+/**
+ * @author Rob Winch
+ * @since 5.2
+ */
+class PayloadSocketAcceptor implements SocketAcceptor {
+ private final SocketAcceptor delegate;
+
+ private final List interceptors;
+
+ @Nullable
+ private MimeType defaultDataMimeType;
+
+ private MimeType defaultMetadataMimeType =
+ MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+
+ PayloadSocketAcceptor(SocketAcceptor delegate, List interceptors) {
+ Assert.notNull(delegate, "delegate cannot be null");
+ if (interceptors == null) {
+ throw new IllegalArgumentException("interceptors cannot be null");
+ }
+ if (interceptors.isEmpty()) {
+ throw new IllegalArgumentException("interceptors cannot be empty");
+ }
+ this.delegate = delegate;
+ this.interceptors = interceptors;
+ }
+
+ @Override
+ public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) {
+ MimeType dataMimeType = parseMimeType(setup.dataMimeType(), this.defaultDataMimeType);
+ Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value");
+
+ MimeType metadataMimeType = parseMimeType(setup.metadataMimeType(), this.defaultMetadataMimeType);
+ Assert.notNull(metadataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value");
+
+ // FIXME do we want to make the sendingSocket available in the PayloadExchange
+ return intercept(setup, dataMimeType, metadataMimeType)
+ .flatMap(ctx -> this.delegate.accept(setup, sendingSocket)
+ .map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx))
+ );
+ }
+
+ private Mono intercept(Payload payload, MimeType dataMimeType, MimeType metadataMimeType) {
+ return Mono.defer(() -> {
+ ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors);
+ DefaultPayloadExchange exchange = new DefaultPayloadExchange(PayloadExchangeType.SETUP, payload,
+ metadataMimeType, dataMimeType);
+ return chain.next(exchange)
+ .then(Mono.fromCallable(() -> chain.getContext()))
+ .defaultIfEmpty(Context.empty());
+ });
+ }
+
+ private MimeType parseMimeType(String str, MimeType defaultMimeType) {
+ return StringUtils.hasText(str) ? MimeTypeUtils.parseMimeType(str) : defaultMimeType;
+ }
+
+ public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) {
+ this.defaultDataMimeType = defaultDataMimeType;
+ }
+
+ public void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) {
+ Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null");
+ this.defaultMetadataMimeType = defaultMetadataMimeType;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java
new file mode 100644
index 0000000000..35fdb36e14
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.SocketAcceptor;
+import io.rsocket.metadata.WellKnownMimeType;
+import io.rsocket.plugins.SocketAcceptorInterceptor;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+
+import java.util.List;
+
+/**
+ * A {@link SocketAcceptorInterceptor} that applies the {@link PayloadInterceptor}s
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class PayloadSocketAcceptorInterceptor implements SocketAcceptorInterceptor {
+
+ private final List interceptors;
+
+ @Nullable
+ private MimeType defaultDataMimeType;
+
+ private MimeType defaultMetadataMimeType =
+ MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+
+ public PayloadSocketAcceptorInterceptor(List interceptors) {
+ this.interceptors = interceptors;
+ }
+
+ @Override
+ public SocketAcceptor apply(SocketAcceptor socketAcceptor) {
+ PayloadSocketAcceptor acceptor = new PayloadSocketAcceptor(
+ socketAcceptor, this.interceptors);
+ acceptor.setDefaultDataMimeType(this.defaultDataMimeType);
+ acceptor.setDefaultMetadataMimeType(this.defaultMetadataMimeType);
+ return acceptor;
+ }
+
+ public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) {
+ this.defaultDataMimeType = defaultDataMimeType;
+ }
+
+ public void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) {
+ Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null");
+ this.defaultMetadataMimeType = defaultMetadataMimeType;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java
new file mode 100644
index 0000000000..97f2866d13
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authentication;
+
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.util.Assert;
+import reactor.core.publisher.Mono;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptor;
+
+import java.util.List;
+
+/**
+ * If {@link ReactiveSecurityContextHolder} is empty populates an
+ * {@code AnonymousAuthenticationToken}
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class AnonymousPayloadInterceptor implements PayloadInterceptor {
+
+ private String key;
+ private Object principal;
+ private List authorities;
+
+
+ /**
+ * Creates a filter with a principal named "anonymousUser" and the single authority
+ * "ROLE_ANONYMOUS".
+ *
+ * @param key the key to identify tokens created by this filter
+ */
+ public AnonymousPayloadInterceptor(String key) {
+ this(key, "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
+ }
+
+ /**
+ * @param key key the key to identify tokens created by this filter
+ * @param principal the principal which will be used to represent anonymous users
+ * @param authorities the authority list for anonymous users
+ */
+ public AnonymousPayloadInterceptor(String key, Object principal,
+ List authorities) {
+ Assert.hasLength(key, "key cannot be null or empty");
+ Assert.notNull(principal, "Anonymous authentication principal must be set");
+ Assert.notNull(authorities, "Anonymous authorities must be set");
+ this.key = key;
+ this.principal = principal;
+ this.authorities = authorities;
+ }
+
+
+ @Override
+ public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) {
+ return ReactiveSecurityContextHolder.getContext()
+ .switchIfEmpty(Mono.defer(() -> {
+ AnonymousAuthenticationToken authentication = new AnonymousAuthenticationToken(
+ this.key, this.principal, this.authorities);
+ return chain.next(exchange)
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+ .then(Mono.empty());
+ }))
+ .flatMap(securityContext -> chain.next(exchange));
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java
new file mode 100644
index 0000000000..1a988c7e9c
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authentication;
+
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptor;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.util.Assert;
+import reactor.core.publisher.Mono;
+
+/**
+ * Uses the provided {@code ReactiveAuthenticationManager} to authenticate a Payload. If
+ * authentication is successful, then the result is added to
+ * {@link ReactiveSecurityContextHolder}.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class AuthenticationPayloadInterceptor implements PayloadInterceptor {
+
+ private final ReactiveAuthenticationManager authenticationManager;
+
+ private PayloadExchangeAuthenticationConverter authenticationConverter =
+ new BasicAuthenticationPayloadExchangeConverter();
+
+ /**
+ * Creates a new instance
+ * @param authenticationManager the manager to use. Cannot be null
+ */
+ public AuthenticationPayloadInterceptor(ReactiveAuthenticationManager authenticationManager) {
+ Assert.notNull(authenticationManager, "authenticationManager cannot be null");
+ this.authenticationManager = authenticationManager;
+ }
+
+ /**
+ * Sets the convert to be used
+ * @param authenticationConverter
+ */
+ public void setAuthenticationConverter(
+ PayloadExchangeAuthenticationConverter authenticationConverter) {
+ Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
+ this.authenticationConverter = authenticationConverter;
+ }
+
+ public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) {
+ return this.authenticationConverter.convert(exchange)
+ .switchIfEmpty(chain.next(exchange).then(Mono.empty()))
+ .flatMap(a -> this.authenticationManager.authenticate(a))
+ .flatMap(a -> onAuthenticationSuccess(chain.next(exchange), a));
+ }
+
+ private Mono onAuthenticationSuccess(Mono payload, Authentication authentication) {
+ return payload
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication));
+ }
+
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java
new file mode 100644
index 0000000000..c4bce298aa
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authentication;
+
+import io.rsocket.metadata.WellKnownMimeType;
+import org.springframework.messaging.rsocket.DefaultMetadataExtractor;
+import org.springframework.messaging.rsocket.MetadataExtractor;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationDecoder;
+import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+import reactor.core.publisher.Mono;
+
+/**
+ * Converts from the {@link PayloadExchange} to a
+ * {@link UsernamePasswordAuthenticationToken} by extracting
+ * {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} from the metadata.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class BasicAuthenticationPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter {
+
+ private MimeType metadataMimetype = MimeTypeUtils.parseMimeType(
+ WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+
+ private MetadataExtractor metadataExtractor = createDefaultExtractor();
+
+ @Override
+ public Mono convert(PayloadExchange exchange) {
+ return Mono.fromCallable(() -> this.metadataExtractor
+ .extract(exchange.getPayload(), this.metadataMimetype))
+ .flatMap(metadata -> Mono.justOrEmpty(metadata.get(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE.toString())))
+ .cast(UsernamePasswordMetadata.class)
+ .map(credentials -> new UsernamePasswordAuthenticationToken(credentials.getUsername(), credentials.getPassword()));
+ }
+
+ private static MetadataExtractor createDefaultExtractor() {
+ DefaultMetadataExtractor result = new DefaultMetadataExtractor(new BasicAuthenticationDecoder());
+ result.metadataToExtract(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE, UsernamePasswordMetadata.class, (String) null);
+ return result;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java
new file mode 100644
index 0000000000..cc9db71dfc
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authentication;
+
+import io.netty.buffer.ByteBuf;
+import io.rsocket.metadata.CompositeMetadata;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.metadata.BearerTokenMetadata;
+import reactor.core.publisher.Mono;
+
+import java.nio.charset.StandardCharsets;
+
+/**
+ * Converts from the {@link PayloadExchange} to a
+ * {@link BearerTokenAuthenticationToken} by extracting
+ * {@link BearerTokenMetadata#BEARER_AUTHENTICATION_MIME_TYPE} from the metadata.
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class BearerPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter {
+
+ private static final String BEARER_MIME_TYPE_VALUE =
+ BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE.toString();
+
+ @Override
+ public Mono convert(PayloadExchange exchange) {
+ ByteBuf metadata = exchange.getPayload().metadata();
+ CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false);
+ for (CompositeMetadata.Entry entry : compositeMetadata) {
+ if (BEARER_MIME_TYPE_VALUE.equals(entry.getMimeType())) {
+ ByteBuf content = entry.getContent();
+ String token = content.toString(StandardCharsets.UTF_8);
+ return Mono.just(new BearerTokenAuthenticationToken(token));
+ }
+ }
+ return Mono.empty();
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java
new file mode 100644
index 0000000000..2713aeb06a
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authentication;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import reactor.core.publisher.Mono;
+
+/**
+ * Converts from a {@link PayloadExchange} to an {@link Authentication}
+ * @author Rob Winch
+ * @since 5.2
+ */
+public interface PayloadExchangeAuthenticationConverter {
+ Mono convert(PayloadExchange exchange);
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java
new file mode 100644
index 0000000000..1fc8011594
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authorization;
+
+import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
+import org.springframework.security.authorization.ReactiveAuthorizationManager;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.util.Assert;
+import reactor.core.publisher.Mono;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptor;
+
+/**
+ * Provides authorization of the {@link PayloadExchange}.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class AuthorizationPayloadInterceptor implements PayloadInterceptor {
+ private final ReactiveAuthorizationManager authorizationManager;
+
+ public AuthorizationPayloadInterceptor(
+ ReactiveAuthorizationManager authorizationManager) {
+ Assert.notNull(authorizationManager, "authorizationManager cannot be null");
+ this.authorizationManager = authorizationManager;
+ }
+
+ @Override
+ public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) {
+ return ReactiveSecurityContextHolder.getContext()
+ .filter(c -> c.getAuthentication() != null)
+ .map(SecurityContext::getAuthentication)
+ .switchIfEmpty(Mono.error(() -> new AuthenticationCredentialsNotFoundException("An Authentication (possibly AnonymousAuthenticationToken) is required.")))
+ .as(authentication -> this.authorizationManager.verify(authentication, exchange))
+ .then(chain.next(exchange));
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java
new file mode 100644
index 0000000000..7fb7096850
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authorization;
+
+import org.springframework.security.authorization.AuthorizationDecision;
+import org.springframework.security.authorization.ReactiveAuthorizationManager;
+import org.springframework.security.core.Authentication;
+import org.springframework.util.Assert;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcher;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Maps a @{code List} of {@link PayloadExchangeMatcher} instances to
+ * @{code ReactiveAuthorizationManager} instances.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class PayloadExchangeMatcherReactiveAuthorizationManager implements ReactiveAuthorizationManager {
+ private final List>> mappings;
+
+ private PayloadExchangeMatcherReactiveAuthorizationManager(List>> mappings) {
+ Assert.notEmpty(mappings, "mappings cannot be null");
+ this.mappings = mappings;
+ }
+
+ @Override
+ public Mono check(Mono authentication, PayloadExchange exchange) {
+ return Flux.fromIterable(this.mappings)
+ .concatMap(mapping -> mapping.getMatcher().matches(exchange)
+ .filter(PayloadExchangeMatcher.MatchResult::isMatch)
+ .map(r -> r.getVariables())
+ .flatMap(variables -> mapping.getEntry()
+ .check(authentication, new PayloadExchangeAuthorizationContext(exchange, variables))
+ )
+ )
+ .next()
+ .switchIfEmpty(Mono.fromCallable(() -> new AuthorizationDecision(false)));
+ }
+
+ public static PayloadExchangeMatcherReactiveAuthorizationManager.Builder builder() {
+ return new PayloadExchangeMatcherReactiveAuthorizationManager.Builder();
+ }
+
+ public static class Builder {
+ private final List>> mappings = new ArrayList<>();
+
+ private Builder() {
+ }
+
+ public PayloadExchangeMatcherReactiveAuthorizationManager.Builder add(
+ PayloadExchangeMatcherEntry> entry) {
+ this.mappings.add(entry);
+ return this;
+ }
+
+ public PayloadExchangeMatcherReactiveAuthorizationManager build() {
+ return new PayloadExchangeMatcherReactiveAuthorizationManager(this.mappings);
+ }
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java
new file mode 100644
index 0000000000..5085e5a833
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.metadata;
+
+import org.reactivestreams.Publisher;
+import org.springframework.core.ResolvableType;
+import org.springframework.core.codec.AbstractDecoder;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.util.MimeType;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+/**
+ * Decodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE}
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class BasicAuthenticationDecoder extends AbstractDecoder {
+ public BasicAuthenticationDecoder() {
+ super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE);
+ }
+
+ @Override
+ public Flux decode(Publisher input,
+ ResolvableType elementType, MimeType mimeType, Map hints) {
+ return Flux.from(input)
+ .map(DataBuffer::asByteBuffer)
+ .map(byteBuffer -> {
+ byte[] sizeBytes = new byte[4];
+ byteBuffer.get(sizeBytes);
+
+ int usernameSize = 4;
+ byte[] usernameBytes = new byte[usernameSize];
+ byteBuffer.get(usernameBytes);
+ byte[] passwordBytes = new byte[byteBuffer.remaining()];
+ byteBuffer.get(passwordBytes);
+ String username = new String(usernameBytes);
+ String password = new String(passwordBytes);
+ return new UsernamePasswordMetadata(username, password);
+ });
+ }
+
+ @Override
+ public Mono decodeToMono(Publisher input,
+ ResolvableType elementType, MimeType mimeType, Map hints) {
+ return Mono.from(input)
+ .map(DataBuffer::asByteBuffer)
+ .map(byteBuffer -> {
+ int usernameSize = byteBuffer.getInt();
+ byte[] usernameBytes = new byte[usernameSize];
+ byteBuffer.get(usernameBytes);
+ byte[] passwordBytes = new byte[byteBuffer.remaining()];
+ byteBuffer.get(passwordBytes);
+ String username = new String(usernameBytes);
+ String password = new String(passwordBytes);
+ return new UsernamePasswordMetadata(username, password);
+ });
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java
new file mode 100644
index 0000000000..9d088f5a2a
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.metadata;
+
+import org.reactivestreams.Publisher;
+import org.springframework.core.ResolvableType;
+import org.springframework.core.codec.AbstractEncoder;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.io.buffer.DataBufferFactory;
+import org.springframework.core.io.buffer.DataBufferUtils;
+import org.springframework.util.MimeType;
+import reactor.core.publisher.Flux;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+
+/**
+ * Encodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE}
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class BasicAuthenticationEncoder extends
+ AbstractEncoder {
+
+ public BasicAuthenticationEncoder() {
+ super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE);
+ }
+
+ @Override
+ public Flux encode(
+ Publisher extends UsernamePasswordMetadata> inputStream,
+ DataBufferFactory bufferFactory, ResolvableType elementType,
+ MimeType mimeType, Map hints) {
+ return Flux.from(inputStream).map(credentials ->
+ encodeValue(credentials, bufferFactory, elementType, mimeType, hints));
+ }
+
+ @Override
+ public DataBuffer encodeValue(UsernamePasswordMetadata credentials,
+ DataBufferFactory bufferFactory, ResolvableType valueType, MimeType mimeType,
+ Map hints) {
+ String username = credentials.getUsername();
+ String password = credentials.getPassword();
+ byte[] usernameBytes = username.getBytes(StandardCharsets.UTF_8);
+ byte[] usernameBytesLengthBytes = ByteBuffer.allocate(4).putInt(usernameBytes.length).array();
+ DataBuffer metadata = bufferFactory.allocateBuffer();
+ boolean release = true;
+ try {
+ metadata.write(usernameBytesLengthBytes);
+ metadata.write(usernameBytes);
+ metadata.write(password.getBytes(StandardCharsets.UTF_8));
+ release = false;
+ return metadata;
+ } finally {
+ if (release) {
+ DataBufferUtils.release(metadata);
+ }
+ }
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java
new file mode 100644
index 0000000000..e252fa21f3
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.springframework.security.rsocket.metadata;
+
+import org.springframework.http.MediaType;
+import org.springframework.util.MimeType;
+
+/**
+ * Represents a bearer token that has been encoded into a
+ * {@link Payload#metadata()}.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class BearerTokenMetadata {
+ /**
+ * Represents a bearer token which is encoded as a String.
+ *
+ * See rsocket/rsocket#272
+ */
+ public static final MimeType BEARER_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.bearer.v0");
+
+ private final String token;
+
+ public BearerTokenMetadata(String token) {
+ this.token = token;
+ }
+
+ public String getToken() {
+ return this.token;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java
new file mode 100644
index 0000000000..e99e23aa40
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.metadata;
+
+import io.rsocket.Payload;
+import org.springframework.http.MediaType;
+import org.springframework.util.MimeType;
+
+/**
+ * Represents a username and password that have been encoded into a
+ * {@link Payload#metadata()}.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public final class UsernamePasswordMetadata {
+ /**
+ * Represents a username password which is encoded as
+ * {@code ${username-bytes-length}${username-bytes}${password-bytes}}.
+ *
+ * See rsocket/rsocket#272
+ */
+ public static final MimeType BASIC_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.basic.v0");
+
+ private final String username;
+
+ private final String password;
+
+ public UsernamePasswordMetadata(String username, String password) {
+ this.username = username;
+ this.password = password;
+ }
+
+ public String getUsername() {
+ return this.username;
+ }
+
+ public String getPassword() {
+ return this.password;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java
new file mode 100644
index 0000000000..ac01e07f73
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.util;
+
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+
+import java.util.Collections;
+import java.util.Map;
+
+/**
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class PayloadExchangeAuthorizationContext {
+ private final PayloadExchange exchange;
+ private final Map variables;
+
+ public PayloadExchangeAuthorizationContext(PayloadExchange exchange) {
+ this(exchange, Collections.emptyMap());
+ }
+
+ public PayloadExchangeAuthorizationContext(PayloadExchange exchange, Map variables) {
+ this.exchange = exchange;
+ this.variables = variables;
+ }
+
+ public PayloadExchange getExchange() {
+ return this.exchange;
+ }
+
+ public Map getVariables() {
+ return Collections.unmodifiableMap(this.variables);
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java
new file mode 100644
index 0000000000..d5a368a1f8
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.rsocket.util;
+
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import reactor.core.publisher.Mono;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An interface for determining if a {@link PayloadExchangeMatcher} matches.
+ * @author Rob Winch
+ * @since 5.2
+ */
+public interface PayloadExchangeMatcher {
+
+ /**
+ * Determines if a request matches or not
+ * @param exchange
+ * @return
+ */
+ Mono matches(PayloadExchange exchange);
+
+ /**
+ * The result of matching
+ */
+ class MatchResult {
+ private final boolean match;
+ private final Map variables;
+
+ private MatchResult(boolean match, Map variables) {
+ this.match = match;
+ this.variables = variables;
+ }
+
+ public boolean isMatch() {
+ return match;
+ }
+
+ /**
+ * Gets potential variables and their values
+ * @return
+ */
+ public Map getVariables() {
+ return variables;
+ }
+
+ /**
+ * Creates an instance of {@link MatchResult} that is a match with no variables
+ * @return
+ */
+ public static Mono match() {
+ return match(Collections.emptyMap());
+ }
+
+ /**
+ *
+ * Creates an instance of {@link MatchResult} that is a match with the specified variables
+ * @param variables
+ * @return
+ */
+ public static Mono match(Map variables) {
+ return Mono.just(new MatchResult(true, variables == null ? null : new HashMap(variables)));
+ }
+
+ /**
+ * Creates an instance of {@link MatchResult} that is not a match.
+ * @return
+ */
+ public static Mono notMatch() {
+ return Mono.just(new MatchResult(false, Collections.emptyMap()));
+ }
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java
new file mode 100644
index 0000000000..691033c417
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.util;
+
+/**
+ * @author Rob Winch
+ */
+public class PayloadExchangeMatcherEntry {
+ private final PayloadExchangeMatcher matcher;
+ private final T entry;
+
+ public PayloadExchangeMatcherEntry(PayloadExchangeMatcher matcher, T entry) {
+ this.matcher = matcher;
+ this.entry = entry;
+ }
+
+ public PayloadExchangeMatcher getMatcher() {
+ return this.matcher;
+ }
+
+ public T getEntry() {
+ return this.entry;
+ }
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java
new file mode 100644
index 0000000000..9202949ac3
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.util;
+
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadExchangeType;
+import reactor.core.publisher.Mono;
+
+/**
+ * @author Rob Winch
+ */
+public abstract class PayloadExchangeMatchers {
+
+ public static PayloadExchangeMatcher setup() {
+ return new PayloadExchangeMatcher() {
+ public Mono matches(PayloadExchange exchange) {
+ return PayloadExchangeType.SETUP.equals(exchange.getType()) ?
+ MatchResult.match() :
+ MatchResult.notMatch();
+ }
+ };
+ }
+
+ public static PayloadExchangeMatcher anyRequest() {
+ return new PayloadExchangeMatcher() {
+ public Mono matches(PayloadExchange exchange) {
+ return exchange.getType().isRequest() ?
+ MatchResult.match() :
+ MatchResult.notMatch();
+ }
+ };
+ }
+
+ public static PayloadExchangeMatcher anyExchange() {
+ return new PayloadExchangeMatcher() {
+ public Mono matches(PayloadExchange exchange) {
+ return MatchResult.match();
+ }
+ };
+ }
+
+ private PayloadExchangeMatchers() {}
+}
diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java
new file mode 100644
index 0000000000..0b711d212b
--- /dev/null
+++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.util;
+
+import org.springframework.messaging.rsocket.MetadataExtractor;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.util.Assert;
+import org.springframework.util.RouteMatcher;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * FIXME: Pay attention to the package this goes into. It requires spring-messaging for
+ * the MetadataExtractor.
+ *
+ * @author Rob Winch
+ * @since 5.2
+ */
+public class RoutePayloadExchangeMatcher implements PayloadExchangeMatcher {
+
+ private final String pattern;
+
+ private final MetadataExtractor metadataExtractor;
+
+ private final RouteMatcher routeMatcher;
+
+ public RoutePayloadExchangeMatcher(MetadataExtractor metadataExtractor,
+ RouteMatcher routeMatcher, String pattern) {
+ Assert.notNull(pattern, "pattern cannot be null");
+ this.metadataExtractor = metadataExtractor;
+ this.routeMatcher = routeMatcher;
+ this.pattern = pattern;
+ }
+
+ @Override
+ public Mono matches(PayloadExchange exchange) {
+ Map metadata = this.metadataExtractor
+ .extract(exchange.getPayload(), exchange.getMetadataMimeType());
+ return Optional.ofNullable((String) metadata.get(MetadataExtractor.ROUTE_KEY))
+ .map(routeValue -> this.routeMatcher.parseRoute(routeValue))
+ .map(route -> this.routeMatcher.matchAndExtract(this.pattern, route))
+ .map(v -> MatchResult.match(v))
+ .orElse(MatchResult.notMatch());
+ }
+}
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java
new file mode 100644
index 0000000000..86a00c93d0
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.authentication.AnonymousPayloadInterceptor;
+
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class AnonymousPayloadInterceptorTests {
+ @Mock
+ private PayloadExchange exchange;
+
+ private AnonymousPayloadInterceptor interceptor;
+
+ @Before
+ public void setup() {
+ this.interceptor = new AnonymousPayloadInterceptor("key");
+ }
+
+ @Test
+ public void constructorKeyWhenKeyNullThenException() {
+ String key = null;
+ assertThatCode(() -> new AnonymousPayloadInterceptor(key))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorKeyPrincipalAuthoritiesWhenKeyNullThenException() {
+ String key = null;
+ assertThatCode(() -> new AnonymousPayloadInterceptor(key, "principal",
+ AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorKeyPrincipalAuthoritiesWhenPrincipalNullThenException() {
+ Object principal = null;
+ assertThatCode(() -> new AnonymousPayloadInterceptor("key", principal,
+ AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorKeyPrincipalAuthoritiesWhenAuthoritiesNullThenException() {
+ List authorities = null;
+ assertThatCode(() -> new AnonymousPayloadInterceptor("key", "principal",
+ authorities))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void interceptWhenNoAuthenticationThenAnonymousAuthentication() {
+ AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain();
+
+ this.interceptor.intercept(this.exchange, chain).block();
+
+ Authentication authentication = chain.getAuthentication();
+
+ assertThat(authentication).isInstanceOf(AnonymousAuthenticationToken.class);
+ }
+
+ @Test
+ public void interceptWhenAuthenticationThenOriginalAuthentication() {
+ AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain();
+ TestingAuthenticationToken expected =
+ new TestingAuthenticationToken("test", "password");
+
+ this.interceptor.intercept(this.exchange, chain)
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expected))
+ .block();
+
+ Authentication authentication = chain.getAuthentication();
+
+ assertThat(authentication).isEqualTo(expected);
+ }
+}
\ No newline at end of file
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java
new file mode 100644
index 0000000000..2f5480ab66
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.authentication;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import reactor.core.publisher.Mono;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+
+/**
+ * @author Rob Winch
+ */
+class AuthenticationPayloadInterceptorChain implements PayloadInterceptorChain {
+ private Authentication authentication;
+
+ @Override
+ public Mono next(PayloadExchange exchange) {
+ return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication)
+ .doOnNext(a -> this.setAuthentication(a)).then();
+ }
+
+ public Authentication getAuthentication() {
+ return this.authentication;
+ }
+
+ public void setAuthentication(Authentication authentication) {
+ this.authentication = authentication;
+ }
+}
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java
new file mode 100644
index 0000000000..e1ed44301c
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.authentication;
+
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.CompositeByteBuf;
+import io.rsocket.Payload;
+import io.rsocket.metadata.CompositeMetadataFlyweight;
+import io.rsocket.metadata.WellKnownMimeType;
+import io.rsocket.util.DefaultPayload;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.core.ResolvableType;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.io.buffer.DefaultDataBufferFactory;
+import org.springframework.core.io.buffer.NettyDataBufferFactory;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.rsocket.interceptor.PayloadExchangeType;
+import org.springframework.security.rsocket.interceptor.authentication.AuthenticationPayloadInterceptor;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
+import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
+import org.springframework.security.rsocket.interceptor.DefaultPayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class AuthenticationPayloadInterceptorTests {
+ static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType(
+ WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+ @Mock
+ ReactiveAuthenticationManager authenticationManager;
+
+ @Captor
+ ArgumentCaptor authenticationArg;
+
+ @Test
+ public void constructorWhenAuthenticationManagerNullThenException() {
+ assertThatCode(() -> new AuthenticationPayloadInterceptor(null))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void interceptWhenBasicCredentialsThenAuthenticates() {
+ AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor(
+ this.authenticationManager);
+ PayloadExchange exchange = createExchange();
+ TestingAuthenticationToken expectedAuthentication =
+ new TestingAuthenticationToken("user", "password");
+ when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just(
+ expectedAuthentication));
+
+ AuthenticationPayloadInterceptorChain authenticationPayloadChain = new AuthenticationPayloadInterceptorChain();
+ interceptor.intercept(exchange, authenticationPayloadChain)
+ .block();
+
+ Authentication authentication = authenticationPayloadChain.getAuthentication();
+
+ verify(this.authenticationManager).authenticate(this.authenticationArg.capture());
+ assertThat(this.authenticationArg.getValue()).isEqualToComparingFieldByField(new UsernamePasswordAuthenticationToken("user", "password"));
+ assertThat(authentication).isEqualTo(expectedAuthentication);
+ }
+
+ @Test
+ public void interceptWhenAuthenticationSuccessThenChainSubscribedOnce() {
+ AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor(
+ this.authenticationManager);
+
+ PayloadExchange exchange = createExchange();
+ TestingAuthenticationToken expectedAuthentication =
+ new TestingAuthenticationToken("user", "password");
+ when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just(
+ expectedAuthentication));
+
+ PublisherProbe voidResult = PublisherProbe.empty();
+ PayloadInterceptorChain chain = mock(PayloadInterceptorChain.class);
+ when(chain.next(any())).thenReturn(voidResult.mono());
+
+
+ StepVerifier.create(interceptor.intercept(exchange, chain))
+ .then(() -> assertThat(voidResult.subscribeCount()).isEqualTo(1))
+ .verifyComplete();
+ }
+
+ private Payload createRequestPayload() {
+
+ UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password");
+ BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder();
+ DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
+ ResolvableType elementType = ResolvableType
+ .forClass(UsernamePasswordMetadata.class);
+ MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE;
+ Map hints = null;
+ DataBuffer dataBuffer = encoder.encodeValue(credentials, factory,
+ elementType, mimeType, hints);
+
+ ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+ CompositeByteBuf metadata = allocator.compositeBuffer();
+ CompositeMetadataFlyweight.encodeAndAddMetadata(
+ metadata, allocator, mimeType.toString(), NettyDataBufferFactory.toByteBuf(dataBuffer));
+
+ return DefaultPayload.create(allocator.buffer(),
+ metadata);
+ }
+
+ private PayloadExchange createExchange() {
+ return new DefaultPayloadExchange(PayloadExchangeType.REQUEST_RESPONSE, createRequestPayload(), COMPOSITE_METADATA,
+ MediaType.APPLICATION_JSON);
+ }
+
+}
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java
new file mode 100644
index 0000000000..c2c5098afa
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.authorization;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.authorization.ReactiveAuthorizationManager;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.rsocket.interceptor.authorization.AuthorizationPayloadInterceptor;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
+import reactor.util.context.Context;
+import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager.authenticated;
+import static org.springframework.security.authorization.AuthorityReactiveAuthorizationManager.hasRole;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class AuthorizationPayloadInterceptorTests {
+ @Mock
+ private ReactiveAuthorizationManager authorizationManager;
+
+ @Mock
+ private PayloadExchange exchange;
+
+ @Mock
+ private PayloadInterceptorChain chain;
+
+ private PublisherProbe managerResult = PublisherProbe.empty();
+
+ private PublisherProbe chainResult = PublisherProbe.empty();
+
+ @Test
+ public void interceptWhenAuthenticationEmptyAndSubscribedThenException() {
+ when(this.chain.next(any())).thenReturn(this.chainResult.mono());
+
+ AuthorizationPayloadInterceptor interceptor =
+ new AuthorizationPayloadInterceptor(authenticated());
+
+ StepVerifier.create(interceptor.intercept(this.exchange, this.chain))
+ .then(() -> this.chainResult.assertWasNotSubscribed())
+ .verifyError(AuthenticationCredentialsNotFoundException.class);
+ }
+
+ @Test
+ public void interceptWhenAuthenticationNotSubscribedAndEmptyThenCompletes() {
+ when(this.chain.next(any())).thenReturn(this.chainResult.mono());
+ when(this.authorizationManager.verify(any(), any()))
+ .thenReturn(this.managerResult.mono());
+
+ AuthorizationPayloadInterceptor interceptor =
+ new AuthorizationPayloadInterceptor(this.authorizationManager);
+
+ StepVerifier.create(interceptor.intercept(this.exchange, this.chain))
+ .then(() -> this.chainResult.assertWasSubscribed())
+ .verifyComplete();
+ }
+
+ @Test
+ public void interceptWhenNotAuthorizedThenException() {
+ when(this.chain.next(any())).thenReturn(this.chainResult.mono());
+
+ AuthorizationPayloadInterceptor interceptor =
+ new AuthorizationPayloadInterceptor(hasRole("USER"));
+ Context userContext = ReactiveSecurityContextHolder
+ .withAuthentication(new TestingAuthenticationToken("user", "password"));
+
+ Mono intercept = interceptor.intercept(this.exchange, this.chain)
+ .subscriberContext(userContext);
+
+ StepVerifier.create(intercept)
+ .then(() -> this.chainResult.assertWasNotSubscribed())
+ .verifyError(AccessDeniedException.class);
+ }
+
+ @Test
+ public void interceptWhenAuthorizedThenContinues() {
+ when(this.chain.next(any())).thenReturn(this.chainResult.mono());
+
+ AuthorizationPayloadInterceptor interceptor =
+ new AuthorizationPayloadInterceptor(authenticated());
+ Context userContext = ReactiveSecurityContextHolder
+ .withAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"));
+
+ Mono intercept = interceptor.intercept(this.exchange, this.chain)
+ .subscriberContext(userContext);
+
+ StepVerifier.create(intercept)
+ .then(() -> this.chainResult.assertWasSubscribed())
+ .verifyComplete();
+ }
+}
\ No newline at end of file
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java
new file mode 100644
index 0000000000..6dc06fc168
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java
@@ -0,0 +1,509 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.metadata.WellKnownMimeType;
+import io.rsocket.util.RSocketProxy;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.mockito.stubbing.Answer;
+import org.reactivestreams.Publisher;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
+import reactor.test.publisher.TestPublisher;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class PayloadInterceptorRSocketTests {
+
+ static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType(
+ WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+
+ @Mock
+ RSocket delegate;
+
+ @Mock
+ PayloadInterceptor interceptor;
+
+ @Mock
+ PayloadInterceptor interceptor2;
+
+ @Mock
+ Payload payload;
+
+ @Captor
+ private ArgumentCaptor exchange;
+
+ PublisherProbe voidResult = PublisherProbe.empty();
+
+ TestPublisher payloadResult = TestPublisher.createCold();
+
+ private MimeType metadataMimeType = COMPOSITE_METADATA;
+
+ private MimeType dataMimeType = MediaType.APPLICATION_JSON;
+
+ @Test
+ public void constructorWhenNullDelegateThenException() {
+ this.delegate = null;
+ List interceptors = Arrays.asList(this.interceptor);
+ assertThatCode(() -> {
+ new PayloadInterceptorRSocket(this.delegate, interceptors,
+ metadataMimeType, dataMimeType);
+ })
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorWhenNullInterceptorsThenException() {
+ List interceptors = null;
+ assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors,
+ metadataMimeType, dataMimeType))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorWhenEmptyInterceptorsThenException() {
+ List interceptors = Collections.emptyList();
+ assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors,
+ metadataMimeType, dataMimeType))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ // single interceptor
+
+ @Test
+ public void fireAndForgetWhenInterceptorCompletesThenDelegateSubscribed() {
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.fireAndForget(this.payload))
+ .then(() -> this.voidResult.assertWasSubscribed())
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void fireAndForgetWhenInterceptorErrorsThenDelegateNotSubscribed() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.fireAndForget(this.payload))
+ .then(() -> this.voidResult.assertWasNotSubscribed())
+ .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected));
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void fireAndForgetWhenSecurityContextThenDelegateContext() {
+ TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication));
+ when(this.delegate.fireAndForget(any())).thenReturn(Mono.empty());
+
+ RSocket assertAuthentication = new RSocketProxy(this.delegate) {
+ @Override
+ public Mono fireAndForget(Payload payload) {
+ return assertAuthentication(authentication)
+ .flatMap(a -> super.fireAndForget(payload));
+ }
+ };
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ interceptor.fireAndForget(this.payload).block();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).fireAndForget(this.payload);
+ }
+
+ @Test
+ public void requestResponseWhenInterceptorCompletesThenDelegateSubscribed() {
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+ when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestResponse(this.payload))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).requestResponse(this.payload);
+ }
+
+ @Test
+ public void requestResponseWhenInterceptorErrorsThenDelegateNotInvoked() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ assertThatCode(() -> interceptor.requestResponse(this.payload).block()).isEqualTo(expected);
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verifyZeroInteractions(this.delegate);
+ }
+
+ @Test
+ public void requestResponseWhenSecurityContextThenDelegateContext() {
+ TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication));
+ when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono());
+
+ RSocket assertAuthentication = new RSocketProxy(this.delegate) {
+ @Override
+ public Mono requestResponse(Payload payload) {
+ return assertAuthentication(authentication)
+ .flatMap(a -> super.requestResponse(payload));
+ }
+ };
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestResponse(this.payload))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).requestResponse(this.payload);
+ }
+
+ @Test
+ public void requestStreamWhenInterceptorCompletesThenDelegateSubscribed() {
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+ when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestStream(this.payload))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void requestStreamWhenInterceptorErrorsThenDelegateNotSubscribed() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestStream(this.payload))
+ .then(() -> this.payloadResult.assertNoSubscribers())
+ .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected));
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void requestStreamWhenSecurityContextThenDelegateContext() {
+ TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication));
+ when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux());
+
+ RSocket assertAuthentication = new RSocketProxy(this.delegate) {
+ @Override
+ public Flux requestStream(Payload payload) {
+ return assertAuthentication(authentication)
+ .flatMapMany(a -> super.requestStream(payload));
+ }
+ };
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestStream(this.payload))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).requestStream(this.payload);
+ }
+
+ @Test
+ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+ when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload)))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).requestChannel(any());
+ }
+
+ @Test
+ public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType);
+
+ StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload)))
+ .then(() -> this.payloadResult.assertNoSubscribers())
+ .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected));
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void requestChannelWhenSecurityContextThenDelegateContext() {
+ Mono payload = Mono.just(this.payload);
+ TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication));
+ when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux());
+
+ RSocket assertAuthentication = new RSocketProxy(this.delegate) {
+ @Override
+ public Flux requestChannel(Publisher payload) {
+ return assertAuthentication(authentication)
+ .flatMapMany(a -> super.requestChannel(payload));
+ }
+ };
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.requestChannel(payload))
+ .then(() -> this.payloadResult.assertSubscribers())
+ .then(() -> this.payloadResult.emit(this.payload))
+ .expectNext(this.payload)
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).requestChannel(any());
+ }
+
+ @Test
+ public void metadataPushWhenInterceptorCompletesThenDelegateSubscribed() {
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+ when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.metadataPush(this.payload))
+ .then(() -> this.voidResult.assertWasSubscribed())
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void metadataPushWhenInterceptorErrorsThenDelegateNotSubscribed() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.metadataPush(this.payload))
+ .then(() -> this.voidResult.assertWasNotSubscribed())
+ .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected));
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ }
+
+ @Test
+ public void metadataPushWhenSecurityContextThenDelegateContext() {
+ TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication));
+ when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono());
+
+ RSocket assertAuthentication = new RSocketProxy(this.delegate) {
+ @Override
+ public Mono metadataPush(Payload payload) {
+ return assertAuthentication(authentication)
+ .flatMap(a -> super.metadataPush(payload));
+ }
+ };
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication,
+ Arrays.asList(this.interceptor), metadataMimeType, dataMimeType);
+
+ StepVerifier.create(interceptor.metadataPush(this.payload))
+ .verifyComplete();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.delegate).metadataPush(this.payload);
+ this.voidResult.assertWasSubscribed();
+ }
+
+ // multiple interceptors
+
+ @Test
+ public void fireAndForgetWhenInterceptorsCompleteThenDelegateInvoked() {
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType,
+ dataMimeType);
+
+ interceptor.fireAndForget(this.payload).block();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ this.voidResult.assertWasSubscribed();
+ }
+
+
+ @Test
+ public void fireAndForgetWhenInterceptorsMutatesPayloadThenDelegateInvoked() {
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono());
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType,
+ dataMimeType);
+
+ interceptor.fireAndForget(this.payload).block();
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.interceptor2).intercept(any(), any());
+ verify(this.delegate).fireAndForget(eq(this.payload));
+ this.voidResult.assertWasSubscribed();
+ }
+
+ @Test
+ public void fireAndForgetWhenInterceptor1ErrorsThenInterceptor2AndDelegateNotInvoked() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType,
+ dataMimeType);
+
+ assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected);
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verifyZeroInteractions(this.interceptor2);
+ this.voidResult.assertWasNotSubscribed();
+ }
+
+ @Test
+ public void fireAndForgetWhenInterceptor2ErrorsThenInterceptor2AndDelegateNotInvoked() {
+ RuntimeException expected = new RuntimeException("Oops");
+ when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext());
+ when(this.interceptor2.intercept(any(), any())).thenReturn(Mono.error(expected));
+
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
+ Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType,
+ dataMimeType);
+
+ assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected);
+
+ verify(this.interceptor).intercept(this.exchange.capture(), any());
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload);
+ verify(this.interceptor2).intercept(any(), any());
+ this.voidResult.assertWasNotSubscribed();
+ }
+
+ private Mono assertAuthentication(Authentication authentication) {
+ return ReactiveSecurityContextHolder.getContext()
+ .map(SecurityContext::getAuthentication)
+ .doOnNext(a -> assertThat(a).isEqualTo(authentication));
+ }
+
+ private Answer withAuthenticated(Authentication authentication) {
+ return invocation -> {
+ PayloadInterceptorChain c = (PayloadInterceptorChain) invocation.getArguments()[1];
+ return c.next(new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, this.metadataMimeType,
+ this.dataMimeType))
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication));
+ };
+ }
+
+ private static Answer> withChainNext() {
+ return invocation -> {
+ PayloadExchange exchange = (PayloadExchange) invocation.getArguments()[0];
+ PayloadInterceptorChain chain = (PayloadInterceptorChain) invocation.getArguments()[1];
+ return chain.next(exchange);
+ };
+ }
+}
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java
new file mode 100644
index 0000000000..367a1b9b99
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.ConnectionSetupPayload;
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.SocketAcceptor;
+import io.rsocket.metadata.WellKnownMimeType;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.http.MediaType;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class PayloadSocketAcceptorInterceptorTests {
+ @Mock
+ private PayloadInterceptor interceptor;
+
+ @Mock
+ private SocketAcceptor socketAcceptor;
+
+ @Mock
+ private ConnectionSetupPayload setupPayload;
+
+ @Mock
+ private RSocket rSocket;
+
+ @Mock
+ private Payload payload;
+
+ private List interceptors;
+
+ private PayloadSocketAcceptorInterceptor acceptorInterceptor;
+
+ @Before
+ public void setup() {
+ this.interceptors = Arrays.asList(this.interceptor);
+ this.acceptorInterceptor = new PayloadSocketAcceptorInterceptor(this.interceptors);
+ }
+
+ @Test
+ public void applyWhenDefaultMetadataMimeTypeThenDefaulted() {
+ when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() {
+ this.acceptorInterceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON);
+ when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void acceptWhenDefaultDataMimeTypeThenDefaulted() {
+ this.acceptorInterceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ private PayloadExchange captureExchange() {
+ when(this.socketAcceptor.accept(any(), any())).thenReturn(Mono.just(this.rSocket));
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+
+ SocketAcceptor wrappedAcceptor = this.acceptorInterceptor.apply(this.socketAcceptor);
+ RSocket result = wrappedAcceptor.accept(this.setupPayload, this.rSocket).block();
+
+ assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class);
+
+ when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty());
+
+ result.fireAndForget(this.payload).block();
+
+ ArgumentCaptor exchangeArg =
+ ArgumentCaptor.forClass(PayloadExchange.class);
+ verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any());
+ return exchangeArg.getValue();
+ }
+}
\ No newline at end of file
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java
new file mode 100644
index 0000000000..af8154fcd6
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java
@@ -0,0 +1,160 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor;
+
+import io.rsocket.ConnectionSetupPayload;
+import io.rsocket.Payload;
+import io.rsocket.RSocket;
+import io.rsocket.SocketAcceptor;
+import io.rsocket.metadata.WellKnownMimeType;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.http.MediaType;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class PayloadSocketAcceptorTests {
+
+ private PayloadSocketAcceptor acceptor;
+
+ private List interceptors;
+
+ @Mock
+ private SocketAcceptor delegate;
+
+ @Mock
+ private PayloadInterceptor interceptor;
+
+ @Mock
+ private ConnectionSetupPayload setupPayload;
+
+ @Mock
+ private RSocket rSocket;
+
+ @Mock
+ private Payload payload;
+
+ @Before
+ public void setup() {
+ this.interceptors = Arrays.asList(this.interceptor);
+ this.acceptor = new PayloadSocketAcceptor(this.delegate, this.interceptors);
+ }
+
+ @Test
+ public void constructorWhenNullDelegateThenException() {
+ this.delegate = null;
+ assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors));
+ }
+
+ @Test
+ public void constructorWhenNullInterceptorsThenException() {
+ this.interceptors = null;
+ assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors));
+ }
+
+ @Test
+ public void constructorWhenEmptyInterceptorsThenException() {
+ this.interceptors = Collections.emptyList();
+ assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors));
+ }
+
+ @Test
+ public void acceptWhenDataMimeTypeNullThenException() {
+ assertThatCode(() -> this.acceptor.accept(this.setupPayload, this.rSocket)
+ .block()).isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void acceptWhenDefaultMetadataMimeTypeThenDefaulted() {
+ when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType().toString())
+ .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() {
+ this.acceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON);
+ when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void acceptWhenDefaultDataMimeTypeThenDefaulted() {
+ this.acceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType()
+ .toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void acceptWhenExplicitMimeTypeThenThenOverrideDefault() {
+ when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE);
+ when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+
+ PayloadExchange exchange = captureExchange();
+
+ assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.TEXT_PLAIN);
+ assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
+ }
+
+ private PayloadExchange captureExchange() {
+ when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket));
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());
+
+ RSocket result = this.acceptor.accept(this.setupPayload, this.rSocket).block();
+
+ assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class);
+
+ when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty());
+
+ result.fireAndForget(this.payload).block();
+
+ ArgumentCaptor exchangeArg =
+ ArgumentCaptor.forClass(PayloadExchange.class);
+ verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any());
+ return exchangeArg.getValue();
+ }
+}
\ No newline at end of file
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java
new file mode 100644
index 0000000000..5e214875ed
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.interceptor.authorization;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.security.authorization.AuthorizationDecision;
+import org.springframework.security.authorization.ReactiveAuthorizationManager;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcher;
+import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry;
+import org.springframework.security.rsocket.util.PayloadExchangeMatchers;
+import reactor.core.publisher.Mono;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class PayloadExchangeMatcherReactiveAuthorizationManagerTest {
+
+ @Mock
+ private ReactiveAuthorizationManager authz;
+
+ @Mock
+ private ReactiveAuthorizationManager authz2;
+
+ @Mock
+ private PayloadExchange exchange;
+
+ @Test
+ public void checkWhenGrantedThenGranted() {
+ AuthorizationDecision expected = new AuthorizationDecision(true);
+ when(this.authz.check(any(), any())).thenReturn(Mono.just(
+ expected));
+ PayloadExchangeMatcherReactiveAuthorizationManager manager =
+ PayloadExchangeMatcherReactiveAuthorizationManager.builder()
+ .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz))
+ .build();
+
+ assertThat(manager.check(Mono.empty(), this.exchange).block())
+ .isEqualTo(expected);
+ }
+
+ @Test
+ public void checkWhenDeniedThenDenied() {
+ AuthorizationDecision expected = new AuthorizationDecision(false);
+ when(this.authz.check(any(), any())).thenReturn(Mono.just(
+ expected));
+ PayloadExchangeMatcherReactiveAuthorizationManager manager =
+ PayloadExchangeMatcherReactiveAuthorizationManager.builder()
+ .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz))
+ .build();
+
+ assertThat(manager.check(Mono.empty(), this.exchange).block())
+ .isEqualTo(expected);
+ }
+
+ @Test
+ public void checkWhenFirstMatchThenSecondUsed() {
+ AuthorizationDecision expected = new AuthorizationDecision(true);
+ when(this.authz.check(any(), any())).thenReturn(Mono.just(
+ expected));
+ PayloadExchangeMatcherReactiveAuthorizationManager manager =
+ PayloadExchangeMatcherReactiveAuthorizationManager.builder()
+ .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz))
+ .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz2))
+ .build();
+
+ assertThat(manager.check(Mono.empty(), this.exchange).block())
+ .isEqualTo(expected);
+ }
+
+ @Test
+ public void checkWhenSecondMatchThenSecondUsed() {
+ AuthorizationDecision expected = new AuthorizationDecision(true);
+ when(this.authz2.check(any(), any())).thenReturn(Mono.just(
+ expected));
+ PayloadExchangeMatcherReactiveAuthorizationManager manager =
+ PayloadExchangeMatcherReactiveAuthorizationManager.builder()
+ .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz))
+ .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz2))
+ .build();
+
+ assertThat(manager.check(Mono.empty(), this.exchange).block())
+ .isEqualTo(expected);
+ }
+}
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java
new file mode 100644
index 0000000000..2654d2378c
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.metadata;
+
+import org.junit.Test;
+import org.springframework.core.ResolvableType;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.io.buffer.DefaultDataBufferFactory;
+import org.springframework.util.MimeType;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Rob Winch
+ */
+public class BasicAuthenticationDecoderTests {
+ @Test
+ public void basicAuthenticationWhenEncodedThenDecodes() {
+ BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder();
+ BasicAuthenticationDecoder decoder = new BasicAuthenticationDecoder();
+ UsernamePasswordMetadata expectedCredentials =
+ new UsernamePasswordMetadata("rob", "password");
+ DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
+ ResolvableType elementType = ResolvableType
+ .forClass(UsernamePasswordMetadata.class);
+ MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE;
+ Map hints = null;
+
+ DataBuffer dataBuffer = encoder.encodeValue(expectedCredentials, factory,
+ elementType, mimeType, hints);
+ UsernamePasswordMetadata actualCredentials = decoder
+ .decodeToMono(Mono.just(dataBuffer), elementType, mimeType, hints).block();
+
+ assertThat(actualCredentials).isEqualToComparingFieldByField(expectedCredentials);
+ }
+
+}
\ No newline at end of file
diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java
new file mode 100644
index 0000000000..8c8c70ac1a
--- /dev/null
+++ b/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java
@@ -0,0 +1,116 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.util;
+
+import io.rsocket.Payload;
+import io.rsocket.metadata.WellKnownMimeType;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.http.MediaType;
+import org.springframework.messaging.rsocket.MetadataExtractor;
+import org.springframework.security.rsocket.interceptor.DefaultPayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadExchange;
+import org.springframework.security.rsocket.interceptor.PayloadExchangeType;
+import org.springframework.util.MimeType;
+import org.springframework.util.MimeTypeUtils;
+import org.springframework.util.RouteMatcher;
+
+import java.util.Collections;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class RoutePayloadExchangeMatcherTests {
+ static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType(
+ WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
+
+ @Mock
+ private MetadataExtractor metadataExtractor;
+
+ @Mock
+ private RouteMatcher routeMatcher;
+
+ private PayloadExchange exchange;
+
+ @Mock
+ private Payload payload;
+
+ @Mock
+ private RouteMatcher.Route route;
+
+ private String pattern;
+
+ private RoutePayloadExchangeMatcher matcher;
+
+ @Before
+ public void setup() {
+ this.pattern = "a.b";
+ this.matcher = new RoutePayloadExchangeMatcher(this.metadataExtractor, this.routeMatcher, this.pattern);
+ this.exchange = new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, COMPOSITE_METADATA,
+ MediaType.APPLICATION_JSON);
+ }
+
+ @Test
+ public void matchesWhenNoRouteThenNotMatch() {
+ when(this.metadataExtractor.extract(any(), any()))
+ .thenReturn(Collections.emptyMap());
+ PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block();
+ assertThat(result.isMatch()).isFalse();
+ }
+
+ @Test
+ public void matchesWhenNotMatchThenNotMatch() {
+ String route = "route";
+ when(this.metadataExtractor.extract(any(), any()))
+ .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route));
+ PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block();
+ assertThat(result.isMatch()).isFalse();
+ }
+
+ @Test
+ public void matchesWhenMatchAndNoVariablesThenMatch() {
+ String route = "route";
+ when(this.metadataExtractor.extract(any(), any()))
+ .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route));
+ when(this.routeMatcher.parseRoute(any())).thenReturn(this.route);
+ when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(Collections.emptyMap());
+ PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block();
+ assertThat(result.isMatch()).isTrue();
+ }
+
+ @Test
+ public void matchesWhenMatchAndVariablesThenMatchAndVariables() {
+ String route = "route";
+ Map variables = Collections.singletonMap("a", "b");
+ when(this.metadataExtractor.extract(any(), any()))
+ .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route));
+ when(this.routeMatcher.parseRoute(any())).thenReturn(this.route);
+ when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(variables);
+ PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block();
+ assertThat(result.isMatch()).isTrue();
+ assertThat(result.getVariables()).containsAllEntriesOf(variables);
+ }
+}
\ No newline at end of file