From 4064b7b4f6b4db8b917d381144a0dc6a0963864a Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Sun, 13 Sep 2009 15:03:14 +0000 Subject: [PATCH] SEC-1167: Introduce more flexible SavedRequest handling. Introduced interface for SavedRequest. --- ...bstractAuthenticationProcessingFilter.java | 2 +- ...uestAwareAuthenticationSuccessHandler.java | 19 +- .../web/savedrequest/DefaultSavedRequest.java | 341 +++++++++++++++++ .../savedrequest/HttpSessionRequestCache.java | 20 +- .../web/savedrequest/RequestCache.java | 3 +- .../web/savedrequest/SavedRequest.java | 362 +----------------- .../SavedRequestAwareWrapper.java | 6 +- .../DefaultSessionAuthenticationStrategy.java | 4 +- .../ExceptionTranslationFilterTests.java | 8 +- .../AbstractProcessingFilterTests.java | 16 +- .../HttpSessionRequestCacheTests.java | 33 ++ .../RequestCacheAwareFilterTests.java | 6 +- .../SavedRequestAwareWrapperTests.java | 4 +- .../web/savedrequest/SavedRequestTests.java | 6 +- ...ultSessionAuthenticationStrategyTests.java | 6 +- 15 files changed, 442 insertions(+), 394 deletions(-) create mode 100644 web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java create mode 100644 web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java diff --git a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java index 75d418902e..17100fc55c 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java @@ -67,7 +67,7 @@ import org.springframework.web.filter.GenericFilterBean; * The configured {@link #setAuthenticationSuccessHandler(AuthenticationSuccessHandler) AuthenticationSuccessHandler} will * then be called to take the redirect to the appropriate destination after a successful login. The default behaviour * is implemented in a {@link SavedRequestAwareAuthenticationSuccessHandler} which will make use of any - * SavedRequest set by the ExceptionTranslationFilter and redirect the user to the URL contained + * DefaultSavedRequest set by the ExceptionTranslationFilter and redirect the user to the URL contained * therein. Otherwise it will redirect to the webapp root "/". You can customize this behaviour by injecting a * differently configured instance of this class, or by using a different implementation. *

diff --git a/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java index 6b6c252c8d..604d51a2f2 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java @@ -13,10 +13,11 @@ import org.springframework.security.web.access.ExceptionTranslationFilter; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; import org.springframework.util.StringUtils; /** - * An authentication success strategy which can make use of the {@link SavedRequest} which may have been stored in + * An authentication success strategy which can make use of the {@link DefaultSavedRequest} which may have been stored in * the session by the {@link ExceptionTranslationFilter}. When such a request is intercepted and requires authentication, * the request data is stored to record the original destination before the authentication process commenced, and to * allow the request to be reconstructed when a redirect to the same URL occurs. This class is responsible for @@ -26,21 +27,21 @@ import org.springframework.util.StringUtils; *

* @@ -72,9 +73,9 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth return; } - // Use the SavedRequest URL - String targetUrl = savedRequest.getFullRequestUrl(); - logger.debug("Redirecting to SavedRequest Url: " + targetUrl); + // Use the DefaultSavedRequest URL + String targetUrl = savedRequest.getRedirectUrl(); + logger.debug("Redirecting to DefaultSavedRequest Url: " + targetUrl); getRedirectStrategy().sendRedirect(request, response, targetUrl); } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java new file mode 100644 index 0000000000..41921bb4da --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java @@ -0,0 +1,341 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.savedrequest; + +import org.springframework.security.web.PortResolver; +import org.springframework.security.web.util.UrlUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; + + +/** + * Represents central information from a HttpServletRequest.

This class is used by {@link + * org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter} and {@link org.springframework.security.web.savedrequest.SavedRequestAwareWrapper} to + * reproduce the request after successful authentication. An instance of this class is stored at the time of an + * authentication exception by {@link org.springframework.security.web.access.ExceptionTranslationFilter}.

+ *

IMPLEMENTATION NOTE: It is assumed that this object is accessed only from the context of a single + * thread, so no synchronization around internal collection classes is performed.

+ *

This class is based on code in Apache Tomcat.

+ * + * @author Craig McClanahan + * @author Andrey Grebnev + * @author Ben Alex + * @version $Id$ + */ +public class DefaultSavedRequest implements SavedRequest { + //~ Static fields/initializers ===================================================================================== + + protected static final Log logger = LogFactory.getLog(DefaultSavedRequest.class); + + public static final String SPRING_SECURITY_SAVED_REQUEST_KEY = "SPRING_SECURITY_SAVED_REQUEST_KEY"; + + //~ Instance fields ================================================================================================ + + private ArrayList cookies = new ArrayList(); + private ArrayList locales = new ArrayList(); + private Map> headers = new TreeMap>(String.CASE_INSENSITIVE_ORDER); + private Map parameters = new TreeMap(String.CASE_INSENSITIVE_ORDER); + private String contextPath; + private String method; + private String pathInfo; + private String queryString; + private String requestURI; + private String requestURL; + private String scheme; + private String serverName; + private String servletPath; + private int serverPort; + + //~ Constructors =================================================================================================== + + @SuppressWarnings("unchecked") + public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) { + Assert.notNull(request, "Request required"); + Assert.notNull(portResolver, "PortResolver required"); + + // Cookies + Cookie[] cookies = request.getCookies(); + + if (cookies != null) { + for (int i = 0; i < cookies.length; i++) { + this.addCookie(cookies[i]); + } + } + + // Headers + Enumeration names = request.getHeaderNames(); + + while (names.hasMoreElements()) { + String name = names.nextElement(); + Enumeration values = request.getHeaders(name); + + while (values.hasMoreElements()) { + this.addHeader(name, values.nextElement()); + } + } + + // Locales + Enumeration locales = request.getLocales(); + + while (locales.hasMoreElements()) { + Locale locale = (Locale) locales.nextElement(); + this.addLocale(locale); + } + + // Parameters + Map parameters = request.getParameterMap(); + + for(String paramName : parameters.keySet()) { + Object paramValues = parameters.get(paramName); + if (paramValues instanceof String[]) { + this.addParameter(paramName, (String[]) paramValues); + } else { + if (logger.isWarnEnabled()) { + logger.warn("ServletRequest.getParameterMap() returned non-String array"); + } + } + } + + // Primitives + this.method = request.getMethod(); + this.pathInfo = request.getPathInfo(); + this.queryString = request.getQueryString(); + this.requestURI = request.getRequestURI(); + this.serverPort = portResolver.getServerPort(request); + this.requestURL = request.getRequestURL().toString(); + this.scheme = request.getScheme(); + this.serverName = request.getServerName(); + this.contextPath = request.getContextPath(); + this.servletPath = request.getServletPath(); + } + + //~ Methods ======================================================================================================== + + private void addCookie(Cookie cookie) { + cookies.add(new SavedCookie(cookie)); + } + + private void addHeader(String name, String value) { + List values = headers.get(name); + + if (values == null) { + values = new ArrayList(); + headers.put(name, values); + } + + values.add(value); + } + + private void addLocale(Locale locale) { + locales.add(locale); + } + + private void addParameter(String name, String[] values) { + parameters.put(name, values); + } + + /** + * Determines if the current request matches the DefaultSavedRequest. + *

+ * All URL arguments are considered but not cookies, locales, headers or parameters. + *

+ * + */ + public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) { + + if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) { + return false; + } + + if (!propertyEquals("queryString", this.queryString, request.getQueryString())) { + return false; + } + + if (!propertyEquals("requestURI", this.requestURI, request.getRequestURI())) { + return false; + } + + if (!"GET".equals(request.getMethod()) && "GET".equals(method)) { + // A save GET should not match an incoming non-GET method + return false; + } + + if (!propertyEquals("serverPort", new Integer(this.serverPort), new Integer(portResolver.getServerPort(request)))) + { + return false; + } + + if (!propertyEquals("requestURL", this.requestURL, request.getRequestURL().toString())) { + return false; + } + + if (!propertyEquals("scheme", this.scheme, request.getScheme())) { + return false; + } + + if (!propertyEquals("serverName", this.serverName, request.getServerName())) { + return false; + } + + if (!propertyEquals("contextPath", this.contextPath, request.getContextPath())) { + return false; + } + + if (!propertyEquals("servletPath", this.servletPath, request.getServletPath())) { + return false; + } + + return true; + } + + public String getContextPath() { + return contextPath; + } + + public List getCookies() { + List cookieList = new ArrayList(cookies.size()); + + for (SavedCookie savedCookie : cookies) { + cookieList.add(savedCookie.getCookie()); + } + + return cookieList; + } + + /** + * Indicates the URL that the user agent used for this request. + * + * @return the full URL of this request + */ + public String getRedirectUrl() { + return UrlUtils.buildFullRequestUrl(scheme, serverName, serverPort, contextPath, servletPath, requestURI, + pathInfo, queryString); + } + + public Iterator getHeaderNames() { + return (headers.keySet().iterator()); + } + + public Iterator getHeaderValues(String name) { + List values = headers.get(name); + + if (values == null) { + values = Collections.emptyList(); + } + + return (values.iterator()); + } + + public Iterator getLocales() { + return (locales.iterator()); + } + + public String getMethod() { + return method; + } + + public Map getParameterMap() { + return parameters; + } + + public Iterator getParameterNames() { + return (parameters.keySet().iterator()); + } + + public String[] getParameterValues(String name) { + return ((String[]) parameters.get(name)); + } + + public String getPathInfo() { + return pathInfo; + } + + public String getQueryString() { + return (this.queryString); + } + + public String getRequestURI() { + return (this.requestURI); + } + + public String getRequestURL() { + return requestURL; + } + + public String getScheme() { + return scheme; + } + + public String getServerName() { + return serverName; + } + + public int getServerPort() { + return serverPort; + } + + public String getServletPath() { + return servletPath; + } + + private boolean propertyEquals(String log, Object arg1, Object arg2) { + if ((arg1 == null) && (arg2 == null)) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": both null (property equals)"); + } + + return true; + } + + if (((arg1 == null) && (arg2 != null)) || ((arg1 != null) && (arg2 == null))) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); + } + + return false; + } + + if (arg1.equals(arg2)) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property equals)"); + } + + return true; + } else { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); + } + + return false; + } + } + + public String toString() { + return "DefaultSavedRequest[" + getRedirectUrl() + "]"; + } +} diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java index cb74c36856..6a76489aa2 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java @@ -10,7 +10,9 @@ import org.springframework.security.web.PortResolver; import org.springframework.security.web.PortResolverImpl; /** - * RequestCache which stores the SavedRequest in the HttpSession. + * RequestCache which stores the SavedRequest in the HttpSession. + * + * The {@link DefaultSavedRequest} class is used as the implementation. * * @author Luke Taylor * @version $Id$ @@ -28,13 +30,13 @@ public class HttpSessionRequestCache implements RequestCache { */ public void saveRequest(HttpServletRequest request, HttpServletResponse response) { if (!justUseSavedRequestOnGet || "GET".equals(request.getMethod())) { - SavedRequest savedRequest = new SavedRequest(request, portResolver); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, portResolver); if (createSessionAllowed || request.getSession(false) != null) { // Store the HTTP request itself. Used by AbstractAuthenticationProcessingFilter // for redirection after successful authentication (SEC-29) - request.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); - logger.debug("SavedRequest added to Session: " + savedRequest); + request.getSession().setAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); + logger.debug("DefaultSavedRequest added to Session: " + savedRequest); } } @@ -44,7 +46,7 @@ public class HttpSessionRequestCache implements RequestCache { HttpSession session = currentRequest.getSession(false); if (session != null) { - return (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); + return (DefaultSavedRequest) session.getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); } return null; @@ -54,13 +56,13 @@ public class HttpSessionRequestCache implements RequestCache { HttpSession session = currentRequest.getSession(false); if (session != null) { - logger.debug("Removing SavedRequest from session if present"); - session.removeAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); + logger.debug("Removing DefaultSavedRequest from session if present"); + session.removeAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); } } public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) { - SavedRequest saved = getRequest(request, response); + DefaultSavedRequest saved = (DefaultSavedRequest) getRequest(request, response); if (saved == null) { return null; @@ -77,7 +79,7 @@ public class HttpSessionRequestCache implements RequestCache { } /** - * If true, will only use SavedRequest to determine the target URL on successful + * If true, will only use DefaultSavedRequest to determine the target URL on successful * authentication if the request that caused the authentication request was a GET. Defaults to false. */ public void setJustUseSavedRequestOnGet(boolean justUseSavedRequestOnGet) { diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/RequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/RequestCache.java index 7f32d15ccc..eee5f696c9 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/RequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/RequestCache.java @@ -34,7 +34,8 @@ public interface RequestCache { * * @param request * @param response - * @return the wrapped save request, if it matches the + * @return the wrapped save request, if it matches the original, or null if there is no cached request or it doesn't + * match. */ HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response); diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java index fc6a610e53..8f15e2d6ef 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java @@ -1,345 +1,17 @@ -/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.web.savedrequest; - -import org.springframework.security.web.PortResolver; -import org.springframework.security.web.util.UrlUtils; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.util.Assert; - -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Enumeration; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.TreeMap; - - -/** - * Represents central information from a HttpServletRequest.

This class is used by {@link - * org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter} and {@link org.springframework.security.web.savedrequest.SavedRequestAwareWrapper} to - * reproduce the request after successful authentication. An instance of this class is stored at the time of an - * authentication exception by {@link org.springframework.security.web.access.ExceptionTranslationFilter}.

- *

IMPLEMENTATION NOTE: It is assumed that this object is accessed only from the context of a single - * thread, so no synchronization around internal collection classes is performed.

- *

This class is based on code in Apache Tomcat.

- * - * @author Craig McClanahan - * @author Andrey Grebnev - * @author Ben Alex - * @version $Id$ - */ -public class SavedRequest implements java.io.Serializable { - //~ Static fields/initializers ===================================================================================== - - protected static final Log logger = LogFactory.getLog(SavedRequest.class); - - public static final String SPRING_SECURITY_SAVED_REQUEST_KEY = "SPRING_SECURITY_SAVED_REQUEST_KEY"; - - //~ Instance fields ================================================================================================ - - private ArrayList cookies = new ArrayList(); - private ArrayList locales = new ArrayList(); - private Map> headers = new TreeMap>(String.CASE_INSENSITIVE_ORDER); - private Map parameters = new TreeMap(String.CASE_INSENSITIVE_ORDER); - private String contextPath; - private String method; - private String pathInfo; - private String queryString; - private String requestURI; - private String requestURL; - private String scheme; - private String serverName; - private String servletPath; - private int serverPort; - - //~ Constructors =================================================================================================== - - @SuppressWarnings("unchecked") - public SavedRequest(HttpServletRequest request, PortResolver portResolver) { - Assert.notNull(request, "Request required"); - Assert.notNull(portResolver, "PortResolver required"); - - // Cookies - Cookie[] cookies = request.getCookies(); - - if (cookies != null) { - for (int i = 0; i < cookies.length; i++) { - this.addCookie(cookies[i]); - } - } - - // Headers - Enumeration names = request.getHeaderNames(); - - while (names.hasMoreElements()) { - String name = names.nextElement(); - Enumeration values = request.getHeaders(name); - - while (values.hasMoreElements()) { - this.addHeader(name, values.nextElement()); - } - } - - // Locales - Enumeration locales = request.getLocales(); - - while (locales.hasMoreElements()) { - Locale locale = (Locale) locales.nextElement(); - this.addLocale(locale); - } - - // Parameters - Map parameters = request.getParameterMap(); - - for(String paramName : parameters.keySet()) { - Object paramValues = parameters.get(paramName); - if (paramValues instanceof String[]) { - this.addParameter(paramName, (String[]) paramValues); - } else { - if (logger.isWarnEnabled()) { - logger.warn("ServletRequest.getParameterMap() returned non-String array"); - } - } - } - - // Primitives - this.method = request.getMethod(); - this.pathInfo = request.getPathInfo(); - this.queryString = request.getQueryString(); - this.requestURI = request.getRequestURI(); - this.serverPort = portResolver.getServerPort(request); - this.requestURL = request.getRequestURL().toString(); - this.scheme = request.getScheme(); - this.serverName = request.getServerName(); - this.contextPath = request.getContextPath(); - this.servletPath = request.getServletPath(); - } - - //~ Methods ======================================================================================================== - - private void addCookie(Cookie cookie) { - cookies.add(new SavedCookie(cookie)); - } - - private void addHeader(String name, String value) { - List values = headers.get(name); - - if (values == null) { - values = new ArrayList(); - headers.put(name, values); - } - - values.add(value); - } - - private void addLocale(Locale locale) { - locales.add(locale); - } - - private void addParameter(String name, String[] values) { - parameters.put(name, values); - } - - /** - * Determines if the current request matches the SavedRequest. All URL arguments are - * considered, but not method (POST/GET), cookies, locales, headers or parameters. - */ - public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) { - Assert.notNull(request, "Request required"); - Assert.notNull(portResolver, "PortResolver required"); - - if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) { - return false; - } - - if (!propertyEquals("queryString", this.queryString, request.getQueryString())) { - return false; - } - - if (!propertyEquals("requestURI", this.requestURI, request.getRequestURI())) { - return false; - } - - if (!propertyEquals("serverPort", new Integer(this.serverPort), new Integer(portResolver.getServerPort(request)))) - { - return false; - } - - if (!propertyEquals("requestURL", this.requestURL, request.getRequestURL().toString())) { - return false; - } - - if (!propertyEquals("scheme", this.scheme, request.getScheme())) { - return false; - } - - if (!propertyEquals("serverName", this.serverName, request.getServerName())) { - return false; - } - - if (!propertyEquals("contextPath", this.contextPath, request.getContextPath())) { - return false; - } - - if (!propertyEquals("servletPath", this.servletPath, request.getServletPath())) { - return false; - } - - return true; - } - - public String getContextPath() { - return contextPath; - } - - public List getCookies() { - List cookieList = new ArrayList(cookies.size()); - - for (SavedCookie savedCookie : cookies) { - cookieList.add(savedCookie.getCookie()); - } - - return cookieList; - } - - /** - * Indicates the URL that the user agent used for this request. - * - * @return the full URL of this request - */ - public String getFullRequestUrl() { - return UrlUtils.buildFullRequestUrl(this.getScheme(), this.getServerName(), this.getServerPort(), this.getContextPath(), - this.getServletPath(), this.getRequestURI(), this.getPathInfo(), this.getQueryString()); - } - - public Iterator getHeaderNames() { - return (headers.keySet().iterator()); - } - - public Iterator getHeaderValues(String name) { - List values = headers.get(name); - - if (values == null) { - values = Collections.emptyList(); - } - - return (values.iterator()); - } - - public Iterator getLocales() { - return (locales.iterator()); - } - - public String getMethod() { - return method; - } - - public Map getParameterMap() { - return parameters; - } - - public Iterator getParameterNames() { - return (parameters.keySet().iterator()); - } - - public String[] getParameterValues(String name) { - return ((String[]) parameters.get(name)); - } - - public String getPathInfo() { - return pathInfo; - } - - public String getQueryString() { - return (this.queryString); - } - - public String getRequestURI() { - return (this.requestURI); - } - - public String getRequestURL() { - return requestURL; - } - - /** - * Obtains the web application-specific fragment of the URL. - * - * @return the URL, excluding any server name, context path or servlet path - */ - public String getRequestUrl() { - return UrlUtils.buildRequestUrl(this.getServletPath(), this.getRequestURI(), this.getContextPath(), this.getPathInfo(), - this.getQueryString()); - } - - public String getScheme() { - return scheme; - } - - public String getServerName() { - return serverName; - } - - public int getServerPort() { - return serverPort; - } - - public String getServletPath() { - return servletPath; - } - - private boolean propertyEquals(String log, Object arg1, Object arg2) { - if ((arg1 == null) && (arg2 == null)) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": both null (property equals)"); - } - - return true; - } - - if (((arg1 == null) && (arg2 != null)) || ((arg1 != null) && (arg2 == null))) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); - } - - return false; - } - - if (arg1.equals(arg2)) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property equals)"); - } - - return true; - } else { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); - } - - return false; - } - } - - public String toString() { - return "SavedRequest[" + getFullRequestUrl() + "]"; - } -} +package org.springframework.security.web.savedrequest; + +/** + * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically + * form-based login) to redirect to the original URL. + * + * @author Luke Taylor + * @version $Id$ + * @since 3.0 + */ +public interface SavedRequest extends java.io.Serializable { + + /** + * @return the URL for the saved request, allowing a redirect to be performed. + */ + String getRedirectUrl(); +} diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java index 66f653977c..e8c0163001 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java @@ -65,7 +65,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { //~ Instance fields ================================================================================================ - protected SavedRequest savedRequest = null; + protected DefaultSavedRequest savedRequest = null; /** * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is @@ -75,7 +75,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { //~ Constructors =================================================================================================== - public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) { + public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) { super(request); savedRequest = saved; @@ -234,7 +234,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { * In both cases the value from the wrapped request should be used. *

* If the value from the wrapped request is null, an attempt will be made to retrieve the parameter - * from the SavedRequest, if available.. + * from the DefaultSavedRequest, if available.. */ @Override public String getParameter(String name) { diff --git a/web/src/main/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategy.java index 2be9f56819..bd7bd38788 100644 --- a/web/src/main/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategy.java @@ -13,7 +13,7 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.security.core.Authentication; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; /** * The default implementation of {@link SessionAuthenticationStrategy}. @@ -45,7 +45,7 @@ public class DefaultSessionAuthenticationStrategy implements SessionAuthenticati * In the case where the attributes will not be migrated, this field allows a list of named attributes * which should not be discarded. */ - private List retainedAttributes = Arrays.asList(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); + private List retainedAttributes = Arrays.asList(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); /** * If set to true, a session will always be created, even if one didn't exist at the start of the request. diff --git a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java index 18515e8aec..8828eb90a4 100644 --- a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java @@ -42,7 +42,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; import org.springframework.security.web.util.ThrowableAnalyzer; /** @@ -66,9 +66,9 @@ public class ExceptionTranslationFilterTests { return null; } - SavedRequest savedRequest = (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); + DefaultSavedRequest savedRequest = (DefaultSavedRequest) session.getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); - return savedRequest.getFullRequestUrl(); + return savedRequest.getRedirectUrl(); } @Test @@ -199,7 +199,7 @@ public class ExceptionTranslationFilterTests { doThrow(new BadCredentialsException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); request.setMethod("POST"); filter.doFilter(request, new MockHttpServletResponse(), fc); - assertTrue(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY) == null); + assertTrue(request.getSession().getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY) == null); } @Test(expected=IllegalArgumentException.class) diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractProcessingFilterTests.java index a116801aca..e34aace005 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractProcessingFilterTests.java @@ -49,7 +49,7 @@ import org.springframework.security.web.authentication.ExceptionMappingAuthentic import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; import org.springframework.security.web.session.SessionAuthenticationStrategy; @@ -83,7 +83,7 @@ public class AbstractProcessingFilterTests extends TestCase { filter.destroy(); } - private SavedRequest makeSavedRequestForUrl() { + private DefaultSavedRequest makeSavedRequestForUrl() { MockHttpServletRequest request = createMockRequest(); request.setMethod("GET"); request.setServletPath("/some_protected_file.html"); @@ -91,10 +91,10 @@ public class AbstractProcessingFilterTests extends TestCase { request.setServerName("www.example.com"); request.setRequestURI("/mycontext/some_protected_file.html"); - return new SavedRequest(request, new PortResolverImpl()); + return new DefaultSavedRequest(request, new PortResolverImpl()); } -// private SavedRequest makePostSavedRequestForUrl() { +// private DefaultSavedRequest makePostSavedRequestForUrl() { // MockHttpServletRequest request = createMockRequest(); // request.setServletPath("/some_protected_file.html"); // request.setScheme("http"); @@ -102,7 +102,7 @@ public class AbstractProcessingFilterTests extends TestCase { // request.setRequestURI("/mycontext/post/some_protected_file.html"); // request.setMethod("POST"); // -// return new SavedRequest(request, new PortResolverImpl()); +// return new DefaultSavedRequest(request, new PortResolverImpl()); // } protected void setUp() throws Exception { @@ -327,7 +327,7 @@ public class AbstractProcessingFilterTests extends TestCase { throws Exception { // Setup our HTTP request MockHttpServletRequest request = createMockRequest(); - request.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, makeSavedRequestForUrl()); + request.getSession().setAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, makeSavedRequestForUrl()); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); @@ -352,7 +352,7 @@ public class AbstractProcessingFilterTests extends TestCase { public void testSuccessfulAuthenticationCausesRedirectToSessionSpecifiedUrl() throws Exception { // Setup our HTTP request MockHttpServletRequest request = createMockRequest(); - request.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, makeSavedRequestForUrl()); + request.getSession().setAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, makeSavedRequestForUrl()); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); @@ -367,7 +367,7 @@ public class AbstractProcessingFilterTests extends TestCase { // Test executeFilterInContainerSimulator(config, filter, request, response, chain); - assertEquals(makeSavedRequestForUrl().getFullRequestUrl(), response.getRedirectedUrl()); + assertEquals(makeSavedRequestForUrl().getRedirectUrl(), response.getRedirectedUrl()); assertNotNull(SecurityContextHolder.getContext().getAuthentication()); } diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java new file mode 100644 index 0000000000..802f8939fa --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java @@ -0,0 +1,33 @@ +package org.springframework.security.web.savedrequest; + +import static org.junit.Assert.*; + +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +/** + * + * @author Luke Taylor + * @version $Id$ + * @since 3.0 + */ +public class HttpSessionRequestCacheTests { + + @Test + public void originalGetRequestDoesntMatchIncomingPost() { + HttpSessionRequestCache cache = new HttpSessionRequestCache(); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/destination"); + MockHttpServletResponse response = new MockHttpServletResponse(); + cache.saveRequest(request, response); + assertNotNull(request.getSession().getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); + assertNotNull(cache.getRequest(request, response)); + + MockHttpServletRequest newRequest = new MockHttpServletRequest("POST", "/destination"); + newRequest.setSession(request.getSession()); + assertNull(cache.getMatchingRequest(newRequest, response)); + + } + +} diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilterTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilterTests.java index 8e21858033..7c0b5b6311 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilterTests.java @@ -9,7 +9,6 @@ import org.springframework.mock.web.MockHttpServletResponse; public class RequestCacheAwareFilterTests { - @Test public void savedRequestIsRemovedAfterMatch() throws Exception { RequestCacheAwareFilter filter = new RequestCacheAwareFilter(); @@ -18,10 +17,9 @@ public class RequestCacheAwareFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination"); MockHttpServletResponse response = new MockHttpServletResponse(); cache.saveRequest(request, response); - assertNotNull(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); + assertNotNull(request.getSession().getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); filter.doFilter(request, response, new MockFilterChain()); - assertNull(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); + assertNull(request.getSession().getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); } - } diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java index d8612e5000..2d0412918c 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java @@ -13,13 +13,13 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.web.PortResolverImpl; import org.springframework.security.web.savedrequest.FastHttpDateFormat; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper; public class SavedRequestAwareWrapperTests { private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) { - SavedRequest saved = requestToSave == null ? null : new SavedRequest(requestToSave, new PortResolverImpl()); + DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl()); return new SavedRequestAwareWrapper(saved, requestToWrap); } diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java index 60b1f5f241..ee98193587 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java @@ -4,7 +4,7 @@ import static org.junit.Assert.*; import org.junit.Test; import org.springframework.security.MockPortResolver; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; import org.springframework.mock.web.MockHttpServletRequest; /** @@ -16,7 +16,7 @@ public class SavedRequestTests { public void headersAreCaseInsensitive() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("USER-aGenT", "Mozilla"); - SavedRequest saved = new SavedRequest(request, new MockPortResolver(8080, 8443)); + DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443)); assertEquals("Mozilla", saved.getHeaderValues("user-agent").next()); } @@ -25,7 +25,7 @@ public class SavedRequestTests { public void parametersAreCaseInsensitive() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.addParameter("ThisIsATest", "Hi mom"); - SavedRequest saved = new SavedRequest(request, new MockPortResolver(8080, 8443)); + DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443)); assertEquals("Hi mom", saved.getParameterValues("thisisatest")[0]); } } diff --git a/web/src/test/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategyTests.java index 756f980968..5119a384a0 100644 --- a/web/src/test/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/session/DefaultSessionAuthenticationStrategyTests.java @@ -10,7 +10,7 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.core.Authentication; -import org.springframework.security.web.savedrequest.SavedRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; /** * @@ -48,12 +48,12 @@ public class DefaultSessionAuthenticationStrategyTests { HttpServletRequest request = new MockHttpServletRequest(); HttpSession session = request.getSession(); session.setAttribute("blah", "blah"); - session.setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, "SavedRequest"); + session.setAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, "DefaultSavedRequest"); strategy.onAuthentication(mock(Authentication.class), request, new MockHttpServletResponse()); assertNull(request.getSession().getAttribute("blah")); - assertNotNull(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); + assertNotNull(request.getSession().getAttribute(DefaultSavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY)); } @Test