prevent cast class exception if a user starts a SAML 1 flow, leaves in the middle...
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / profile / saml1 / ShibbolethSSOProfileHandler.java
index 177e85b..7ceba20 100644 (file)
@@ -19,8 +19,7 @@ package edu.internet2.middleware.shibboleth.idp.profile.saml1;
 import java.io.IOException;
 import java.util.ArrayList;
 
-import javax.servlet.RequestDispatcher;
-import javax.servlet.ServletException;
+import javax.servlet.ServletContext;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
@@ -40,6 +39,7 @@ import org.opensaml.saml2.metadata.Endpoint;
 import org.opensaml.saml2.metadata.EntityDescriptor;
 import org.opensaml.saml2.metadata.IDPSSODescriptor;
 import org.opensaml.saml2.metadata.SPSSODescriptor;
+import org.opensaml.util.URLBuilder;
 import org.opensaml.ws.message.decoder.MessageDecodingException;
 import org.opensaml.ws.transport.http.HTTPInTransport;
 import org.opensaml.ws.transport.http.HTTPOutTransport;
@@ -49,7 +49,6 @@ import org.opensaml.xml.security.SecurityException;
 import org.opensaml.xml.util.DatatypeHelper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import org.slf4j.helpers.MessageFormatter;
 
 import edu.internet2.middleware.shibboleth.common.ShibbolethConstants;
 import edu.internet2.middleware.shibboleth.common.profile.ProfileException;
@@ -59,8 +58,8 @@ import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfi
 import edu.internet2.middleware.shibboleth.common.relyingparty.provider.SAMLMDRelyingPartyConfigurationManager;
 import edu.internet2.middleware.shibboleth.common.relyingparty.provider.saml1.ShibbolethSSOConfiguration;
 import edu.internet2.middleware.shibboleth.common.util.HttpHelper;
-import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
 import edu.internet2.middleware.shibboleth.idp.authn.ShibbolethSSOLoginContext;
+import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
 import edu.internet2.middleware.shibboleth.idp.util.HttpServletHelper;
 
 /** Shibboleth SSO request profile handler. */
@@ -90,7 +89,11 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         if (DatatypeHelper.isEmpty(authnManagerPath)) {
             throw new IllegalArgumentException("Authentication manager path may not be null");
         }
-        authenticationManagerPath = authnManagerPath;
+        if (authnManagerPath.startsWith("/")) {
+            authenticationManagerPath = authnManagerPath;
+        } else {
+            authenticationManagerPath = "/" + authnManagerPath;
+        }
 
         authnStatementBuilder = (SAMLObjectBuilder<AuthenticationStatement>) getBuilderFactory().getBuilder(
                 AuthenticationStatement.DEFAULT_ELEMENT_NAME);
@@ -112,20 +115,28 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         log.debug("Processing incoming request");
 
         HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
-        LoginContext loginContext = HttpServletHelper.getLoginContext(httpRequest);
+        HttpServletResponse httpResponse = ((HttpServletResponseAdapter) outTransport).getWrappedResponse();
+        ServletContext servletContext = httpRequest.getSession().getServletContext();
 
