Ensure login context is bound to the request by the profile handlers
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / AuthenticationEngine.java
index 6350c8f..a997626 100644 (file)
@@ -139,8 +139,17 @@ public class AuthenticationEngine extends HttpServlet {
             LOG.error("HTTP Response already committed");
         }
 
-        HttpSession httpSession = httpRequest.getSession();
-        LoginContext loginContext = (LoginContext) httpSession.getAttribute(LoginContext.LOGIN_CONTEXT_KEY);
+        LoginContext loginContext = (LoginContext) httpRequest.getAttribute(LoginContext.LOGIN_CONTEXT_KEY);
+        if (loginContext == null) {
+            // When the login context comes from the profile handlers its attached to the request
+            // The authn engine attaches it to the session to allow the handlers to do any number of
+            // request/response pairs without maintaining or losing the login context
+            loginContext = (LoginContext) httpRequest.getSession().getAttribute(LoginContext.LOGIN_CONTEXT_KEY);
+        } else {
+            // Clean out any old state that might be lying around
+            httpRequest.getSession().removeAttribute(LoginContext.LOGIN_CONTEXT_KEY);
+        }
+
         if (loginContext == null) {
             LOG.error("Incoming request does not have attached login context");
             throw new ServletException("Incoming request does not have attached login context");
@@ -181,7 +190,8 @@ public class AuthenticationEngine extends HttpServlet {
                 LOG.debug("Forced authentication not required, trying existing authentication methods");
                 for (AuthenticationMethodInformation activeAuthnMethod : activeAuthnMethods) {
                     if (possibleLoginHandlers.containsKey(activeAuthnMethod.getAuthenticationMethod())) {
-                        completeAuthenticationWithActiveMethod(activeAuthnMethod, httpRequest, httpResponse);
+                        completeAuthenticationWithActiveMethod(loginContext, activeAuthnMethod, httpRequest,
+                                httpResponse);
                         return;
                     }
                 }
@@ -316,23 +326,23 @@ public class AuthenticationEngine extends HttpServlet {
         loginContext.setAuthenticationMethod(authnMethod);
         loginContext.setAuthenticationEngineURL(HttpHelper.getRequestUriWithoutContext(httpRequest));
         logingHandler.login(httpRequest, httpResponse);
+        httpRequest.getSession().setAttribute(LoginContext.LOGIN_CONTEXT_KEY, loginContext);
     }
 
     /**
      * Completes the authentication request using an existing, active, authentication method for the current user.
      * 
+     * @param loginContext current login context
      * @param authenticationMethod authentication method to use to complete the request
      * @param httpRequest current HTTP request
      * @param httpResponse current HTTP response
      */
-    protected void completeAuthenticationWithActiveMethod(AuthenticationMethodInformation authenticationMethod,
-            HttpServletRequest httpRequest, HttpServletResponse httpResponse) {
-        HttpSession httpSession = httpRequest.getSession();
-
+    protected void completeAuthenticationWithActiveMethod(LoginContext loginContext,
+            AuthenticationMethodInformation authenticationMethod, HttpServletRequest httpRequest,
+            HttpServletResponse httpResponse) {
         Session shibSession = (Session) httpRequest.getAttribute(Session.HTTP_SESSION_BINDING_ATTRIBUTE);
 
         LOG.debug("Populating login context with existing session and authentication method information.");
-        LoginContext loginContext = (LoginContext) httpSession.getAttribute(LoginContext.LOGIN_CONTEXT_KEY);
         loginContext.setAuthenticationDuration(authenticationMethod.getAuthenticationDuration());
         loginContext.setAuthenticationInstant(authenticationMethod.getAuthenticationInstant());
         loginContext.setAuthenticationMethod(authenticationMethod.getAuthenticationMethod());