diff --git a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java
index fbafa4187c..8952e84503 100644
--- a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java
+++ b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilter.java
@@ -15,6 +15,8 @@
package org.springframework.security.ui.openid;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.springframework.security.Authentication;
import org.springframework.security.AuthenticationException;
import org.springframework.security.AuthenticationServiceException;
@@ -24,14 +26,16 @@ import org.springframework.security.ui.AbstractProcessingFilter;
import org.springframework.security.ui.FilterChainOrder;
import org.springframework.security.ui.openid.consumers.OpenID4JavaConsumer;
import org.springframework.security.ui.webapp.AuthenticationProcessingFilter;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.util.HashMap;
+import java.util.Map;
/**
@@ -50,6 +54,7 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt
private OpenIDConsumer consumer;
private String claimedIdentityFieldName = DEFAULT_CLAIMED_IDENTITY_FIELD;
+ private Map realmMapping = new HashMap();
//~ Methods ========================================================================================================
@@ -79,7 +84,7 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt
}
token.setDetails(authenticationDetailsSource.buildDetails(req));
-
+
// delegate to the auth provider
Authentication authentication = this.getAuthenticationManager().authenticate(token);
@@ -106,7 +111,8 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt
if (StringUtils.hasText(claimedIdentity)) {
try {
String returnToUrl = buildReturnToUrl(request);
- return consumer.beginConsumption(request, claimedIdentity, returnToUrl);
+ String realm = lookupRealm(returnToUrl);
+ return consumer.beginConsumption(request, claimedIdentity, returnToUrl, realm);
} catch (OpenIDConsumerException e) {
log.error("Unable to consume claimedIdentity [" + claimedIdentity + "]", e);
}
@@ -116,6 +122,30 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt
return super.determineFailureUrl(request, failed);
}
+ protected String lookupRealm(String returnToUrl) {
+
+ String mapping = (String) realmMapping.get(returnToUrl);
+
+ if (mapping == null) {
+ try {
+
+ URL url = new URL(returnToUrl);
+ int port = (url.getPort() == -1) ? 80 : url.getPort();
+ StringBuffer realmBuffer = new StringBuffer(returnToUrl.length())
+ .append(url.getProtocol())
+ .append("://")
+ .append(url.getHost())
+ .append(":").append(port)
+ .append("/");
+ mapping = realmBuffer.toString();
+ } catch (MalformedURLException e) {
+ log.warn("returnToUrl was not a valid URL: [" + returnToUrl + "]", e);
+ }
+ }
+
+ return mapping;
+ }
+
protected String buildReturnToUrl(HttpServletRequest request) {
return request.getRequestURL().toString();
}
@@ -199,6 +229,34 @@ public class OpenIDAuthenticationProcessingFilter extends AbstractProcessingFilt
}
public int getOrder() {
- return FilterChainOrder.OPENID_PROCESSING_FILTER;
+ return FilterChainOrder.OPENID_PROCESSING_FILTER;
+ }
+
+ /**
+ * Maps the return_to url to a realm.
+ * For example http://www.example.com/j_spring_openid_security_check -> http://www.example.com/realm
+ * If no mapping is provided then the returnToUrl will be parsed to extract the protocol, hostname and port followed
+ * by a trailing slash.
+ * This means that http://www.example.com/j_spring_openid_security_check will automatically
+ * become http://www.example.com:80/
+ *
+ * @return Map containing returnToUrl -> realm mappings
+ */
+ public Map getRealmMapping() {
+ return realmMapping;
+ }
+
+ /**
+ * Maps the return_to url to a realm.
+ * For example http://www.example.com/j_spring_openid_security_check -> http://www.example.com/realm
+ * If no mapping is provided then the returnToUrl will be parsed to extract the protocol, hostname and port followed
+ * by a trailing slash.
+ * This means that http://www.example.com/j_spring_openid_security_check will automatically
+ * become http://www.example.com:80/
+ *
+ * @param realmMapping containing returnToUrl -> realm mappings
+ */
+ public void setRealmMapping(Map realmMapping) {
+ this.realmMapping = realmMapping;
}
}
diff --git a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java
index 2db0816277..796b66aaf0 100644
--- a/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java
+++ b/openid/src/main/java/org/springframework/security/ui/openid/OpenIDConsumer.java
@@ -27,9 +27,26 @@ import javax.servlet.http.HttpServletRequest;
*/
public interface OpenIDConsumer {
+ /**
+ * @deprecated Use {@link #beginConsumption(javax.servlet.http.HttpServletRequest, String, String, String)}
+ */
public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl)
throws OpenIDConsumerException;
+ /**
+ * Given the request, the claimedIdentity, the return to url, and a realm, lookup the openId authentication
+ * page the user should be redirected to.
+ *
+ * @param req HttpServletRequest
+ * @param claimedIdentity String URI the user presented during authentication
+ * @param returnToUrl String URI of the URL we want the user sent back to by the OP
+ * @param realm URI pattern matching the realm we want the user to see
+ * @return String URI to redirect user to for authentication
+ * @throws OpenIDConsumerException if anything bad happens
+ */
+ public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm)
+ throws OpenIDConsumerException;
+
public OpenIDAuthenticationToken endConsumption(HttpServletRequest req)
throws OpenIDConsumerException;
diff --git a/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java b/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java
index 84bd0bbb10..f73e822d92 100644
--- a/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java
+++ b/openid/src/main/java/org/springframework/security/ui/openid/consumers/OpenID4JavaConsumer.java
@@ -61,8 +61,12 @@ public class OpenID4JavaConsumer implements OpenIDConsumer {
//~ Methods ========================================================================================================
- public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl)
- throws OpenIDConsumerException {
+ public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl) throws OpenIDConsumerException {
+ return beginConsumption(req, identityUrl, returnToUrl, returnToUrl);
+ }
+
+ public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl, String realm)
+ throws OpenIDConsumerException {
List discoveries;
try {
@@ -78,7 +82,7 @@ public class OpenID4JavaConsumer implements OpenIDConsumer {
AuthRequest authReq;
try {
- authReq = consumerManager.authenticate(information, returnToUrl);
+ authReq = consumerManager.authenticate(information, returnToUrl, realm);
} catch (MessageException e) {
throw new OpenIDConsumerException("Error processing ConumerManager authentication", e);
} catch (ConsumerException e) {
diff --git a/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java b/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java
new file mode 100644
index 0000000000..30ed6c8b08
--- /dev/null
+++ b/openid/src/test/java/org/springframework/security/ui/openid/OpenIDAuthenticationProcessingFilterTests.java
@@ -0,0 +1,60 @@
+package org.springframework.security.ui.openid;
+
+import junit.framework.TestCase;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.MockAuthenticationManager;
+import org.springframework.security.ui.openid.consumers.MockOpenIDConsumer;
+import org.springframework.security.util.MockFilterChain;
+
+import javax.servlet.http.HttpServletRequest;
+
+public class OpenIDAuthenticationProcessingFilterTests extends TestCase {
+
+ OpenIDAuthenticationProcessingFilter filter;
+ private static final String REDIRECT_URL = "http://www.example.com/redirect";
+ private static final String CLAIMED_IDENTITY_URL = "http://www.example.com/identity";
+ private static final String REQUEST_PATH = "/j_spring_openid_security_check";
+ private static final String FILTER_PROCESS_URL = "http://localhost:80" + REQUEST_PATH;
+ private static final String DEFAULT_TARGET_URL = FILTER_PROCESS_URL;
+
+ protected void setUp() throws Exception {
+ filter = new OpenIDAuthenticationProcessingFilter();
+ filter.setConsumer(new MockOpenIDConsumer(REDIRECT_URL));
+ filter.setDefaultTargetUrl(DEFAULT_TARGET_URL);
+ filter.setAuthenticationManager(new MockAuthenticationManager());
+ filter.afterPropertiesSet();
+ }
+
+ public void testNoIdentityCausesException() throws Exception {
+ try {
+ MockHttpServletRequest req = new MockHttpServletRequest();
+ filter.attemptAuthentication(req);
+ fail("OpenIDAuthenticationRequiredException expected, no openid.identity parameter");
+ } catch (OpenIDAuthenticationRequiredException e) {
+ //cool
+ }
+ }
+
+ public void testFilterOperation() throws Exception {
+ MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+
+ req.setParameter("j_username", CLAIMED_IDENTITY_URL);
+ req.setRemoteHost("www.example.com");
+
+ filter.setConsumer(new MockOpenIDConsumer() {
+ public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException {
+ assertEquals(CLAIMED_IDENTITY_URL, claimedIdentity);
+ assertEquals(DEFAULT_TARGET_URL, returnToUrl);
+ assertEquals("http://localhost:80/", realm);
+ return REDIRECT_URL;
+ }
+ });
+
+ filter.doFilter(req, response, new MockFilterChain(false));
+ assertEquals(REDIRECT_URL, response.getRedirectedUrl());
+ }
+
+
+}
diff --git a/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java b/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java
index d81a2e1438..1863f69a07 100644
--- a/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java
+++ b/openid/src/test/java/org/springframework/security/ui/openid/consumers/MockOpenIDConsumer.java
@@ -15,7 +15,6 @@
package org.springframework.security.ui.openid.consumers;
import org.springframework.security.providers.openid.OpenIDAuthenticationToken;
-
import org.springframework.security.ui.openid.OpenIDConsumer;
import org.springframework.security.ui.openid.OpenIDConsumerException;
@@ -33,21 +32,41 @@ public class MockOpenIDConsumer implements OpenIDConsumer {
private OpenIDAuthenticationToken token;
private String redirectUrl;
+ public MockOpenIDConsumer() {
+ }
+
+ public MockOpenIDConsumer(String redirectUrl, OpenIDAuthenticationToken token) {
+ this.redirectUrl = redirectUrl;
+ this.token = token;
+ }
+
+ public MockOpenIDConsumer(String redirectUrl) {
+ this.redirectUrl = redirectUrl;
+ }
+
+ public MockOpenIDConsumer(OpenIDAuthenticationToken token) {
+ this.token = token;
+ }
+
//~ Methods ========================================================================================================
- /* (non-Javadoc)
- * @see org.springframework.security.ui.openid.OpenIDConsumer#beginConsumption(javax.servlet.http.HttpServletRequest, java.lang.String)
- */
- public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl)
- throws OpenIDConsumerException {
+ public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException {
return redirectUrl;
}
+ /* (non-Javadoc)
+ * @see org.springframework.security.ui.openid.OpenIDConsumer#beginConsumption(javax.servlet.http.HttpServletRequest, java.lang.String)
+ */
+ public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl)
+ throws OpenIDConsumerException {
+ throw new UnsupportedOperationException("This method is deprecated, stop using it");
+ }
+
/* (non-Javadoc)
* @see org.springframework.security.ui.openid.OpenIDConsumer#endConsumption(javax.servlet.http.HttpServletRequest)
*/
public OpenIDAuthenticationToken endConsumption(HttpServletRequest req)
- throws OpenIDConsumerException {
+ throws OpenIDConsumerException {
return token;
}