-        if (loginContext == null) {
+       LoginContext loginContext = HttpServletHelper.getLoginContext(
+                getStorageService(), servletContext, httpRequest);
+
+        if (loginContext == null || !(loginContext instanceof ShibbolethSSOLoginContext)) {
             log.debug("Incoming request does not contain a login context, processing as first leg of request");
             performAuthentication(inTransport, outTransport);
-        } else {
+        } else if (loginContext.isPrincipalAuthenticated() || loginContext.getAuthenticationFailure() != null) {
             log.debug("Incoming request contains a login context, processing as second leg of request");
-            completeAuthenticationRequest(inTransport, outTransport);
+            HttpServletHelper.unbindLoginContext(getStorageService(), servletContext, httpRequest, httpResponse);
+            completeAuthenticationRequest((ShibbolethSSOLoginContext)loginContext, inTransport, outTransport);
+        } else {
+            log.debug("Incoming request contained a login context but principal was not authenticated, processing as first leg of request");
+            performAuthentication(inTransport, outTransport);
         }
     }
 
     /**
-     * Creates a {@link LoginContext} an sends the request off to the AuthenticationManager to begin the process of
-     * authenticating the user.
+     * Creates a {@link ShibbolethSSOLoginContext} an sends the request off to the AuthenticationManager to begin the
+     * process of authenticating the user.
      * 
      * @param inTransport inbound message transport
      * @param outTransport outbound message transport
@@ -147,26 +158,24 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         loginContext.setDefaultAuthenticationMethod(rpConfig.getDefaultAuthenticationMethod());
         ProfileConfiguration ssoConfig = rpConfig.getProfileConfiguration(ShibbolethSSOConfiguration.PROFILE_ID);
         if (ssoConfig == null) {
-            String msg = MessageFormatter.format("Shibboleth SSO profile is not configured for relying party '{}'",
-                    loginContext.getRelyingPartyId());
+            String msg = "Shibboleth SSO profile is not configured for relying party "
+                    + loginContext.getRelyingPartyId();
             log.warn(msg);
             throw new ProfileException(msg);
         }
 
-        HttpServletHelper.bindLoginContext(loginContext, httpRequest);
+        HttpServletHelper.bindLoginContext(loginContext, getStorageService(), httpRequest.getSession()
+                .getServletContext(), httpRequest, httpResponse);
 
         try {
-            RequestDispatcher dispatcher = httpRequest.getRequestDispatcher(authenticationManagerPath);
-            dispatcher.forward(httpRequest, httpResponse);
-            return;
+            String authnEngineUrl = HttpServletHelper.getContextRelativeUrl(httpRequest, authenticationManagerPath)
+                    .buildURL();
+            log.debug("Redirecting user to authentication engine at {}", authnEngineUrl);
+            httpResponse.sendRedirect(authnEngineUrl);
         } catch (IOException e) {
             String msg = "Error forwarding Shibboleth SSO request to AuthenticationManager";
             log.error(msg, e);
             throw new ProfileException(msg, e);
-        } catch (ServletException e) {
-            String msg = "Error forwarding Shibboleth SSO request to AuthenticationManager";
-            log.error(msg, e);
-            throw new ProfileException(msg, e);
         }
     }
 
@@ -182,8 +191,8 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
     protected void decodeRequest(ShibbolethSSORequestContext requestContext, HTTPInTransport inTransport,
             HTTPOutTransport outTransport) throws ProfileException {
         if (log.isDebugEnabled()) {
-            log.debug("Decoding message with decoder binding {}",
-                    getInboundMessageDecoder(requestContext).getBindingURI());
+            log.debug("Decoding message with decoder binding {}", getInboundMessageDecoder(requestContext)
+                    .getBindingURI());
         }
 
         HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
@@ -205,8 +214,8 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         requestContext.setMessageDecoder(decoder);
         try {
             decoder.decode(requestContext);
-            log.debug("Decoded Shibboleth SSO request from relying party '{}'", requestContext
-                    .getInboundMessageIssuer());
+            log.debug("Decoded Shibboleth SSO request from relying party '{}'",
+                    requestContext.getInboundMessageIssuer());
         } catch (MessageDecodingException e) {
             String msg = "Error decoding Shibboleth SSO request";
             log.warn(msg, e);
@@ -230,16 +239,14 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
      * Creates a response to the Shibboleth SSO and sends the user, with response in tow, back to the relying party
      * after they've been authenticated.
      * 
+     * @param loginContext login context for this request
      * @param inTransport inbound message transport
      * @param outTransport outbound message transport
      * 
      * @throws ProfileException thrown if the response can not be created and sent back to the relying party
      */
-    protected void completeAuthenticationRequest(HTTPInTransport inTransport, HTTPOutTransport outTransport)
-            throws ProfileException {
-        HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
-        ShibbolethSSOLoginContext loginContext = (ShibbolethSSOLoginContext) HttpServletHelper.getLoginContext(httpRequest);
-
+    protected void completeAuthenticationRequest(ShibbolethSSOLoginContext loginContext, HTTPInTransport inTransport,
+            HTTPOutTransport outTransport) throws ProfileException {
         ShibbolethSSORequestContext requestContext = buildRequestContext(loginContext, inTransport, outTransport);
 
         Response samlResponse;
@@ -250,16 +257,16 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
             }
 
             resolveAttributes(requestContext);
-            
+
             ArrayList<Statement> statements = new ArrayList<Statement>();
             statements.add(buildAuthenticationStatement(requestContext));
             if (requestContext.getProfileConfiguration().includeAttributeStatement()) {
-                    AttributeStatement attributeStatement = buildAttributeStatement(requestContext,
-                            "urn:oasis:names:tc:SAML:1.0:cm:bearer");
-                    if (attributeStatement != null) {
-                        requestContext.setReleasedAttributes(requestContext.getAttributes().keySet());
-                        statements.add(attributeStatement);
-                    }
+                AttributeStatement attributeStatement = buildAttributeStatement(requestContext,
+                        "urn:oasis:names:tc:SAML:1.0:cm:bearer");
+                if (attributeStatement != null) {
+                    requestContext.setReleasedAttributes(requestContext.getAttributes().keySet());
+                    statements.add(attributeStatement);
+                }
             }
 
             samlResponse = buildResponse(requestContext, statements);
@@ -299,7 +306,7 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         requestContext.setInboundSAMLProtocol(ShibbolethConstants.SHIB_SSO_PROFILE_URI);
 
         requestContext.setOutboundMessageTransport(out);
-        requestContext.setOutboundSAMLProtocol(SAMLConstants.SAML20P_NS);
+        requestContext.setOutboundSAMLProtocol(SAMLConstants.SAML11P_NS);
 
         requestContext.setMetadataProvider(getMetadataProvider());