Be sure not to add null values to list
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / Saml2LoginContext.java
index 8c47253..7173e98 100644 (file)
 
 package edu.internet2.middleware.shibboleth.idp.authn;
 
-
+import java.io.Serializable;
+import java.io.StringReader;
+import java.io.StringWriter;
+import java.util.ArrayList;
 import java.util.List;
 
-import javolution.util.FastList;
-
-import org.apache.log4j.Logger;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
 
+import org.opensaml.Configuration;
+import org.opensaml.saml2.core.AuthnContext;
 import org.opensaml.saml2.core.AuthnContextClassRef;
-import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
+import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnRequest;
 import org.opensaml.saml2.core.RequestedAuthnContext;
-
+import org.opensaml.xml.io.Marshaller;
+import org.opensaml.xml.io.MarshallingException;
+import org.opensaml.xml.io.Unmarshaller;
+import org.opensaml.xml.io.UnmarshallingException;
+import org.opensaml.xml.util.DatatypeHelper;
+import org.opensaml.xml.util.XMLHelper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.w3c.dom.Element;
+import org.xml.sax.InputSource;
 
 /**
- * A SAML 2.0 {@link LoginContext}. 
+ * A SAML 2.0 {@link LoginContext}.
  * 
  * This class can interpret {@link RequestedAuthnContext} and act accordingly.
  */
