import java.io.IOException;
import java.io.PrintWriter;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
+import org.apache.commons.httpclient.HttpStatus;
import org.joda.time.DateTime;
import org.joda.time.chrono.ISOChronology;
import org.joda.time.format.DateTimeFormatter;
import org.joda.time.format.ISODateTimeFormat;
import org.opensaml.xml.security.x509.X509Credential;
import org.opensaml.xml.util.Base64;
+import org.opensaml.xml.util.DatatypeHelper;
+import org.opensaml.xml.util.LazyList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolutionException;
import edu.internet2.middleware.shibboleth.common.attribute.resolver.AttributeResolver;
import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfiguration;
import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfigurationManager;
import edu.internet2.middleware.shibboleth.idp.util.HttpServletHelper;
+import edu.internet2.middleware.shibboleth.idp.util.IPRange;
/** A Servlet for displaying the status of the IdP. */
public class StatusServlet extends HttpServlet {
/** Serial version UID. */
- private static final long serialVersionUID = 7917509317276109266L;
+ private static final long serialVersionUID = -5280549109235107879L;
+
+ private final String IP_PARAM_NAME = "AllowedIPs";
+
+ private final Logger log = LoggerFactory.getLogger(StatusServlet.class);
+
+ private LazyList<IPRange> allowedIPs;
/** Formatter used when print date/times. */
private DateTimeFormatter dateFormat;
public void init(ServletConfig config) throws ServletException {
super.init(config);
+ allowedIPs = new LazyList<IPRange>();
+
+ String cidrBlocks = DatatypeHelper.safeTrimOrNullString(config.getInitParameter(IP_PARAM_NAME));
+ if (cidrBlocks != null) {
+ for (String cidrBlock : cidrBlocks.split(" ")) {
+ allowedIPs.add(IPRange.parseCIDRBlock(cidrBlock));
+ }
+ }
+
dateFormat = ISODateTimeFormat.dateTimeNoMillis();
startTime = new DateTime(ISOChronology.getInstanceUTC());
attributeResolver = HttpServletHelper.getAttributeResolver(config.getServletContext());
}
/** {@inheritDoc} */
- protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
- resp.setContentType("text/plain");
- PrintWriter output = resp.getWriter();
+ protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
+ if (!isAuthenticated(request)) {
+ response.sendError(HttpStatus.SC_UNAUTHORIZED);
+ return;
+ }
+
+ response.setContentType("text/plain");
+ PrintWriter output = response.getWriter();
printOperatingEnvironmentInformation(output);
output.println();
printIdPInformation(output);
output.println();
- printRelyingPartyConfigurationsInformation(output, req.getParameter("relyingParty"));
+ printRelyingPartyConfigurationsInformation(output, request.getParameter("relyingParty"));
output.flush();
}
/**
+ * Checks whether the client is authenticated.
+ *
+ * @param request client request
+ *
+ * @return true if the client is authenticated, false if not
+ */
+ protected boolean isAuthenticated(HttpServletRequest request) throws ServletException {
+ log.debug("Attempting to authenticate client '{}'", request.getRemoteAddr());
+ try {
+ InetAddress clientAddress = InetAddress.getByName(request.getRemoteAddr());
+
+ for (IPRange range : allowedIPs) {
+ if (range.contains(clientAddress)) {
+ return true;
+ }
+ }
+
+ return false;
+ } catch (UnknownHostException e) {
+ throw new ServletException(e);
+ }
+ }
+
+ /**
* Prints out information about the operating environment. This includes the operating system name, version and
* architecture, the JDK version, available CPU cores, memory currently used by the JVM process, the maximum amount
* of memory that may be used by the JVM, and the current time in UTC.
protected void printRelyingPartyConfigurationInformation(PrintWriter out, RelyingPartyConfiguration config) {
out.println("relying_party_id: " + config.getRelyingPartyId());
out.println("idp_entity_id: " + config.getProviderId());
-
+
if (config.getDefaultAuthenticationMethod() != null) {
out.println("default_authentication_method: " + config.getDefaultAuthenticationMethod());
} else {
out.println("default_authentication_method: none");
}
- try{
+ try {
X509Credential signingCredential = (X509Credential) config.getDefaultSigningCredential();
- out.println("default_signing_tls_key: " + Base64.encodeBytes(signingCredential.getEntityCertificate().getEncoded(), Base64.DONT_BREAK_LINES));
- }catch(Throwable t){
+ out
+ .println("default_signing_tls_key: "
+ + Base64.encodeBytes(signingCredential.getEntityCertificate().getEncoded(),
+ Base64.DONT_BREAK_LINES));
+ } catch (Throwable t) {
// swallow error
}
-
+
for (String profileId : config.getProfileConfigurations().keySet()) {
out.println("configured_communication_profile: " + profileId);
}