Create IP range helper class
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / StatusServlet.java
index a63cdfb..9b3afe5 100644 (file)
@@ -18,6 +18,8 @@ package edu.internet2.middleware.shibboleth.idp;
 
 import java.io.IOException;
 import java.io.PrintWriter;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
 
 import javax.servlet.ServletConfig;
 import javax.servlet.ServletException;
@@ -25,24 +27,36 @@ import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.apache.commons.httpclient.HttpStatus;
 import org.joda.time.DateTime;
 import org.joda.time.chrono.ISOChronology;
 import org.joda.time.format.DateTimeFormatter;
 import org.joda.time.format.ISODateTimeFormat;
 import org.opensaml.xml.security.x509.X509Credential;
 import org.opensaml.xml.util.Base64;
+import org.opensaml.xml.util.DatatypeHelper;
+import org.opensaml.xml.util.LazyList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolutionException;
 import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolver;
 import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfiguration;
 import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfigurationManager;
 import edu.internet2.middleware.shibboleth.idp.util.HttpServletHelper;
+import edu.internet2.middleware.shibboleth.idp.util.IPRange;
 
 /** A Servlet for displaying the status of the IdP. */
 public class StatusServlet extends HttpServlet {
 
     /** Serial version UID. */
-    private static final long serialVersionUID = 7917509317276109266L;
+    private static final long serialVersionUID = -5280549109235107879L;
+
+    private final String IP_PARAM_NAME = "AllowedIPs";
+    
+    private final Logger log = LoggerFactory.getLogger(StatusServlet.class);
+
+    private LazyList<IPRange> allowedIPs;
 
     /** Formatter used when print date/times. */
     private DateTimeFormatter dateFormat;
@@ -60,6 +74,15 @@ public class StatusServlet extends HttpServlet {
     public void init(ServletConfig config) throws ServletException {
         super.init(config);
 
+        allowedIPs = new LazyList<IPRange>();
+
+        String cidrBlocks = DatatypeHelper.safeTrimOrNullString(config.getInitParameter(IP_PARAM_NAME));
+        if (cidrBlocks != null) {
+            for (String cidrBlock : cidrBlocks.split(" ")) {
+                allowedIPs.add(IPRange.parseCIDRBlock(cidrBlock));
+            }
+        }
+
         dateFormat = ISODateTimeFormat.dateTimeNoMillis();
         startTime = new DateTime(ISOChronology.getInstanceUTC());
         attributeResolver = HttpServletHelper.getAttributeResolver(config.getServletContext());
@@ -67,20 +90,49 @@ public class StatusServlet extends HttpServlet {
     }
 
     /** {@inheritDoc} */
-    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
-        resp.setContentType("text/plain");
-        PrintWriter output = resp.getWriter();
+    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
+        if (!isAuthenticated(request)) {
+            response.sendError(HttpStatus.SC_UNAUTHORIZED);
+            return;
+        }
+
+        response.setContentType("text/plain");
+        PrintWriter output = response.getWriter();
 
         printOperatingEnvironmentInformation(output);
         output.println();
         printIdPInformation(output);
         output.println();
-        printRelyingPartyConfigurationsInformation(output, req.getParameter("relyingParty"));
+        printRelyingPartyConfigurationsInformation(output, request.getParameter("relyingParty"));
 
         output.flush();
     }
 
     /**
+     * Checks whether the client is authenticated.
+     * 
+     * @param request client request
+     * 
+     * @return true if the client is authenticated, false if not
+     */
+    protected boolean isAuthenticated(HttpServletRequest request) throws ServletException {
+        log.debug("Attempting to authenticate client '{}'", request.getRemoteAddr());
+        try {
+            InetAddress clientAddress = InetAddress.getByName(request.getRemoteAddr());
+
+            for (IPRange range : allowedIPs) {
+                if (range.contains(clientAddress)) {
+                    return true;
+                }
+            }
+
+            return false;
+        } catch (UnknownHostException e) {
+            throw new ServletException(e);
+        }
+    }
+
+    /**
      * Prints out information about the operating environment. This includes the operating system name, version and
      * architecture, the JDK version, available CPU cores, memory currently used by the JVM process, the maximum amount
      * of memory that may be used by the JVM, and the current time in UTC.
@@ -155,20 +207,23 @@ public class StatusServlet extends HttpServlet {
     protected void printRelyingPartyConfigurationInformation(PrintWriter out, RelyingPartyConfiguration config) {
         out.println("relying_party_id: " + config.getRelyingPartyId());
         out.println("idp_entity_id: " + config.getProviderId());
-        
+
         if (config.getDefaultAuthenticationMethod() != null) {
             out.println("default_authentication_method: " + config.getDefaultAuthenticationMethod());
         } else {
             out.println("default_authentication_method: none");
         }
 
-        try{
+        try {
             X509Credential signingCredential = (X509Credential) config.getDefaultSigningCredential();
-            out.println("default_signing_tls_key: " + Base64.encodeBytes(signingCredential.getEntityCertificate().getEncoded(), Base64.DONT_BREAK_LINES));
-        }catch(Throwable t){
+            out
+                    .println("default_signing_tls_key: "
+                            + Base64.encodeBytes(signingCredential.getEntityCertificate().getEncoded(),
+                                    Base64.DONT_BREAK_LINES));
+        } catch (Throwable t) {
             // swallow error
         }
-        
+
         for (String profileId : config.getProfileConfigurations().keySet()) {
             out.println("configured_communication_profile: " + profileId);
         }