Add AuthnRequestConsumerResolver
Closes gh-8141
This commit is contained in:
+20
-1
@@ -21,6 +21,8 @@ import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.joda.time.DateTime;
|
||||
import org.opensaml.saml.common.xml.SAMLConstants;
|
||||
@@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
|
||||
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
|
||||
|
||||
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
|
||||
= context -> authnRequest -> {};
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
|
||||
@@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
|
||||
return createAuthnRequest(context.getIssuer(),
|
||||
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
|
||||
context.getDestination(), context.getAssertionConsumerServiceUrl());
|
||||
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
|
||||
return authnRequest;
|
||||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
|
||||
@@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link AuthnRequest} post-processor resolver
|
||||
*
|
||||
* @param authnRequestConsumerResolver
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setAuthnRequestConsumerResolver(
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
|
||||
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
|
||||
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* '
|
||||
* Use this {@link Clock} with {@link Instant#now()} for generating
|
||||
|
||||
+36
-1
@@ -16,6 +16,9 @@
|
||||
|
||||
package org.springframework.security.saml2.provider.service.authentication;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
@@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
||||
@@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
||||
factory.setProtocolBinding("my-invalid-binding");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
|
||||
this.factory.createPostAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
|
||||
this.factory.createRedirectAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthnRequestConsumerResolverWhenNullThenException() {
|
||||
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
|
||||
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
|
||||
factory.createRedirectAuthenticationRequest(context) :
|
||||
|
||||
Reference in New Issue
Block a user