-public class Saml2LoginContext extends LoginContext {
-    
-    private static final Logger log = Logger.getLogger(Saml2LoginContext.class);
-    
-    /** The {@link RequestedAuthnContext} */
-    private RequestedAuthnContext ctx;
-    
+public class Saml2LoginContext extends LoginContext implements Serializable {
+
+    /** Serial version UID. */
+    private static final long serialVersionUID = -2518779446947534977L;
+
+    /** Class logger. */
+    private final Logger log = LoggerFactory.getLogger(Saml2LoginContext.class);
     
+    /** Relay state from authentication request. */
+    private String relayState;
+
+    /** Serialized authentication request. */
+    private String serialAuthnRequest;
+
+    /** Unmarshalled authentication request. */
+    private transient AuthnRequest authnRequest;
+
     /**
      * Creates a new instance of Saml2LoginContext.
-     *
-     * @param authnRequest A SAML 2.0 Authentication Request.
+     * 
+     * @param relyingParty entity ID of the relying party
+     * @param state relay state from incoming authentication request
+     * @param request SAML 2.0 Authentication Request
+     * 
+     * @throws MarshallingException thrown if the given request can not be marshalled and serialized into a string
      */
-    public Saml2LoginContext(AuthnRequest authnRequest) {
+    public Saml2LoginContext(String relyingParty, String state, AuthnRequest request) throws MarshallingException {
+        super();
         
-        if (authnRequest != null) {
-            forceAuth = authnRequest.isForceAuthn();
-            passiveAuth = authnRequest.isPassive();
-            ctx = authnRequest.getRequestedAuthnContext();
+        if (relyingParty == null || request == null) {
+            throw new IllegalArgumentException("SAML 2 authentication request and relying party ID may not be null");
         }
+        setRelyingParty(relyingParty);
+        relayState = state;
+        authnRequest = request;
+        serialAuthnRequest = serializeRequest(request);
+        
+        setForceAuthRequired(authnRequest.isForceAuthn());
+        setPassiveAuthRequired(authnRequest.isPassive());
+        getRequestedAuthenticationMethods().addAll(extractRequestedAuthenticationMethods());
+    }
+
+    /**
+     * Gets the authentication request that started the login process.
+     * 
+     * @return authentication request that started the login process
+     * 
+     * @throws UnmarshallingException thrown if the serialized form on the authentication request can be unmarshalled
+     */
+    public AuthnRequest getAuthenticationRequest() throws UnmarshallingException {
+        if (authnRequest == null) {
+            authnRequest = deserializeRequest(serialAuthnRequest);
+        }
+
+        return authnRequest;
     }
     
-    
-    /** 
-     * This method evaluates a SAML2 {@link RequestedAuthnContext}
-     * and returns the list of requested authentication method URIs.
-     *
-     * If the AuthnQuery did not contain a RequestedAuthnContext,
-     * this method will return <code>null</code>.
-     *
-     * @return An array of authentication method URIs, or <code>null</code>.
+    /**
+     * Gets the relay state from the orginating authentication request.
+     * 
+     * @return relay state from the orginating authentication request
      */
-    public String[] getRequestedAuthenticationMethods() {
+    public String getRelayState(){
+        return relayState;
+    }
 
-        if (ctx == null)
+    /**
+     * Gets the requested authentication context information from the authentication request.
+     * 
+     * @return requested authentication context information or null
+     */
+    public RequestedAuthnContext getRequestedAuthenticationContext() {
+        try {
+            AuthnRequest request = getAuthenticationRequest();
+            return request.getRequestedAuthnContext();
+        } catch (UnmarshallingException e) {
             return null;
+        }
+    }
+
+    /**
+     * Serializes an authentication request into a string.
+     * 
+     * @param request the request to serialize
+     * 
+     * @return the serialized form of the string
+     * 
+     * @throws MarshallingException thrown if the request can not be marshalled and serialized
+     */
+    protected String serializeRequest(AuthnRequest request) throws MarshallingException {
+        Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(request);
+        Element requestElem = marshaller.marshall(request);
+        StringWriter writer = new StringWriter();
+        XMLHelper.writeNode(requestElem, writer);
+        return writer.toString();
+    }
+
+    /**
+     * Deserailizes an authentication request from a string.
+     * 
+     * @param request request to deserialize
+     * 
+     * @return the request XMLObject
+     * 
+     * @throws UnmarshallingException thrown if the request can no be deserialized and unmarshalled
+     */
+    protected AuthnRequest deserializeRequest(String request) throws UnmarshallingException {
+        DocumentBuilderFactory builderFactory = DocumentBuilderFactory.newInstance();
+        try {
+            DocumentBuilder docBuilder = builderFactory.newDocumentBuilder();
+            InputSource requestInput = new InputSource(new StringReader(request));
+            Element requestElem = docBuilder.parse(requestInput).getDocumentElement();
+            Unmarshaller unmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(requestElem);
+            return (AuthnRequest) unmarshaller.unmarshall(requestElem);
+        } catch (Exception e) {
+            throw new UnmarshallingException("Unable to read serialized authentication request");
+        }
+    }
+    
+    /**
+     * Extracts the authentication methods requested within the request.
+     * 
+     * @return requested authentication methods, or an empty list if no preference
+     */
+    protected List<String> extractRequestedAuthenticationMethods(){
+        ArrayList<String> requestedMethods = new ArrayList<String>();
+
+        RequestedAuthnContext authnContext = getRequestedAuthenticationContext();
+        if (authnContext == null) {
+            return requestedMethods;
+        }
 
         // For the immediate future, we only support the "exact" comparator.
-        // XXX: we should probably throw an exception or somehow indicate this 
-        // as an error to the caller.
-        AuthnContextComparisonTypeEnumeration comparator = ctx.getComparison();
+        AuthnContextComparisonTypeEnumeration comparator = authnContext.getComparison();
         if (comparator != null && comparator != AuthnContextComparisonTypeEnumeration.EXACT) {
             log.error("Unsupported comparision operator ( " + comparator
-                + ") in RequestedAuthnContext. Only exact comparisions are supported.");
-            return null;
+                    + ") in RequestedAuthnContext. Only exact comparisions are supported.");
+            return requestedMethods;
         }
-    
+
         // build a list of all requested authn classes and declrefs
-        List<String> requestedAuthnMethods = new FastList<String>();
-        List<AuthnContextClassRef> authnClasses = ctx.getAuthnContextClassRefs();
-        List<AuthnContextDeclRef> authnDeclRefs = ctx.getAuthnContextDeclRefs();
-    
+        List<AuthnContextClassRef> authnClasses = authnContext.getAuthnContextClassRefs();
         if (authnClasses != null) {
             for (AuthnContextClassRef classRef : authnClasses) {
-                if (classRef != null) {
-                    String s = classRef.getAuthnContextClassRef();
-                    if (s != null) {
-                        requestedAuthnMethods.add(s);
-                    }
+                if (classRef != null && !DatatypeHelper.isEmpty(classRef.getAuthnContextClassRef())) {
+                    requestedMethods.add(classRef.getAuthnContextClassRef());
                 }
             }
         }
-    
+
+        List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
         if (authnDeclRefs != null) {
             for (AuthnContextDeclRef declRef : authnDeclRefs) {
-                if (declRef != null) {
-                    String s = declRef.getAuthnContextDeclRef();
-                    if (s != null) {
-                        requestedAuthnMethods.add(s);
-                    }
+                if (declRef != null&& !DatatypeHelper.isEmpty(declRef.getAuthnContextDeclRef())) {
+                    requestedMethods.add(declRef.getAuthnContextDeclRef());
                 }
             }
         }
-    
-        if (requestedAuthnMethods.size() == 0) {
-            return null;
-        } else {
-            String[] methods = new String[requestedAuthnMethods.size()];
-            return requestedAuthnMethods.toArray(methods);
+        
+        if(requestedMethods.contains(AuthnContext.UNSPECIFIED_AUTHN_CTX)){
+            requestedMethods.clear();
         }
 
+        return requestedMethods;
     }
-}
+}
\ No newline at end of file