diff --git a/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java b/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java index 3580cf182e..c0b42d934e 100644 --- a/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java +++ b/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java @@ -1,15 +1,11 @@ package org.springframework.security.web.access.expression; -import java.net.InetAddress; -import java.net.UnknownHostException; -import java.util.Arrays; - import javax.servlet.http.HttpServletRequest; import org.springframework.security.access.expression.SecurityExpressionRoot; import org.springframework.security.core.Authentication; import org.springframework.security.web.FilterInvocation; -import org.springframework.util.StringUtils; +import org.springframework.security.web.util.IpAddressMatcher; /** * @@ -34,57 +30,7 @@ public class WebSecurityExpressionRoot extends SecurityExpressionRoot { * @return true if the IP address of the current request is in the required range. */ public boolean hasIpAddress(String ipAddress) { - int nMaskBits = 0; - - if (ipAddress.indexOf('/') > 0) { - String[] addressAndMask = StringUtils.split(ipAddress, "/"); - ipAddress = addressAndMask[0]; - nMaskBits = Integer.parseInt(addressAndMask[1]); - } - - InetAddress requiredAddress = parseAddress(ipAddress); - InetAddress remoteAddress = parseAddress(request.getRemoteAddr()); - - if (!requiredAddress.getClass().equals(remoteAddress.getClass())) { - throw new IllegalArgumentException("IP Address in expression must be the same type as " + - "version returned by request"); - } - - if (nMaskBits == 0) { - return remoteAddress.equals(requiredAddress); - } - - byte[] remAddr = remoteAddress.getAddress(); - byte[] reqAddr = requiredAddress.getAddress(); - - int oddBits = nMaskBits % 8; - int nMaskBytes = nMaskBits/8 + (oddBits == 0 ? 0 : 1); - byte[] mask = new byte[nMaskBytes]; - - Arrays.fill(mask, 0, oddBits == 0 ? mask.length : mask.length - 1, (byte)0xFF); - - if (oddBits != 0) { - int finalByte = (1 << oddBits) - 1; - finalByte <<= 8-oddBits; - mask[mask.length - 1] = (byte) finalByte; - } - - // System.out.println("Mask is " + new sun.misc.HexDumpEncoder().encode(mask)); - - for (int i=0; i < mask.length; i++) { - if ((remAddr[i] & mask[i]) != (reqAddr[i] & mask[i])) { - return false; - } - } - - return true; + return (new IpAddressMatcher(ipAddress).matches(request)); } - private InetAddress parseAddress(String address) { - try { - return InetAddress.getByName(address); - } catch (UnknownHostException e) { - throw new IllegalArgumentException("Failed to parse address" + address, e); - } - } } diff --git a/web/src/main/java/org/springframework/security/web/util/IpAddressMatcher.java b/web/src/main/java/org/springframework/security/web/util/IpAddressMatcher.java new file mode 100644 index 0000000000..91b2223fab --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/util/IpAddressMatcher.java @@ -0,0 +1,84 @@ +package org.springframework.security.web.util; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Arrays; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.util.StringUtils; + +/** + * Matches a request based on IP Address or subnet mask matching against the remote address. + * + * @author Luke Taylor + * @since 3.0.2 + */ +public class IpAddressMatcher implements RequestMatcher { + private final int nMaskBits; + private final InetAddress requiredAddress; + + /** + * Takes a specific IP address or a range specified using the + * IP/Netmask (e.g. 192.168.1.0/24 or 202.24.0.0/14). + * + * @param ipAddress the address or range of addresses from which the request must come. + */ + public IpAddressMatcher(String ipAddress) { + + if (ipAddress.indexOf('/') > 0) { + String[] addressAndMask = StringUtils.split(ipAddress, "/"); + ipAddress = addressAndMask[0]; + nMaskBits = Integer.parseInt(addressAndMask[1]); + } else { + nMaskBits = 0; + } + requiredAddress = parseAddress(ipAddress); + } + + public boolean matches(HttpServletRequest request) { + InetAddress remoteAddress = parseAddress(request.getRemoteAddr()); + + if (!requiredAddress.getClass().equals(remoteAddress.getClass())) { + throw new IllegalArgumentException("IP Address in expression must be the same type as " + + "version returned by request"); + } + + if (nMaskBits == 0) { + return remoteAddress.equals(requiredAddress); + } + + byte[] remAddr = remoteAddress.getAddress(); + byte[] reqAddr = requiredAddress.getAddress(); + + int oddBits = nMaskBits % 8; + int nMaskBytes = nMaskBits/8 + (oddBits == 0 ? 0 : 1); + byte[] mask = new byte[nMaskBytes]; + + Arrays.fill(mask, 0, oddBits == 0 ? mask.length : mask.length - 1, (byte)0xFF); + + if (oddBits != 0) { + int finalByte = (1 << oddBits) - 1; + finalByte <<= 8-oddBits; + mask[mask.length - 1] = (byte) finalByte; + } + + // System.out.println("Mask is " + new sun.misc.HexDumpEncoder().encode(mask)); + + for (int i=0; i < mask.length; i++) { + if ((remAddr[i] & mask[i]) != (reqAddr[i] & mask[i])) { + return false; + } + } + + return true; + } + + private InetAddress parseAddress(String address) { + try { + return InetAddress.getByName(address); + } catch (UnknownHostException e) { + throw new IllegalArgumentException("Failed to parse address" + address, e); + } + } +}