diff --git a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java index fbafa4187c..8952e84503 100644 --- a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java +++ b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java @@ -15,6 +15,8 @@ package org.springframework.security.ui.openid; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.security.Authentication; import org.springframework.security.AuthenticationException; import org.springframework.security.AuthenticationServiceException; @@ -24,14 +26,16 @@ import org.springframework.security.ui.AbstractProcessingFilter; import org.springframework.security.ui.FilterChainOrder; import org.springframework.security.ui.openid.consumers.OpenID4JavaConsumer; import org.springframework.security.ui.webapp.AuthenticationProcessingFilter; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; /** @@ -50,6 +54,7 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt private OpenIDConsumer consumer; private String claimedIdentityFieldName = DEFAULT_CLAIMED_IDENTITY_FIELD; + private Map realmMapping = new HashMap(); //~ Methods ======================================================================================================== @@ -79,7 +84,7 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt } token.setDetails(authenticationDetailsSource.buildDetails(req)); - + // delegate to the auth provider Authentication authentication = this.getAuthenticationManager().authenticate(token); @@ -106,7 +111,8 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt if (StringUtils.hasText(claimedIdentity)) { try { String returnToUrl = buildReturnToUrl(request); - return consumer.beginConsumption(request, claimedIdentity, returnToUrl); + String realm = lookupRealm(returnToUrl); + return consumer.beginConsumption(request, claimedIdentity, returnToUrl, realm); } catch (OpenIDConsumerException e) { log.error("Unable to consume claimedIdentity [" + claimedIdentity + "]", e); } @@ -116,6 +122,30 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt return super.determineFailureUrl(request, failed); } + protected String lookupRealm(String returnToUrl) { + + String mapping = (String) realmMapping.get(returnToUrl); + + if (mapping == null) { + try { + + URL url = new URL(returnToUrl); + int port = (url.getPort() == -1) ? 80 : url.getPort(); + StringBuffer realmBuffer = new StringBuffer(returnToUrl.length()) + .append(url.getProtocol()) + .append("://") + .append(url.getHost()) + .append(":").append(port) + .append("/"); + mapping = realmBuffer.toString(); + } catch (MalformedURLException e) { + log.warn("returnToUrl was not a valid URL: [" + returnToUrl + "]", e); + } + } + + return mapping; + } + protected String buildReturnToUrl(HttpServletRequest request) { return request.getRequestURL().toString(); } @@ -199,6 +229,34 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt } public int getOrder() { - return FilterChainOrder.OPENID_PROCESSING_FILTER; + return FilterChainOrder.OPENID_PROCESSING_FILTER; + } + + /** + * Maps the return_to url to a realm.
+ * For example http://www.example.com/j_spring_openid_security_check -> http://www.example.com/realm
+ * If no mapping is provided then the returnToUrl will be parsed to extract the protocol, hostname and port followed + * by a trailing slash.
+ * This means that http://www.example.com/j_spring_openid_security_check will automatically + * become http://www.example.com:80/ + * + * @return Map containing returnToUrl -> realm mappings + */ + public Map getRealmMapping() { + return realmMapping; + } + + /** + * Maps the return_to url to a realm.
+ * For example http://www.example.com/j_spring_openid_security_check -> http://www.example.com/realm
+ * If no mapping is provided then the returnToUrl will be parsed to extract the protocol, hostname and port followed + * by a trailing slash.
+ * This means that http://www.example.com/j_spring_openid_security_check will automatically + * become http://www.example.com:80/ + * + * @param realmMapping containing returnToUrl -> realm mappings + */ + public void setRealmMapping(Map realmMapping) { + this.realmMapping = realmMapping; } } diff --git a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java index 2db0816277..796b66aaf0 100644 --- a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java +++ b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java @@ -27,9 +27,26 @@ import javax.servlet.http.HttpServletRequest; */ public interface OpenIDConsumer { + /** + * @deprecated Use {@link #beginConsumption(javax.servlet.http.HttpServletRequest, String, String, String)} + */ public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) throws OpenIDConsumerException; + /** + * Given the request, the claimedIdentity, the return to url, and a realm, lookup the openId authentication + * page the user should be redirected to. + * + * @param req HttpServletRequest + * @param claimedIdentity String URI the user presented during authentication + * @param returnToUrl String URI of the URL we want the user sent back to by the OP + * @param realm URI pattern matching the realm we want the user to see + * @return String URI to redirect user to for authentication + * @throws OpenIDConsumerException if anything bad happens + */ + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) + throws OpenIDConsumerException; + public OpenIDAuthenticationToken endConsumption(HttpServletRequest req) throws OpenIDConsumerException; diff --git a/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java b/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java index 84bd0bbb10..f73e822d92 100644 --- a/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java +++ b/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java @@ -61,8 +61,12 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { //~ Methods ======================================================================================================== - public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) - throws OpenIDConsumerException { + public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) throws OpenIDConsumerException { + return beginConsumption(req, identityUrl, returnToUrl, returnToUrl); + } + + public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl, String realm) + throws OpenIDConsumerException { List discoveries; try { @@ -78,7 +82,7 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { AuthRequest authReq; try { - authReq = consumerManager.authenticate(information, returnToUrl); + authReq = consumerManager.authenticate(information, returnToUrl, realm); } catch (MessageException e) { throw new OpenIDConsumerException("Error processing ConumerManager authentication", e); } catch (ConsumerException e) { diff --git a/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java b/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java new file mode 100644 index 0000000000..30ed6c8b08 --- /dev/null +++ b/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java @@ -0,0 +1,60 @@ +package org.springframework.security.ui.openid; + +import junit.framework.TestCase; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.MockAuthenticationManager; +import org.springframework.security.ui.openid.consumers.MockOpenIDConsumer; +import org.springframework.security.util.MockFilterChain; + +import javax.servlet.http.HttpServletRequest; + +public class OpenIDAuthenticationProcessingFilterTests extends TestCase { + + OpenIDAuthenticationProcessingFilter filter; + private static final String REDIRECT_URL = "http://www.example.com/redirect"; + private static final String CLAIMED_IDENTITY_URL = "http://www.example.com/identity"; + private static final String REQUEST_PATH = "/j_spring_openid_security_check"; + private static final String FILTER_PROCESS_URL = "http://localhost:80" + REQUEST_PATH; + private static final String DEFAULT_TARGET_URL = FILTER_PROCESS_URL; + + protected void setUp() throws Exception { + filter = new OpenIDAuthenticationProcessingFilter(); + filter.setConsumer(new MockOpenIDConsumer(REDIRECT_URL)); + filter.setDefaultTargetUrl(DEFAULT_TARGET_URL); + filter.setAuthenticationManager(new MockAuthenticationManager()); + filter.afterPropertiesSet(); + } + + public void testNoIdentityCausesException() throws Exception { + try { + MockHttpServletRequest req = new MockHttpServletRequest(); + filter.attemptAuthentication(req); + fail("OpenIDAuthenticationRequiredException expected, no openid.identity parameter"); + } catch (OpenIDAuthenticationRequiredException e) { + //cool + } + } + + public void testFilterOperation() throws Exception { + MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH); + MockHttpServletResponse response = new MockHttpServletResponse(); + + req.setParameter("j_username", CLAIMED_IDENTITY_URL); + req.setRemoteHost("www.example.com"); + + filter.setConsumer(new MockOpenIDConsumer() { + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException { + assertEquals(CLAIMED_IDENTITY_URL, claimedIdentity); + assertEquals(DEFAULT_TARGET_URL, returnToUrl); + assertEquals("http://localhost:80/", realm); + return REDIRECT_URL; + } + }); + + filter.doFilter(req, response, new MockFilterChain(false)); + assertEquals(REDIRECT_URL, response.getRedirectedUrl()); + } + + +} diff --git a/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java b/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java index d81a2e1438..1863f69a07 100644 --- a/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java +++ b/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java @@ -15,7 +15,6 @@ package org.springframework.security.ui.openid.consumers; import org.springframework.security.providers.openid.OpenIDAuthenticationToken; - import org.springframework.security.ui.openid.OpenIDConsumer; import org.springframework.security.ui.openid.OpenIDConsumerException; @@ -33,21 +32,41 @@ public class MockOpenIDConsumer implements OpenIDConsumer { private OpenIDAuthenticationToken token; private String redirectUrl; + public MockOpenIDConsumer() { + } + + public MockOpenIDConsumer(String redirectUrl, OpenIDAuthenticationToken token) { + this.redirectUrl = redirectUrl; + this.token = token; + } + + public MockOpenIDConsumer(String redirectUrl) { + this.redirectUrl = redirectUrl; + } + + public MockOpenIDConsumer(OpenIDAuthenticationToken token) { + this.token = token; + } + //~ Methods ======================================================================================================== - /* (non-Javadoc) - * @see org.springframework.security.ui.openid.OpenIDConsumer#beginConsumption(javax.servlet.http.HttpServletRequest, java.lang.String) - */ - public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) - throws OpenIDConsumerException { + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException { return redirectUrl; } + /* (non-Javadoc) + * @see org.springframework.security.ui.openid.OpenIDConsumer#beginConsumption(javax.servlet.http.HttpServletRequest, java.lang.String) + */ + public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) + throws OpenIDConsumerException { + throw new UnsupportedOperationException("This method is deprecated, stop using it"); + } + /* (non-Javadoc) * @see org.springframework.security.ui.openid.OpenIDConsumer#endConsumption(javax.servlet.http.HttpServletRequest) */ public OpenIDAuthenticationToken endConsumption(HttpServletRequest req) - throws OpenIDConsumerException { + throws OpenIDConsumerException { return token; }