Create IP range helper class
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / StatusServlet.java
index e092b3d..9b3afe5 100644 (file)
@@ -18,6 +18,8 @@ package edu.internet2.middleware.shibboleth.idp;
 
 import java.io.IOException;
 import java.io.PrintWriter;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
 
 import javax.servlet.ServletConfig;
 import javax.servlet.ServletException;
@@ -25,20 +27,37 @@ import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.apache.commons.httpclient.HttpStatus;
 import org.joda.time.DateTime;
 import org.joda.time.chrono.ISOChronology;
 import org.joda.time.format.DateTimeFormatter;
 import org.joda.time.format.ISODateTimeFormat;
+import org.opensaml.xml.security.x509.X509Credential;
+import org.opensaml.xml.util.Base64;
 import org.opensaml.xml.util.DatatypeHelper;
+import org.opensaml.xml.util.LazyList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolutionException;
 import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolver;
 import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfiguration;
 import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfigurationManager;
+import edu.internet2.middleware.shibboleth.idp.util.HttpServletHelper;
+import edu.internet2.middleware.shibboleth.idp.util.IPRange;
 
-/** A servlet for displaying the status of the IdP. */
+/** A Servlet for displaying the status of the IdP. */
 public class StatusServlet extends HttpServlet {
 
+    /** Serial version UID. */
+    private static final long serialVersionUID = -5280549109235107879L;
+
+    private final String IP_PARAM_NAME = "AllowedIPs";
+    
+    private final Logger log = LoggerFactory.getLogger(StatusServlet.class);
+
+    private LazyList<IPRange> allowedIPs;
+
     /** Formatter used when print date/times. */
     private DateTimeFormatter dateFormat;
 
@@ -55,37 +74,65 @@ public class StatusServlet extends HttpServlet {
     public void init(ServletConfig config) throws ServletException {
         super.init(config);
 
-        dateFormat = ISODateTimeFormat.dateTimeNoMillis();
+        allowedIPs = new LazyList<IPRange>();
 
-        startTime = new DateTime(ISOChronology.getInstanceUTC());
-
-        String attributeResolverId = config.getInitParameter("attributeResolverId");
-        if (DatatypeHelper.isEmpty(attributeResolverId)) {
-            attributeResolverId = "shibboleth.AttributeResolver";
+        String cidrBlocks = DatatypeHelper.safeTrimOrNullString(config.getInitParameter(IP_PARAM_NAME));
+        if (cidrBlocks != null) {
+            for (String cidrBlock : cidrBlocks.split(" ")) {
+                allowedIPs.add(IPRange.parseCIDRBlock(cidrBlock));
+            }
         }
-        attributeResolver = (AttributeResolver<?>) getServletContext().getAttribute(attributeResolverId);
 
-        String rpConfigManagerId = config.getInitParameter("rpConfigManagerId");
-        if (DatatypeHelper.isEmpty(rpConfigManagerId)) {
-            rpConfigManagerId = "shibboleth.RelyingPartyConfigurationManager";
-        }
-        rpConfigManager = (RelyingPartyConfigurationManager) getServletContext().getAttribute(rpConfigManagerId);
+        dateFormat = ISODateTimeFormat.dateTimeNoMillis();
+        startTime = new DateTime(ISOChronology.getInstanceUTC());
+        attributeResolver = HttpServletHelper.getAttributeResolver(config.getServletContext());
+        rpConfigManager = HttpServletHelper.getRelyingPartyConfirmationManager(config.getServletContext());
     }
 
     /** {@inheritDoc} */
-    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
-        PrintWriter output = resp.getWriter();
+    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
+        if (!isAuthenticated(request)) {
+            response.sendError(HttpStatus.SC_UNAUTHORIZED);
+            return;
+        }
+
+        response.setContentType("text/plain");
+        PrintWriter output = response.getWriter();
 
         printOperatingEnvironmentInformation(output);
         output.println();
         printIdPInformation(output);
         output.println();
-        printRelyingPartyConfigurationsInformation(output, req.getParameter("relyingParty"));
+        printRelyingPartyConfigurationsInformation(output, request.getParameter("relyingParty"));
 
         output.flush();
     }
 
     /**
+     * Checks whether the client is authenticated.
+     * 
+     * @param request client request
+     * 
+     * @return true if the client is authenticated, false if not
+     */
+    protected boolean isAuthenticated(HttpServletRequest request) throws ServletException {
+        log.debug("Attempting to authenticate client '{}'", request.getRemoteAddr());
+        try {
+            InetAddress clientAddress = InetAddress.getByName(request.getRemoteAddr());
+
+            for (IPRange range : allowedIPs) {
+                if (range.contains(clientAddress)) {
+                    return true;
+                }
+            }
+
+            return false;
+        } catch (UnknownHostException e) {
+            throw new ServletException(e);
+        }
+    }
+
+    /**
      * Prints out information about the operating environment. This includes the operating system name, version and
      * architecture, the JDK version, available CPU cores, memory currently used by the JVM process, the maximum amount
      * of memory that may be used by the JVM, and the current time in UTC.
@@ -167,6 +214,16 @@ public class StatusServlet extends HttpServlet {
             out.println("default_authentication_method: none");
         }
 
+        try {
+            X509Credential signingCredential = (X509Credential) config.getDefaultSigningCredential();
+            out
+                    .println("default_signing_tls_key: "
+                            + Base64.encodeBytes(signingCredential.getEntityCertificate().getEncoded(),
+                                    Base64.DONT_BREAK_LINES));
+        } catch (Throwable t) {
+            // swallow error
+        }
+
         for (String profileId : config.getProfileConfigurations().keySet()) {
             out.println("configured_communication_profile: " + profileId);
         }