1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Add AuthnRequestConsumerResolver

Closes gh-8141
This commit is contained in:
Josh Cummings
2020-07-16 14:51:47 -06:00
parent 2e5c87dc75
commit 2c960d2ad1
3 changed files with 124 additions and 3 deletions
@@ -21,6 +21,8 @@ import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
@@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
= context -> authnRequest -> {};
@Override
@Deprecated
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
@@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
return createAuthnRequest(context.getIssuer(),
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl());
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
return authnRequest;
}
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
@@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return auth;
}
/**
* Set the {@link AuthnRequest} post-processor resolver
*
* @param authnRequestConsumerResolver
* @since 5.4
*/
public void setAuthnRequestConsumerResolver(
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
}
/**
* '
* Use this {@link Clock} with {@link Instant#now()} for generating
@@ -16,6 +16,9 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
@@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.CoreMatchers.containsString;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests {
factory.setProtocolBinding("my-invalid-binding");
}
@Test
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
this.factory.createPostAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
}
@Test
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
this.factory.createRedirectAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
}
@Test
public void setAuthnRequestConsumerResolverWhenNullThenException() {
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
.isInstanceOf(IllegalArgumentException.class);
}
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
factory.createRedirectAuthenticationRequest(context) :