Create IP range helper class
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / authn / provider / IPAddressLoginHandler.java
index d6f316e..6e02809 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright [2006] [University Corporation for Advanced Internet Development, Inc.]
+ * Copyright 2006 University Corporation for Advanced Internet Development, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 package edu.internet2.middleware.shibboleth.idp.authn.provider;
 
-import java.net.Inet4Address;
-import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
-import java.util.BitSet;
+import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.CopyOnWriteArrayList;
 
-import javax.servlet.ServletRequest;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.opensaml.xml.util.DatatypeHelper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.slf4j.helpers.MessageFormatter;
 
 import edu.internet2.middleware.shibboleth.idp.authn.AuthenticationEngine;
+import edu.internet2.middleware.shibboleth.idp.authn.AuthenticationException;
 import edu.internet2.middleware.shibboleth.idp.authn.LoginHandler;
+import edu.internet2.middleware.shibboleth.idp.util.IPRange;
 
 /**
  * IP Address authentication handler.
  * 
  * This "authenticates" a user based on their IP address. It operates in either default deny or default allow mode, and
  * evaluates a given request against a list of blocked or permitted IPs. It supports both IPv4 and IPv6.
- * 
- * If an Authentication Context Class or DeclRef URI is not specified, it will default to
- * "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocol".
  */
 public class IPAddressLoginHandler extends AbstractLoginHandler {
 
@@ -49,37 +46,26 @@ public class IPAddressLoginHandler extends AbstractLoginHandler {
     private final Logger log = LoggerFactory.getLogger(IPAddressLoginHandler.class);
 
     /** The username to use for IP-address "authenticated" users. */
-    private String username;
-
-    /** Are the IPs in ipList a permitted list or a deny list. */
-    private boolean defaultDeny;
+    private String authenticatedUser;
 
-    /** The list of denied or permitted IPs. */
-    private List<IPEntry> ipList;
+    /** List of configured IP ranged. */
+    private List<IPRange> ipRanges;
 
-    /**
-     * Set the permitted IP addresses.
-     * 
-     * If <code>defaultDeny</code> is <code>true</code> then only the IP addresses in <code>ipList</code> will be
-     * "authenticated." If <code>defaultDeny</code> is <code>false</code>, then all IP addresses except those in
-     * <code>ipList</code> will be authenticated.
-     * 
-     * @param entries A list of IP addresses (with CIDR masks).
-     * @param defaultDeny Does <code>ipList</code> contain a deny or permit list.
-     */
-    public void setEntries(final List<String> entries, boolean defaultDeny) {
+    /** Whether a user is "authenticated" if their IP address is within a configured IP range. */
+    private boolean ipInRangeIsAuthenticated;
 
-        this.defaultDeny = defaultDeny;
-        ipList = new CopyOnWriteArrayList<IPEntry>();
+    public IPAddressLoginHandler(String user, List<IPRange> ranges, boolean ipInRangeIsAuthenticated) {
+        authenticatedUser = DatatypeHelper.safeTrimOrNullString(user);
+        if (authenticatedUser == null) {
+            throw new IllegalArgumentException("The authenticated user ID may not be null or empty");
+        }
 
-        for (String addr : entries) {
-            try {
-                ipList.add(new edu.internet2.middleware.shibboleth.idp.authn.provider.IPAddressLoginHandler.IPEntry(
-                        addr));
-            } catch (UnknownHostException ex) {
-                log.warn("IPAddressHandler: Error parsing IP entry \"" + addr + "\". Ignoring.");
-            }
+        if (ranges == null || ranges.isEmpty()) {
+            throw new IllegalArgumentException("The list of IP ranges may not be null or empty");
         }
+        ipRanges = new ArrayList<IPRange>(ranges);
+
+        this.ipInRangeIsAuthenticated = ipInRangeIsAuthenticated;
     }
 
     /** {@inheritDoc} */
@@ -92,194 +78,51 @@ public class IPAddressLoginHandler extends AbstractLoginHandler {
         return true;
     }
 
-    /**
-     * Get the username for all IP-address authenticated users.
-     * 
-     * @return The username for IP-address authenticated users.
-     */
-    public String getUsername() {
-        return username;
-    }
-
-    /**
-     * Set the username to use for all IP-address authenticated users.
-     * 
-     * @param name The username for IP-address authenticated users.
-     */
-    public void setUsername(String name) {
-        username = name;
-    }
-
     /** {@inheritDoc} */
     public void login(HttpServletRequest httpRequest, HttpServletResponse httpResponse) {
-
-        if (defaultDeny) {
-            handleDefaultDeny(httpRequest, httpResponse);
-        } else {
-            handleDefaultAllow(httpRequest, httpResponse);
-        }
-
-        AuthenticationEngine.returnToAuthenticationEngine(httpRequest, httpResponse);
-    }
-
-    protected void handleDefaultDeny(HttpServletRequest request, HttpServletResponse response) {
-
-        boolean ipAllowed = searchIpList(request);
-
-        if (ipAllowed) {
-            log.debug("Authenticated user by IP address");
-            request.setAttribute(LoginHandler.PRINCIPAL_NAME_KEY, username);
-        }
-    }
-
-    protected void handleDefaultAllow(HttpServletRequest request, HttpServletResponse response) {
-
-        boolean ipDenied = searchIpList(request);
-
-        if (!ipDenied) {
-            log.debug("Authenticated user by IP address");
-            request.setAttribute(LoginHandler.PRINCIPAL_NAME_KEY, username);
-        }
-    }
-
-    /**
-     * Search the list of InetAddresses for the client's address.
-     * 
-     * @param request The ServletReqeust
-     * 
-     * @return <code>true</code> if the client's address is in <code>ipList</code>
-     */
-    private boolean searchIpList(ServletRequest request) {
-
-        boolean found = false;
-
+        log.debug("Attempting to authenticated client '{}'", httpRequest.getRemoteAddr());
         try {
-            InetAddress addr = InetAddress.getByName(request.getRemoteAddr());
-            BitSet addrbits = byteArrayToBitSet(addr.getAddress());
-
-            for (IPEntry entry : ipList) {
-
-                BitSet netaddr = entry.getNetworkAddress();
-                BitSet netmask = entry.getNetmask();
-
-                addrbits.and(netmask);
-                if (addrbits.equals(netaddr)) {
-                    found = true;
-                    break;
-                }
+            InetAddress clientAddress = InetAddress.getByName(httpRequest.getRemoteAddr());
+            if (authenticate(clientAddress)) {
+                log.debug("Authenticated user by IP address");
+                httpRequest.setAttribute(LoginHandler.PRINCIPAL_NAME_KEY, authenticatedUser);
+            } else {
+                log.debug("Client IP address {} failed authentication.", httpRequest.getRemoteAddr());
+                httpRequest.setAttribute(LoginHandler.AUTHENTICATION_ERROR_KEY, new AuthenticationException(
+                        "Client failed IP address authentication"));
             }
-
-        } catch (UnknownHostException ex) {
-            log.error("Error resolving hostname.", ex);
-            return false;
+        } catch (UnknownHostException e) {
+            String msg = MessageFormatter.format("Unable to resolve {} in to an IP address", httpRequest
+                    .getRemoteAddr());
+            log.warn(msg);
+            httpRequest.setAttribute(LoginHandler.AUTHENTICATION_ERROR_KEY, new AuthenticationException(msg));
         }
 
-        return found;
+        AuthenticationEngine.returnToAuthenticationEngine(httpRequest, httpResponse);
     }
 
     /**
-     * Converts a byte array to a BitSet.
+     * Authenticates the client address.
      * 
-     * The supplied byte array is assumed to have the most signifigant bit in element 0.
+     * @param clientAddress the client address
      * 
-     * @param bytes the byte array with most signifigant bit in element 0.
-     * 
-     * @return the BitSet
-     */
-    protected BitSet byteArrayToBitSet(final byte[] bytes) {
-
-        BitSet bits = new BitSet();
-
-        for (int i = 0; i < bytes.length * 8; i++) {
-            if ((bytes[bytes.length - i / 8 - 1] & (1 << (i % 8))) > 0) {
-                bits.set(i);
-            }
-        }
-
-        return bits;
-    }
-
-    /**
-     * Encapsulates a network address and a netmask on ipList.
+     * @return true if the client address is authenticated, false it not
      */
-    protected class IPEntry {
-
-        /** The network address. */
-        private final BitSet networkAddress;
-
-        /** The netmask. */
-        private final BitSet netmask;
-
-        /**
-         * Construct a new IPEntry given a network address in CIDR format.
-         * 
-         * @param entry A CIDR-formatted network address/netmask
-         * 
-         * @throws UnknownHostException If entry is malformed.
-         */
-        public IPEntry(String entry) throws UnknownHostException {
-
-            // quick sanity checks
-            if (entry == null || entry.length() == 0) {
-                throw new UnknownHostException("entry is null.");
-            }
-
-            int cidrOffset = entry.indexOf("/");
-            if (cidrOffset == -1) {
-                log.warn("Invalid entry \"" + entry + "\" -- it lacks a netmask component.");
-                throw new UnknownHostException("entry lacks a netmask component.");
-            }
-
-            // ensure that only one "/" is present.
-            if (entry.indexOf("/", cidrOffset + 1) != -1) {
-                log.warn("Invalid entry \"" + entry + "\" -- too many \"/\" present.");
-                throw new UnknownHostException("entry has too many netmask components.");
-            }
-
-            String networkString = entry.substring(0, cidrOffset);
-            String netmaskString = entry.substring(cidrOffset + 1, entry.length());
-
-            InetAddress tempAddr = InetAddress.getByName(networkString);
-            networkAddress = byteArrayToBitSet(tempAddr.getAddress());
-
-            int masklen = Integer.parseInt(netmaskString);
-
-            int addrlen;
-            if (tempAddr instanceof Inet4Address) {
-                addrlen = 32;
-            } else if (tempAddr instanceof Inet6Address) {
-                addrlen = 128;
-            }else{
-                throw new UnknownHostException("Unable to determine Inet protocol version");
+    protected boolean authenticate(InetAddress clientAddress) {
+        if (ipInRangeIsAuthenticated) {
+            for (IPRange range : ipRanges) {
+                if (range.contains(clientAddress)) {
+                    return true;
+                }
             }
-
-            // ensure that the netmask isn't too large
-            if ((tempAddr instanceof Inet4Address) && (masklen > 32)) {
-                throw new UnknownHostException("Netmask is too large for an IPv4 address: " + masklen);
-            } else if ((tempAddr instanceof Inet6Address) && masklen > 128) {
-                throw new UnknownHostException("Netmask is too large for an IPv6 address: " + masklen);
+        } else {
+            for (IPRange range : ipRanges) {
+                if (!range.contains(clientAddress)) {
+                    return true;
+                }
             }
-
-            netmask = new BitSet(addrlen);
-            netmask.set(addrlen - masklen, addrlen, true);
-        }
-
-        /**
-         * Get the network address.
-         * 
-         * @return the network address.
-         */
-        public BitSet getNetworkAddress() {
-            return networkAddress;
         }
 
-        /**
-         * Get the netmask.
-         * 
-         * @return the netmask.
-         */
-        public BitSet getNetmask() {
-            return netmask;
-        }
+        return false;
     }
 }
\ No newline at end of file