The HS now validates providerId -> authN consumer URL unions against federation metadata.
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / hs / HandleServlet.java
index 888295b..00f242f 100644 (file)
@@ -33,7 +33,6 @@ import java.util.Date;
 import javax.servlet.RequestDispatcher;
 import javax.servlet.ServletException;
 import javax.servlet.UnavailableException;
-import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
@@ -52,16 +51,26 @@ import org.w3c.dom.Element;
 import org.w3c.dom.NodeList;
 
 import sun.misc.BASE64Decoder;
+
+import com.sun.corba.se.internal.core.EndPoint;
+
 import edu.internet2.middleware.shibboleth.common.AuthNPrincipal;
 import edu.internet2.middleware.shibboleth.common.Credentials;
 import edu.internet2.middleware.shibboleth.common.NameIdentifierMapping;
 import edu.internet2.middleware.shibboleth.common.NameIdentifierMappingException;
 import edu.internet2.middleware.shibboleth.common.OriginConfig;
+import edu.internet2.middleware.shibboleth.common.RelyingParty;
 import edu.internet2.middleware.shibboleth.common.ServiceProviderMapperException;
 import edu.internet2.middleware.shibboleth.common.ShibPOSTProfile;
 import edu.internet2.middleware.shibboleth.common.ShibbolethConfigurationException;
+import edu.internet2.middleware.shibboleth.common.ShibbolethOriginConfig;
+import edu.internet2.middleware.shibboleth.common.TargetFederationComponent;
+import edu.internet2.middleware.shibboleth.metadata.Endpoint;
+import edu.internet2.middleware.shibboleth.metadata.Provider;
+import edu.internet2.middleware.shibboleth.metadata.ProviderRole;
+import edu.internet2.middleware.shibboleth.metadata.SPProviderRole;
 
-public class HandleServlet extends HttpServlet {
+public class HandleServlet extends TargetFederationComponent {
 
        private static Logger                   log                             = Logger.getLogger(HandleServlet.class.getName());
        private static Logger                   transactionLog  = Logger.getLogger("Shibboleth-TRANSACTION");
@@ -116,6 +125,16 @@ public class HandleServlet extends HttpServlet {
                        throw new ShibbolethConfigurationException("Could not load origin configuration.");
                }
 
+               //Load metadata
+               itemElements = originConfig.getDocumentElement().getElementsByTagNameNS(
+                               ShibbolethOriginConfig.originConfigNamespace, "FederationProvider");
+               for (int i = 0; i < itemElements.getLength(); i++) {
+                       addFederationProvider((Element) itemElements.item(i));
+               }
+               if (providerCount() < 1) {
+                       log.error("No Federation Provider metadata loaded.");
+                       throw new ShibbolethConfigurationException("Could not load federation metadata.");
+               }
        }
 
        public void init() throws ServletException {
@@ -153,10 +172,23 @@ public class HandleServlet extends HttpServlet {
 
                        HSRelyingParty relyingParty = targetMapper.getRelyingParty(req.getParameter("providerId"));
 
+                       //Get the authN info
                        String username = configuration.getAuthHeaderName().equalsIgnoreCase("REMOTE_USER")
                                        ? req.getRemoteUser()
                                        : req.getHeader(configuration.getAuthHeaderName());
 
+                       //Make sure that the selected relying party configuration is appropriate for this
+                       //acceptance URL
+                       if (!relyingParty.isLegacyProvider()) {
+                               if (isValidAssertionConsumerURL(relyingParty, req.getParameter("shire"))) {
+                                       log.info("Supplied consumer URL validated for this provider.");
+                               } else {
+                                       log.error("Supplied assertion consumer service URL (" + req.getParameter("shire")
+                                                       + ") is NOT valid for provider (" + relyingParty.getProviderId() + ").");
+                                       throw new InvalidClientDataException("Invalid assertion consumer service URL.");
+                               }
+                       }
+
                        SAMLNameIdentifier nameId = nameMapper.getNameIdentifierName(relyingParty.getHSNameFormatId(),
                                        new AuthNPrincipal(username), relyingParty, relyingParty.getIdentityProvider());
 
@@ -175,9 +207,9 @@ public class HandleServlet extends HttpServlet {
                        createForm(req, res, buf);
 
                        if (relyingParty.isLegacyProvider()) {
-                               transactionLog.info("Authentication assertion issued to legacy provider (SHIRE: " + req.getParameter("shire")
-                                               + ") on behalf of principal (" + username
-                                               + ") for resource (" + req.getParameter("target") + "). Name Identifier: (" + nameId.getName()
+                               transactionLog.info("Authentication assertion issued to legacy provider (SHIRE: "
+                                               + req.getParameter("shire") + ") on behalf of principal (" + username + ") for resource ("
+                                               + req.getParameter("target") + "). Name Identifier: (" + nameId.getName()
                                                + "). Name Identifier Format: (" + nameId.getFormat() + ").");
                        } else {
                                transactionLog.info("Authentication assertion issued to provider (" + req.getParameter("providerId")
@@ -268,6 +300,35 @@ public class HandleServlet extends HttpServlet {
                }
        }
 
+       protected boolean isValidAssertionConsumerURL(RelyingParty relyingParty, String shireURL)
+                       throws InvalidClientDataException {
+
+               Provider provider = lookup(relyingParty.getProviderId());
+               if (provider == null) {
+                       log.info("No metadata found for provider: (" + relyingParty.getProviderId() + ").");
+                       throw new InvalidClientDataException("Request if from an unkown Service Provider.");
+               }
+
+               ProviderRole[] roles = provider.getRoles();
+               if (roles.length == 0) {
+                       log.info("Inappropriate metadata for provider.");
+                       return false;
+               }
+
+               for (int i = 0; roles.length > i; i++) {
+                       if (roles[i] instanceof SPProviderRole) {
+                               Endpoint[] endpoints = ((SPProviderRole) roles[i]).getAssertionConsumerServiceURLs();
+                               for (int j = 0; endpoints.length > j; j++) {
+                                       if (shireURL.equals(endpoints[j].getLocation())) {
+                                               return true;
+                                       }
+                               }
+                       }
+               }
+               log.info("Supplied consumer URL not found in metadata.");
+               return false;
+       }
+
        class InvalidClientDataException extends Exception {
 
                public InvalidClientDataException(String message) {