diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java index 41921bb4da..36b4af4e09 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java @@ -15,23 +15,24 @@ package org.springframework.security.web.savedrequest; -import org.springframework.security.web.PortResolver; -import org.springframework.security.web.util.UrlUtils; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.util.Assert; - -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Enumeration; -import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.TreeMap; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.security.web.PortResolver; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.Assert; + /** * Represents central information from a HttpServletRequest.

This class is used by {@link @@ -237,22 +238,22 @@ public class DefaultSavedRequest implements SavedRequest { pathInfo, queryString); } - public Iterator getHeaderNames() { - return (headers.keySet().iterator()); + public Collection getHeaderNames() { + return headers.keySet(); } - public Iterator getHeaderValues(String name) { + public List getHeaderValues(String name) { List values = headers.get(name); if (values == null) { - values = Collections.emptyList(); + return Collections.emptyList(); } - return (values.iterator()); + return values; } - public Iterator getLocales() { - return (locales.iterator()); + public List getLocales() { + return locales; } public String getMethod() { @@ -263,8 +264,8 @@ public class DefaultSavedRequest implements SavedRequest { return parameters; } - public Iterator getParameterNames() { - return (parameters.keySet().iterator()); + public Collection getParameterNames() { + return parameters.keySet(); } public String[] getParameterValues(String name) { diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java index 8f15e2d6ef..9d1307db59 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java @@ -1,8 +1,16 @@ package org.springframework.security.web.savedrequest; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import javax.servlet.http.Cookie; + /** - * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically - * form-based login) to redirect to the original URL. + * Encapsulates the functionality required of a cached request for both an authentication mechanism (typically + * form-based login) to redirect to the original URL and for a RequestCache to build a wrapped request, + * reproducing the original request data. * * @author Luke Taylor * @version $Id$ @@ -14,4 +22,18 @@ public interface SavedRequest extends java.io.Serializable { * @return the URL for the saved request, allowing a redirect to be performed. */ String getRedirectUrl(); + + List getCookies(); + + String getMethod(); + + List getHeaderValues(String name); + + Collection getHeaderNames(); + + List getLocales(); + + String[] getParameterValues(String name); + + Map getParameterMap(); } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java index e8c0163001..3e5c452118 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java @@ -21,7 +21,6 @@ import java.util.Arrays; import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -45,9 +44,6 @@ import org.apache.commons.logging.LogFactory; * *

* Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}. - *

- * - * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks. * * @author Andrey Grebnev * @author Ben Alex @@ -65,7 +61,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { //~ Instance fields ================================================================================================ - protected DefaultSavedRequest savedRequest = null; + protected SavedRequest savedRequest = null; /** * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is @@ -75,7 +71,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { //~ Constructors =================================================================================================== - public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) { + public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) { super(request); savedRequest = saved; @@ -92,9 +88,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public Cookie[] getCookies() { - if (savedRequest == null) { - return super.getCookies(); - } List cookies = savedRequest.getCookies(); return cookies.toArray(new Cookie[cookies.size()]); @@ -102,9 +95,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public long getDateHeader(String name) { - if (savedRequest == null) { - return super.getDateHeader(name); - } String value = getHeader(name); if (value == null) { @@ -123,128 +113,79 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public String getHeader(String name) { - if (savedRequest == null) { - return super.getHeader(name); - } + List values = savedRequest.getHeaderValues(name); - String header = null; - Iterator iterator = savedRequest.getHeaderValues(name); - - while (iterator.hasNext()) { - header = iterator.next(); - - break; - } - - return header; + return values.isEmpty() ? null : values.get(0); } @Override @SuppressWarnings("unchecked") public Enumeration getHeaderNames() { - if (savedRequest == null) { - return super.getHeaderNames(); - } - return new Enumerator(savedRequest.getHeaderNames()); } @Override @SuppressWarnings("unchecked") public Enumeration getHeaders(String name) { - if (savedRequest == null) { - return super.getHeaders(name); - } else { - return new Enumerator(savedRequest.getHeaderValues(name)); - } + return new Enumerator(savedRequest.getHeaderValues(name)); } @Override public int getIntHeader(String name) { - if (savedRequest == null) { - return super.getIntHeader(name); - } else { - String value = getHeader(name); + String value = getHeader(name); - if (value == null) { - return -1; - } else { - return Integer.parseInt(value); - } + if (value == null) { + return -1; + } else { + return Integer.parseInt(value); } } @Override public Locale getLocale() { - if (savedRequest == null) { - return super.getLocale(); - } else { - Locale locale = null; - Iterator iterator = savedRequest.getLocales(); + List locales = savedRequest.getLocales(); - while (iterator.hasNext()) { - locale = (Locale) iterator.next(); - - break; - } - - if (locale == null) { - return defaultLocale; - } else { - return locale; - } - } + return locales.isEmpty() ? Locale.getDefault() : locales.get(0); } @Override @SuppressWarnings("unchecked") public Enumeration getLocales() { - if (savedRequest == null) { - return super.getLocales(); + List locales = savedRequest.getLocales(); + + if (locales.isEmpty()) { + // Fall back to default locale + locales = new ArrayList(1); + locales.add(Locale.getDefault()); } - Iterator iterator = savedRequest.getLocales(); - - if (iterator.hasNext()) { - return new Enumerator(iterator); - } - // Fall back to default locale - ArrayList results = new ArrayList(1); - results.add(defaultLocale); - - return new Enumerator(results.iterator()); + return new Enumerator(locales); } @Override public String getMethod() { - if (savedRequest == null) { - return super.getMethod(); - } else { - return savedRequest.getMethod(); - } + return savedRequest.getMethod(); } /** - * If the parameter is available from the wrapped request then either - *
    - *
  1. There is no saved request (it a normal request)
  2. - *
  3. There is a saved request, but the request has been forwarded/included to a URL with parameters, either - * supplementing or overriding the saved request values.
  4. - *
- * In both cases the value from the wrapped request should be used. + * If the parameter is available from the wrapped request then the request has been forwarded/included to a URL + * with parameters, either supplementing or overriding the saved request values. + *

+ * In this case, the value from the wrapped request should be used. *

* If the value from the wrapped request is null, an attempt will be made to retrieve the parameter - * from the DefaultSavedRequest, if available.. + * from the saved request. */ @Override public String getParameter(String name) { String value = super.getParameter(name); - if (value != null || savedRequest == null) { + if (value != null) { return value; } String[] values = savedRequest.getParameterValues(name); + if (values == null || values.length == 0) { return null; } @@ -255,10 +196,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override @SuppressWarnings("unchecked") public Map getParameterMap() { - if (savedRequest == null) { - return super.getParameterMap(); - } - Set names = getCombinedParameterNames(); Map parameterMap = new HashMap(names.size()); @@ -273,10 +210,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { private Set getCombinedParameterNames() { Set names = new HashSet(); names.addAll(super.getParameterMap().keySet()); - - if (savedRequest != null) { - names.addAll(savedRequest.getParameterMap().keySet()); - } + names.addAll(savedRequest.getParameterMap().keySet()); return names; } @@ -289,10 +223,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public String[] getParameterValues(String name) { - if (savedRequest == null) { - return super.getParameterValues(name); - } - String[] savedRequestParams = savedRequest.getParameterValues(name); String[] wrappedRequestParams = super.getParameterValues(name); diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java similarity index 95% rename from web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java rename to web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java index ee98193587..47a33ab526 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java @@ -10,14 +10,14 @@ import org.springframework.mock.web.MockHttpServletRequest; /** * */ -public class SavedRequestTests { +public class DefaultSavedRequestTests { @Test public void headersAreCaseInsensitive() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("USER-aGenT", "Mozilla"); DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443)); - assertEquals("Mozilla", saved.getHeaderValues("user-agent").next()); + assertEquals("Mozilla", saved.getHeaderValues("user-agent").get(0)); } // TODO: Why are parameters case insensitive. I think this is a mistake diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java index 2d0412918c..a1563b4809 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java @@ -19,19 +19,10 @@ import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper; public class SavedRequestAwareWrapperTests { private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) { - DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl()); + DefaultSavedRequest saved = new DefaultSavedRequest(requestToSave, new PortResolverImpl()); return new SavedRequestAwareWrapper(saved, requestToWrap); } - @Test - public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception { - MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); - wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")}); - SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest); - assertEquals(1, wrapper.getCookies().length); - assertEquals("fromwrapped", wrapper.getCookies()[0].getValue()); - } - @Test public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception { MockHttpServletRequest savedRequest = new MockHttpServletRequest(); @@ -61,27 +52,6 @@ public class SavedRequestAwareWrapperTests { assertEquals("header", wrapper.getHeaderNames().nextElement()); } - @Test - @SuppressWarnings("unchecked") - public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception { - MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); - wrappedRequest.addHeader("header", "wrappedheader"); - SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest); - - assertNull(wrapper.getHeader("nonexistent")); - Enumeration headers = wrapper.getHeaders("nonexistent"); - assertFalse(headers.hasMoreElements()); - - assertEquals("wrappedheader", wrapper.getHeader("header")); - headers = wrapper.getHeaders("header"); - assertTrue(headers.hasMoreElements()); - assertEquals("wrappedheader", headers.nextElement()); - assertFalse(headers.hasMoreElements()); - assertTrue(wrapper.getHeaderNames().hasMoreElements()); - assertEquals("header", wrapper.getHeaderNames().nextElement()); - } - - @Test /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request) * and then RequestDispatcher.forward() it to /someUrl?action=bar. @@ -125,8 +95,7 @@ public class SavedRequestAwareWrapperTests { @Test public void getParameterValuesReturnsNullIfParameterIsntSet() { - MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest); + SavedRequestAwareWrapper wrapper = createWrapper(new MockHttpServletRequest(), new MockHttpServletRequest()); assertNull(wrapper.getParameterValues("action")); assertNull(wrapper.getParameterMap().get("action")); } @@ -148,7 +117,7 @@ public class SavedRequestAwareWrapperTests { } @Test - public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception { + public void expecteDateHeaderIsReturnedFromSavedRequest() throws Exception { SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); String nowString = FastHttpDateFormat.getCurrentDate(); Date now = formatter.parse(nowString); @@ -158,12 +127,6 @@ public class SavedRequestAwareWrapperTests { assertEquals(now.getTime(), wrapper.getDateHeader("header")); assertEquals(-1L, wrapper.getDateHeader("nonexistent")); - - // Now try with no saved request - request = new MockHttpServletRequest(); - request.addHeader("header", now); - wrapper = createWrapper(null, request); - assertEquals(now.getTime(), wrapper.getDateHeader("header")); } @Test(expected=IllegalArgumentException.class) @@ -179,8 +142,6 @@ public class SavedRequestAwareWrapperTests { MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused"); SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused")); assertEquals("PUT", wrapper.getMethod()); - wrapper = createWrapper(null, request); - assertEquals("PUT", wrapper.getMethod()); } @Test @@ -192,9 +153,6 @@ public class SavedRequestAwareWrapperTests { assertEquals(999, wrapper.getIntHeader("header")); assertEquals(-1, wrapper.getIntHeader("nonexistent")); - - wrapper = createWrapper(null, request); - assertEquals(999, wrapper.getIntHeader("header")); } }