prevent cast class exception if a user starts a SAML 1 flow, leaves in the middle...
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / profile / saml2 / SSOProfileHandler.java
index 4aa3ae1..d656df6 100644 (file)
@@ -20,9 +20,9 @@ import java.io.IOException;
 import java.io.StringReader;
 import java.util.ArrayList;
 
-import javax.servlet.RequestDispatcher;
-import javax.servlet.ServletException;
+import javax.servlet.ServletContext;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
@@ -37,17 +37,22 @@ import org.opensaml.saml2.core.AuthnContextClassRef;
 import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnRequest;
 import org.opensaml.saml2.core.AuthnStatement;
+import org.opensaml.saml2.core.NameID;
+import org.opensaml.saml2.core.NameIDPolicy;
 import org.opensaml.saml2.core.RequestedAuthnContext;
 import org.opensaml.saml2.core.Response;
 import org.opensaml.saml2.core.Statement;
 import org.opensaml.saml2.core.StatusCode;
 import org.opensaml.saml2.core.Subject;
 import org.opensaml.saml2.core.SubjectLocality;
+import org.opensaml.saml2.metadata.AffiliateMember;
+import org.opensaml.saml2.metadata.AffiliationDescriptor;
 import org.opensaml.saml2.metadata.AssertionConsumerService;
 import org.opensaml.saml2.metadata.Endpoint;
 import org.opensaml.saml2.metadata.EntityDescriptor;
 import org.opensaml.saml2.metadata.IDPSSODescriptor;
 import org.opensaml.saml2.metadata.SPSSODescriptor;
+import org.opensaml.saml2.metadata.provider.MetadataProviderException;
 import org.opensaml.ws.message.decoder.MessageDecodingException;
 import org.opensaml.ws.transport.http.HTTPInTransport;
 import org.opensaml.ws.transport.http.HTTPOutTransport;
@@ -60,7 +65,6 @@ import org.opensaml.xml.security.SecurityException;
 import org.opensaml.xml.util.DatatypeHelper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import org.slf4j.helpers.MessageFormatter;
 import org.w3c.dom.Element;
 
 import edu.internet2.middleware.shibboleth.common.profile.ProfileException;
@@ -70,9 +74,9 @@ import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfi
 import edu.internet2.middleware.shibboleth.common.relyingparty.provider.SAMLMDRelyingPartyConfigurationManager;
 import edu.internet2.middleware.shibboleth.common.relyingparty.provider.saml2.SSOConfiguration;
 import edu.internet2.middleware.shibboleth.common.util.HttpHelper;
-import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
 import edu.internet2.middleware.shibboleth.idp.authn.PassiveAuthenticationException;
 import edu.internet2.middleware.shibboleth.idp.authn.Saml2LoginContext;
+import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
 import edu.internet2.middleware.shibboleth.idp.session.Session;
 import edu.internet2.middleware.shibboleth.idp.util.HttpServletHelper;
 
