diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java
index 41921bb4da..36b4af4e09 100644
--- a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java
+++ b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java
@@ -15,23 +15,24 @@
package org.springframework.security.web.savedrequest;
-import org.springframework.security.web.PortResolver;
-import org.springframework.security.web.util.UrlUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.springframework.util.Assert;
-
-import javax.servlet.http.Cookie;
-import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
-import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.web.PortResolver;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.Assert;
+
/**
* Represents central information from a HttpServletRequest.
This class is used by {@link
@@ -237,22 +238,22 @@ public class DefaultSavedRequest implements SavedRequest {
pathInfo, queryString);
}
- public Iterator getHeaderNames() {
- return (headers.keySet().iterator());
+ public Collection getHeaderNames() {
+ return headers.keySet();
}
- public Iterator getHeaderValues(String name) {
+ public List getHeaderValues(String name) {
List values = headers.get(name);
if (values == null) {
- values = Collections.emptyList();
+ return Collections.emptyList();
}
- return (values.iterator());
+ return values;
}
- public Iterator getLocales() {
- return (locales.iterator());
+ public List getLocales() {
+ return locales;
}
public String getMethod() {
@@ -263,8 +264,8 @@ public class DefaultSavedRequest implements SavedRequest {
return parameters;
}
- public Iterator getParameterNames() {
- return (parameters.keySet().iterator());
+ public Collection getParameterNames() {
+ return parameters.keySet();
}
public String[] getParameterValues(String name) {
diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java
index 8f15e2d6ef..9d1307db59 100644
--- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java
+++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java
@@ -1,8 +1,16 @@
package org.springframework.security.web.savedrequest;
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
+
/**
- * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically
- * form-based login) to redirect to the original URL.
+ * Encapsulates the functionality required of a cached request for both an authentication mechanism (typically
+ * form-based login) to redirect to the original URL and for a RequestCache to build a wrapped request,
+ * reproducing the original request data.
*
* @author Luke Taylor
* @version $Id$
@@ -14,4 +22,18 @@ public interface SavedRequest extends java.io.Serializable {
* @return the URL for the saved request, allowing a redirect to be performed.
*/
String getRedirectUrl();
+
+ List getCookies();
+
+ String getMethod();
+
+ List getHeaderValues(String name);
+
+ Collection getHeaderNames();
+
+ List getLocales();
+
+ String[] getParameterValues(String name);
+
+ Map getParameterMap();
}
diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java
index e8c0163001..3e5c452118 100644
--- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java
+++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java
@@ -21,7 +21,6 @@ import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -45,9 +44,6 @@ import org.apache.commons.logging.LogFactory;
*
*
* Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
- *
- *
- * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks.
*
* @author Andrey Grebnev
* @author Ben Alex
@@ -65,7 +61,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
//~ Instance fields ================================================================================================
- protected DefaultSavedRequest savedRequest = null;
+ protected SavedRequest savedRequest = null;
/**
* The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is
@@ -75,7 +71,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
//~ Constructors ===================================================================================================
- public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) {
+ public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) {
super(request);
savedRequest = saved;
@@ -92,9 +88,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
@Override
public Cookie[] getCookies() {
- if (savedRequest == null) {
- return super.getCookies();
- }
List cookies = savedRequest.getCookies();
return cookies.toArray(new Cookie[cookies.size()]);
@@ -102,9 +95,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
@Override
public long getDateHeader(String name) {
- if (savedRequest == null) {
- return super.getDateHeader(name);
- }
String value = getHeader(name);
if (value == null) {
@@ -123,128 +113,79 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
@Override
public String getHeader(String name) {
- if (savedRequest == null) {
- return super.getHeader(name);
- }
+ List values = savedRequest.getHeaderValues(name);
- String header = null;
- Iterator iterator = savedRequest.getHeaderValues(name);
-
- while (iterator.hasNext()) {
- header = iterator.next();
-
- break;
- }
-
- return header;
+ return values.isEmpty() ? null : values.get(0);
}
@Override
@SuppressWarnings("unchecked")
public Enumeration getHeaderNames() {
- if (savedRequest == null) {
- return super.getHeaderNames();
- }
-
return new Enumerator(savedRequest.getHeaderNames());
}
@Override
@SuppressWarnings("unchecked")
public Enumeration getHeaders(String name) {
- if (savedRequest == null) {
- return super.getHeaders(name);
- } else {
- return new Enumerator(savedRequest.getHeaderValues(name));
- }
+ return new Enumerator(savedRequest.getHeaderValues(name));
}
@Override
public int getIntHeader(String name) {
- if (savedRequest == null) {
- return super.getIntHeader(name);
- } else {
- String value = getHeader(name);
+ String value = getHeader(name);
- if (value == null) {
- return -1;
- } else {
- return Integer.parseInt(value);
- }
+ if (value == null) {
+ return -1;
+ } else {
+ return Integer.parseInt(value);
}
}
@Override
public Locale getLocale() {
- if (savedRequest == null) {
- return super.getLocale();
- } else {
- Locale locale = null;
- Iterator iterator = savedRequest.getLocales();
+ List locales = savedRequest.getLocales();
- while (iterator.hasNext()) {
- locale = (Locale) iterator.next();
-
- break;
- }
-
- if (locale == null) {
- return defaultLocale;
- } else {
- return locale;
- }
- }
+ return locales.isEmpty() ? Locale.getDefault() : locales.get(0);
}
@Override
@SuppressWarnings("unchecked")
public Enumeration getLocales() {
- if (savedRequest == null) {
- return super.getLocales();
+ List locales = savedRequest.getLocales();
+
+ if (locales.isEmpty()) {
+ // Fall back to default locale
+ locales = new ArrayList(1);
+ locales.add(Locale.getDefault());
}
- Iterator iterator = savedRequest.getLocales();
-
- if (iterator.hasNext()) {
- return new Enumerator(iterator);
- }
- // Fall back to default locale
- ArrayList results = new ArrayList(1);
- results.add(defaultLocale);
-
- return new Enumerator(results.iterator());
+ return new Enumerator(locales);
}
@Override
public String getMethod() {
- if (savedRequest == null) {
- return super.getMethod();
- } else {
- return savedRequest.getMethod();
- }
+ return savedRequest.getMethod();
}
/**
- * If the parameter is available from the wrapped request then either
- *
- * - There is no saved request (it a normal request)
- * - There is a saved request, but the request has been forwarded/included to a URL with parameters, either
- * supplementing or overriding the saved request values.
- *
- * In both cases the value from the wrapped request should be used.
+ * If the parameter is available from the wrapped request then the request has been forwarded/included to a URL
+ * with parameters, either supplementing or overriding the saved request values.
+ *
+ * In this case, the value from the wrapped request should be used.
*
* If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
- * from the DefaultSavedRequest, if available..
+ * from the saved request.
*/
@Override
public String getParameter(String name) {
String value = super.getParameter(name);
- if (value != null || savedRequest == null) {
+ if (value != null) {
return value;
}
String[] values = savedRequest.getParameterValues(name);
+
if (values == null || values.length == 0) {
return null;
}
@@ -255,10 +196,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
@Override
@SuppressWarnings("unchecked")
public Map getParameterMap() {
- if (savedRequest == null) {
- return super.getParameterMap();
- }
-
Set names = getCombinedParameterNames();
Map parameterMap = new HashMap(names.size());
@@ -273,10 +210,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
private Set getCombinedParameterNames() {
Set names = new HashSet();
names.addAll(super.getParameterMap().keySet());
-
- if (savedRequest != null) {
- names.addAll(savedRequest.getParameterMap().keySet());
- }
+ names.addAll(savedRequest.getParameterMap().keySet());
return names;
}
@@ -289,10 +223,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
@Override
public String[] getParameterValues(String name) {
- if (savedRequest == null) {
- return super.getParameterValues(name);
- }
-
String[] savedRequestParams = savedRequest.getParameterValues(name);
String[] wrappedRequestParams = super.getParameterValues(name);
diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java
similarity index 95%
rename from web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java
rename to web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java
index ee98193587..47a33ab526 100644
--- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java
+++ b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java
@@ -10,14 +10,14 @@ import org.springframework.mock.web.MockHttpServletRequest;
/**
*
*/
-public class SavedRequestTests {
+public class DefaultSavedRequestTests {
@Test
public void headersAreCaseInsensitive() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("USER-aGenT", "Mozilla");
DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443));
- assertEquals("Mozilla", saved.getHeaderValues("user-agent").next());
+ assertEquals("Mozilla", saved.getHeaderValues("user-agent").get(0));
}
// TODO: Why are parameters case insensitive. I think this is a mistake
diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java
index 2d0412918c..a1563b4809 100644
--- a/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java
+++ b/web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java
@@ -19,19 +19,10 @@ import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper;
public class SavedRequestAwareWrapperTests {
private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
- DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl());
+ DefaultSavedRequest saved = new DefaultSavedRequest(requestToSave, new PortResolverImpl());
return new SavedRequestAwareWrapper(saved, requestToWrap);
}
- @Test
- public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception {
- MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
- wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")});
- SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
- assertEquals(1, wrapper.getCookies().length);
- assertEquals("fromwrapped", wrapper.getCookies()[0].getValue());
- }
-
@Test
public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
MockHttpServletRequest savedRequest = new MockHttpServletRequest();
@@ -61,27 +52,6 @@ public class SavedRequestAwareWrapperTests {
assertEquals("header", wrapper.getHeaderNames().nextElement());
}
- @Test
- @SuppressWarnings("unchecked")
- public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception {
- MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
- wrappedRequest.addHeader("header", "wrappedheader");
- SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-
- assertNull(wrapper.getHeader("nonexistent"));
- Enumeration headers = wrapper.getHeaders("nonexistent");
- assertFalse(headers.hasMoreElements());
-
- assertEquals("wrappedheader", wrapper.getHeader("header"));
- headers = wrapper.getHeaders("header");
- assertTrue(headers.hasMoreElements());
- assertEquals("wrappedheader", headers.nextElement());
- assertFalse(headers.hasMoreElements());
- assertTrue(wrapper.getHeaderNames().hasMoreElements());
- assertEquals("header", wrapper.getHeaderNames().nextElement());
- }
-
-
@Test
/* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
* and then RequestDispatcher.forward() it to /someUrl?action=bar.
@@ -125,8 +95,7 @@ public class SavedRequestAwareWrapperTests {
@Test
public void getParameterValuesReturnsNullIfParameterIsntSet() {
- MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
- SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest);
+ SavedRequestAwareWrapper wrapper = createWrapper(new MockHttpServletRequest(), new MockHttpServletRequest());
assertNull(wrapper.getParameterValues("action"));
assertNull(wrapper.getParameterMap().get("action"));
}
@@ -148,7 +117,7 @@ public class SavedRequestAwareWrapperTests {
}
@Test
- public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception {
+ public void expecteDateHeaderIsReturnedFromSavedRequest() throws Exception {
SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
String nowString = FastHttpDateFormat.getCurrentDate();
Date now = formatter.parse(nowString);
@@ -158,12 +127,6 @@ public class SavedRequestAwareWrapperTests {
assertEquals(now.getTime(), wrapper.getDateHeader("header"));
assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
-
- // Now try with no saved request
- request = new MockHttpServletRequest();
- request.addHeader("header", now);
- wrapper = createWrapper(null, request);
- assertEquals(now.getTime(), wrapper.getDateHeader("header"));
}
@Test(expected=IllegalArgumentException.class)
@@ -179,8 +142,6 @@ public class SavedRequestAwareWrapperTests {
MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
assertEquals("PUT", wrapper.getMethod());
- wrapper = createWrapper(null, request);
- assertEquals("PUT", wrapper.getMethod());
}
@Test
@@ -192,9 +153,6 @@ public class SavedRequestAwareWrapperTests {
assertEquals(999, wrapper.getIntHeader("header"));
assertEquals(-1, wrapper.getIntHeader("nonexistent"));
-
- wrapper = createWrapper(null, request);
- assertEquals(999, wrapper.getIntHeader("header"));
}
}