@@ -112,7 +116,14 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
     public SSOProfileHandler(String authnManagerPath) {
         super();
 
-        authenticationManagerPath = authnManagerPath;
+        if (DatatypeHelper.isEmpty(authnManagerPath)) {
+            throw new IllegalArgumentException("Authentication manager path may not be null");
+        }
+        if (authnManagerPath.startsWith("/")) {
+            authenticationManagerPath = authnManagerPath;
+        } else {
+            authenticationManagerPath = "/" + authnManagerPath;
+        }
 
         authnStatementBuilder = (SAMLObjectBuilder<AuthnStatement>) getBuilderFactory().getBuilder(
                 AuthnStatement.DEFAULT_ELEMENT_NAME);
@@ -136,14 +147,21 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
     /** {@inheritDoc} */
     public void processRequest(HTTPInTransport inTransport, HTTPOutTransport outTransport) throws ProfileException {
         HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
+        HttpServletResponse httpResponse = ((HttpServletResponseAdapter) outTransport).getWrappedResponse();
+        ServletContext servletContext = httpRequest.getSession().getServletContext();
 
-        LoginContext loginContext = HttpServletHelper.getLoginContext(httpRequest);
-        if (loginContext == null) {
+        LoginContext loginContext = HttpServletHelper.getLoginContext(getStorageService(),
+                servletContext, httpRequest);
+        if (loginContext == null || !(loginContext instanceof Saml2LoginContext)) {
             log.debug("Incoming request does not contain a login context, processing as first leg of request");
             performAuthentication(inTransport, outTransport);
-        } else {
+        } else if (loginContext.isPrincipalAuthenticated() || loginContext.getAuthenticationFailure() != null) {
             log.debug("Incoming request contains a login context, processing as second leg of request");
-            completeAuthenticationRequest(inTransport, outTransport);
+            HttpServletHelper.unbindLoginContext(getStorageService(), servletContext, httpRequest, httpResponse);
+            completeAuthenticationRequest((Saml2LoginContext)loginContext, inTransport, outTransport);
+        } else {
+            log.debug("Incoming request contained a login context but principal was not authenticated, processing as first leg of request");
+            performAuthentication(inTransport, outTransport);
         }
     }
 
@@ -159,7 +177,9 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
      */
     protected void performAuthentication(HTTPInTransport inTransport, HTTPOutTransport outTransport)
             throws ProfileException {
-        HttpServletRequest servletRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
+        HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
+        HttpServletResponse httpResponse = ((HttpServletResponseAdapter) outTransport).getWrappedResponse();
+
         SSORequestContext requestContext = new SSORequestContext();
 
         try {
@@ -169,8 +189,8 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
             RelyingPartyConfiguration rpConfig = getRelyingPartyConfiguration(relyingPartyId);
             ProfileConfiguration ssoConfig = rpConfig.getProfileConfiguration(getProfileId());
             if (ssoConfig == null) {
-                String msg = MessageFormatter.format("SAML 2 SSO profile is not configured for relying party '{}'",
-                        requestContext.getInboundMessageIssuer());
+                String msg = "SAML 2 SSO profile is not configured for relying party "
+                        + requestContext.getInboundMessageIssuer();
                 log.warn(msg);
                 throw new ProfileException(msg);
             }
@@ -179,21 +199,22 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
             Saml2LoginContext loginContext = new Saml2LoginContext(relyingPartyId, requestContext.getRelayState(),
                     requestContext.getInboundSAMLMessage());
             loginContext.setAuthenticationEngineURL(authenticationManagerPath);
-            loginContext.setProfileHandlerURL(HttpHelper.getRequestUriWithoutContext(servletRequest));
+            loginContext.setProfileHandlerURL(HttpHelper.getRequestUriWithoutContext(httpRequest));
             loginContext.setDefaultAuthenticationMethod(rpConfig.getDefaultAuthenticationMethod());
 
-            HttpServletHelper.bindLoginContext(loginContext, servletRequest);
-            RequestDispatcher dispatcher = servletRequest.getRequestDispatcher(authenticationManagerPath);
-            dispatcher.forward(servletRequest, ((HttpServletResponseAdapter) outTransport).getWrappedResponse());
+            HttpServletHelper.bindLoginContext(loginContext, getStorageService(), httpRequest.getSession()
+                    .getServletContext(), httpRequest, httpResponse);
+
+            String authnEngineUrl = HttpServletHelper.getContextRelativeUrl(httpRequest, authenticationManagerPath)
+                    .buildURL();
+            log.debug("Redirecting user to authentication engine at {}", authnEngineUrl);
+            httpResponse.sendRedirect(authnEngineUrl);
         } catch (MarshallingException e) {
             log.error("Unable to marshall authentication request context");
             throw new ProfileException("Unable to marshall authentication request context", e);
         } catch (IOException ex) {
             log.error("Error forwarding SAML 2 AuthnRequest to AuthenticationManager", ex);
             throw new ProfileException("Error forwarding SAML 2 AuthnRequest to AuthenticationManager", ex);
-        } catch (ServletException ex) {
-            log.error("Error forwarding SAML 2 AuthnRequest to AuthenticationManager", ex);
-            throw new ProfileException("Error forwarding SAML 2 AuthnRequest to AuthenticationManager", ex);
         }
     }
 
@@ -201,22 +222,21 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
      * Creates a response to the {@link AuthnRequest} and sends the user, with response in tow, back to the relying
      * party after they've been authenticated.
      * 
+     * @param loginContext login context for this request
      * @param inTransport inbound message transport
      * @param outTransport outbound message transport
      * 
      * @throws ProfileException thrown if the response can not be created and sent back to the relying party
      */
-    protected void completeAuthenticationRequest(HTTPInTransport inTransport, HTTPOutTransport outTransport)
-            throws ProfileException {
-        HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
-        Saml2LoginContext loginContext = (Saml2LoginContext) HttpServletHelper.getLoginContext(httpRequest);
-
+    protected void completeAuthenticationRequest(Saml2LoginContext loginContext, HTTPInTransport inTransport,
+            HTTPOutTransport outTransport) throws ProfileException {
         SSORequestContext requestContext = buildRequestContext(loginContext, inTransport, outTransport);
 
-        checkSamlVersion(requestContext);
-
         Response samlResponse;
         try {
+            checkSamlVersion(requestContext);
+            checkNameIDPolicy(requestContext);
+
             if (loginContext.getAuthenticationFailure() != null) {
                 if (loginContext.getAuthenticationFailure() instanceof PassiveAuthenticationException) {
                     requestContext.setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, StatusCode.NO_PASSIVE_URI,
@@ -233,8 +253,9 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
                 resolvePrincipal(requestContext);
                 String requestedPrincipalName = requestContext.getPrincipalName();
                 if (!DatatypeHelper.safeEquals(loginContext.getPrincipalName(), requestedPrincipalName)) {
-                    log.warn("Authentication request identified principal {} but authentication mechanism identified principal {}",
-                                    requestedPrincipalName, loginContext.getPrincipalName());
+                    log.warn(
+                            "Authentication request identified principal {} but authentication mechanism identified principal {}",
+                            requestedPrincipalName, loginContext.getPrincipalName());
                     requestContext.setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, StatusCode.AUTHN_FAILED_URI,
                             null));
                     throw new ProfileException("User failed authentication");
@@ -242,7 +263,7 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
             }
 
             resolveAttributes(requestContext);
-            
+
             ArrayList<Statement> statements = new ArrayList<Statement>();
             statements.add(buildAuthnStatement(requestContext));
             if (requestContext.getProfileConfiguration().includeAttributeStatement()) {
@@ -319,6 +340,54 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
     }
 
     /**
+     * Checks to see, if present, if the affiliation associated with the SPNameQualifier given in the AuthnRequest
+     * NameIDPolicy lists the inbound message issuer as a member.
+     * 
+     * @param requestContext current request context
+     * 
+     * @throws ProfileException thrown if there the request is not a member of the affiliation or if there was a problem
+     *             determining membership
+     */
+    protected void checkNameIDPolicy(SSORequestContext requestContext) throws ProfileException {
+        AuthnRequest request = requestContext.getInboundSAMLMessage();
+
+        NameIDPolicy nameIdPolcy = request.getNameIDPolicy();
+        if (nameIdPolcy == null) {
+            return;
+        }
+
+        String spNameQualifier = DatatypeHelper.safeTrimOrNullString(nameIdPolcy.getSPNameQualifier());
+        if (spNameQualifier == null) {
+            return;
+        }
+
+        log.debug("Checking if message issuer is a member of affiliation '{}'", spNameQualifier);
+        try {
+            EntityDescriptor affiliation = getMetadataProvider().getEntityDescriptor(spNameQualifier);
+            if (affiliation != null) {
+                AffiliationDescriptor affiliationDescriptor = affiliation.getAffiliationDescriptor();
+                if (affiliationDescriptor != null && affiliationDescriptor.getMembers() != null) {
+                    for (AffiliateMember member : affiliationDescriptor.getMembers()) {
+                        if (DatatypeHelper.safeEquals(member.getID(), requestContext.getInboundMessageIssuer())) {
+                            return;
+                        }
+                    }
+                }
+            }
+
+            requestContext.setFailureStatus(buildStatus(StatusCode.REQUESTER_URI, StatusCode.INVALID_NAMEID_POLICY_URI,
+                    "Invalid SPNameQualifier for this request"));
+            throw new ProfileException("Relying party '" + requestContext.getInboundMessageIssuer()
+                    + "' is not a member of the affiliation " + spNameQualifier);
+        } catch (MetadataProviderException e) {
+            requestContext.setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, null, "Internal service error"));
+            log.error("Error looking up metadata for affiliation", e);
+            throw new ProfileException("Relying party '" + requestContext.getInboundMessageIssuer()
+                    + "' is not a member of the affiliation " + spNameQualifier);
+        }
+    }
+
+    /**
      * Creates an authentication request context from the current environmental information.
      * 
      * @param loginContext current login context
@@ -464,7 +533,8 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
         if (requestedAuthnContext != null) {
             if (requestedAuthnContext.getAuthnContextClassRefs() != null) {
                 for (AuthnContextClassRef classRef : requestedAuthnContext.getAuthnContextClassRefs()) {
-                    if (classRef.getAuthnContextClassRef().equals(loginContext.getAuthenticationMethod())) {
+                    if (DatatypeHelper.safeEquals(classRef.getAuthnContextClassRef(),
+                            loginContext.getAuthenticationMethod())) {
                         AuthnContextClassRef ref = authnContextClassRefBuilder.buildObject();
                         ref.setAuthnContextClassRef(loginContext.getAuthenticationMethod());
                         authnContext.setAuthnContextClassRef(ref);
@@ -472,7 +542,8 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
                 }
             } else if (requestedAuthnContext.getAuthnContextDeclRefs() != null) {
                 for (AuthnContextDeclRef declRef : requestedAuthnContext.getAuthnContextDeclRefs()) {
-                    if (declRef.getAuthnContextDeclRef().equals(loginContext.getAuthenticationMethod())) {
+                    if (DatatypeHelper.safeEquals(declRef.getAuthnContextDeclRef(),
+                            loginContext.getAuthenticationMethod())) {
                         AuthnContextDeclRef ref = authnContextDeclRefBuilder.buildObject();
                         ref.setAuthnContextDeclRef(loginContext.getAuthenticationMethod());
                         authnContext.setAuthnContextDeclRef(ref);
@@ -505,6 +576,42 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
         return subjectLocality;
     }
 
+    /** {@inheritDoc} */
+    protected String getRequiredNameIDFormat(BaseSAMLProfileRequestContext requestContext) {
+        String requiredNameFormat = null;
+        AuthnRequest authnRequest = (AuthnRequest) requestContext.getInboundSAMLMessage();
+        NameIDPolicy nameIdPolicy = authnRequest.getNameIDPolicy();
+        if (nameIdPolicy != null) {
+            requiredNameFormat = DatatypeHelper.safeTrimOrNullString(nameIdPolicy.getFormat());
+            // Check for unspec'd or encryption formats, which aren't relevant for this section of code.
+            if (requiredNameFormat != null
+                    && (NameID.ENCRYPTED.equals(requiredNameFormat) || NameID.UNSPECIFIED.equals(requiredNameFormat))) {
+                requiredNameFormat = null;
+            }
+        }
+
+        return requiredNameFormat;
+    }
+
+    /** {@inheritDoc} */
+    protected NameID buildNameId(BaseSAML2ProfileRequestContext<?, ?, ?> requestContext) throws ProfileException {
+        NameID nameId = super.buildNameId(requestContext);
+        if (nameId != null) {
+            AuthnRequest authnRequest = (AuthnRequest) requestContext.getInboundSAMLMessage();
+            NameIDPolicy nameIdPolicy = authnRequest.getNameIDPolicy();
+            if (nameIdPolicy != null) {
+                String spNameQualifier = DatatypeHelper.safeTrimOrNullString(nameIdPolicy.getSPNameQualifier());
+                if (spNameQualifier != null) {
+                    nameId.setSPNameQualifier(spNameQualifier);
+                } else {
+                    nameId.setSPNameQualifier(requestContext.getInboundMessageIssuer());
+                }
+            }
+        }
+
+        return nameId;
+    }
+
     /**
      * Selects the appropriate endpoint for the relying party and stores it in the request context.
      * 
@@ -525,11 +632,10 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
                 } else {
                     endpoint.setBinding(getSupportedOutboundBindings().get(0));
                 }
-                log
-                        .warn(
-                                "Generating endpoint for anonymous relying party self-identified as '{}', ACS url '{}' and binding '{}'",
-                                new Object[] { requestContext.getInboundMessageIssuer(), endpoint.getLocation(),
-                                        endpoint.getBinding(), });
+                log.warn(
+                        "Generating endpoint for anonymous relying party self-identified as '{}', ACS url '{}' and binding '{}'",
+                        new Object[] { requestContext.getInboundMessageIssuer(), endpoint.getLocation(),
+                                endpoint.getBinding(), });
             } else {
                 log.warn("Unable to generate endpoint for anonymous party.  No ACS url provided.");
             }