Be sure not to add null values to list
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / Saml2LoginContext.java
index 945716a..7173e98 100644 (file)
 
 package edu.internet2.middleware.shibboleth.idp.authn;
 
+import java.io.Serializable;
+import java.io.StringReader;
+import java.io.StringWriter;
+import java.util.ArrayList;
 import java.util.List;
-import java.util.LinkedList;
 
-import org.apache.log4j.Logger;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
 
+import org.opensaml.Configuration;
+import org.opensaml.saml2.core.AuthnContext;
 import org.opensaml.saml2.core.AuthnContextClassRef;
-import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
+import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnRequest;
 import org.opensaml.saml2.core.RequestedAuthnContext;
+import org.opensaml.xml.io.Marshaller;
+import org.opensaml.xml.io.MarshallingException;
+import org.opensaml.xml.io.Unmarshaller;
+import org.opensaml.xml.io.UnmarshallingException;
+import org.opensaml.xml.util.DatatypeHelper;
+import org.opensaml.xml.util.XMLHelper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.w3c.dom.Element;
+import org.xml.sax.InputSource;
 
 /**
  * A SAML 2.0 {@link LoginContext}.
  * 
  * This class can interpret {@link RequestedAuthnContext} and act accordingly.
  */
-public class Saml2LoginContext extends LoginContext {
-
-       private static final Logger log = Logger.getLogger(Saml2LoginContext.class);
-
-       /** The {@link RequestedAuthnContext} */
-       private RequestedAuthnContext ctx;
-
-       /**
-        * Creates a new instance of Saml2LoginContext.
-        * 
-        * @param authnRequest
-        *            A SAML 2.0 Authentication Request.
-        */
-       public Saml2LoginContext(AuthnRequest authnRequest) {
-
-               if (authnRequest != null) {
-                       forceAuth = authnRequest.isForceAuthn();
-                       passiveAuth = authnRequest.isPassive();
-                       ctx = authnRequest.getRequestedAuthnContext();
-               }
-       }
-
-       /**
-        * This method evaluates a SAML2 {@link RequestedAuthnContext} and returns
-        * the list of requested authentication method URIs.
-        * 
-        * If the AuthnQuery did not contain a RequestedAuthnContext, this method
-        * will return <code>null</code>.
-        * 
-        * @return An array of authentication method URIs, or <code>null</code>.
-        */
-       public String[] getRequestedAuthenticationMethods() {
-
-               if (ctx == null)
-                       return null;
-
-               // For the immediate future, we only support the "exact" comparator.
-               // XXX: we should probably throw an exception or somehow indicate this
-               // as an error to the caller.
-               AuthnContextComparisonTypeEnumeration comparator = ctx.getComparison();
-               if (comparator != null
-                               && comparator != AuthnContextComparisonTypeEnumeration.EXACT) {
-                       log
-                                       .error("Unsupported comparision operator ( "
-                                                       + comparator
-                                                       + ") in RequestedAuthnContext. Only exact comparisions are supported.");
-                       return null;
-               }
-
-               // build a list of all requested authn classes and declrefs
-               List<String> requestedAuthnMethods = new LinkedList<String>();
-               List<AuthnContextClassRef> authnClasses = ctx
-                               .getAuthnContextClassRefs();
-               List<AuthnContextDeclRef> authnDeclRefs = ctx.getAuthnContextDeclRefs();
-
-               if (authnClasses != null) {
-                       for (AuthnContextClassRef classRef : authnClasses) {
-                               if (classRef != null) {
-                                       String s = classRef.getAuthnContextClassRef();
-                                       if (s != null) {
-                                               requestedAuthnMethods.add(s);
-                                       }
-                               }
-                       }
-               }
-
-               if (authnDeclRefs != null) {
-                       for (AuthnContextDeclRef declRef : authnDeclRefs) {
-                               if (declRef != null) {
-                                       String s = declRef.getAuthnContextDeclRef();
-                                       if (s != null) {
-                                               requestedAuthnMethods.add(s);
-                                       }
-                               }
-                       }
-               }
-
-               if (requestedAuthnMethods.size() == 0) {
-                       return null;
-               } else {
-                       String[] methods = new String[requestedAuthnMethods.size()];
-                       return requestedAuthnMethods.toArray(methods);
-               }
-
-       }
-}
+public class Saml2LoginContext extends LoginContext implements Serializable {
+
+    /** Serial version UID. */
+    private static final long serialVersionUID = -2518779446947534977L;
+
+    /** Class logger. */
+    private final Logger log = LoggerFactory.getLogger(Saml2LoginContext.class);
+    
+    /** Relay state from authentication request. */
+    private String relayState;
+
+    /** Serialized authentication request. */
+    private String serialAuthnRequest;
+
+    /** Unmarshalled authentication request. */
+    private transient AuthnRequest authnRequest;
+
+    /**
+     * Creates a new instance of Saml2LoginContext.
+     * 
+     * @param relyingParty entity ID of the relying party
+     * @param state relay state from incoming authentication request
+     * @param request SAML 2.0 Authentication Request
+     * 
+     * @throws MarshallingException thrown if the given request can not be marshalled and serialized into a string
+     */
+    public Saml2LoginContext(String relyingParty, String state, AuthnRequest request) throws MarshallingException {
+        super();
+        
+        if (relyingParty == null || request == null) {
+            throw new IllegalArgumentException("SAML 2 authentication request and relying party ID may not be null");
+        }
+        setRelyingParty(relyingParty);
+        relayState = state;
+        authnRequest = request;
+        serialAuthnRequest = serializeRequest(request);
+        
+        setForceAuthRequired(authnRequest.isForceAuthn());
+        setPassiveAuthRequired(authnRequest.isPassive());
+        getRequestedAuthenticationMethods().addAll(extractRequestedAuthenticationMethods());
+    }
+
+    /**
+     * Gets the authentication request that started the login process.
+     * 
+     * @return authentication request that started the login process
+     * 
+     * @throws UnmarshallingException thrown if the serialized form on the authentication request can be unmarshalled
+     */
+    public AuthnRequest getAuthenticationRequest() throws UnmarshallingException {
+        if (authnRequest == null) {
+            authnRequest = deserializeRequest(serialAuthnRequest);
+        }
+
+        return authnRequest;
+    }
+    
+    /**
+     * Gets the relay state from the orginating authentication request.
+     * 
+     * @return relay state from the orginating authentication request
+     */
+    public String getRelayState(){
+        return relayState;
+    }
+
+    /**
+     * Gets the requested authentication context information from the authentication request.
+     * 
+     * @return requested authentication context information or null
+     */
+    public RequestedAuthnContext getRequestedAuthenticationContext() {
+        try {
+            AuthnRequest request = getAuthenticationRequest();
+            return request.getRequestedAuthnContext();
+        } catch (UnmarshallingException e) {
+            return null;
+        }
+    }
+
+    /**
+     * Serializes an authentication request into a string.
+     * 
+     * @param request the request to serialize
+     * 
+     * @return the serialized form of the string
+     * 
+     * @throws MarshallingException thrown if the request can not be marshalled and serialized
+     */
+    protected String serializeRequest(AuthnRequest request) throws MarshallingException {
+        Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(request);
+        Element requestElem = marshaller.marshall(request);
+        StringWriter writer = new StringWriter();
+        XMLHelper.writeNode(requestElem, writer);
+        return writer.toString();
+    }
+
+    /**
+     * Deserailizes an authentication request from a string.
+     * 
+     * @param request request to deserialize
+     * 
+     * @return the request XMLObject
+     * 
+     * @throws UnmarshallingException thrown if the request can no be deserialized and unmarshalled
+     */
+    protected AuthnRequest deserializeRequest(String request) throws UnmarshallingException {
+        DocumentBuilderFactory builderFactory = DocumentBuilderFactory.newInstance();
+        try {
+            DocumentBuilder docBuilder = builderFactory.newDocumentBuilder();
+            InputSource requestInput = new InputSource(new StringReader(request));
+            Element requestElem = docBuilder.parse(requestInput).getDocumentElement();
+            Unmarshaller unmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(requestElem);
+            return (AuthnRequest) unmarshaller.unmarshall(requestElem);
+        } catch (Exception e) {
+            throw new UnmarshallingException("Unable to read serialized authentication request");
+        }
+    }
+    
+    /**
+     * Extracts the authentication methods requested within the request.
+     * 
+     * @return requested authentication methods, or an empty list if no preference
+     */
+    protected List<String> extractRequestedAuthenticationMethods(){
+        ArrayList<String> requestedMethods = new ArrayList<String>();
+
+        RequestedAuthnContext authnContext = getRequestedAuthenticationContext();
+        if (authnContext == null) {
+            return requestedMethods;
+        }
+
+        // For the immediate future, we only support the "exact" comparator.
+        AuthnContextComparisonTypeEnumeration comparator = authnContext.getComparison();
+        if (comparator != null && comparator != AuthnContextComparisonTypeEnumeration.EXACT) {
+            log.error("Unsupported comparision operator ( " + comparator
+                    + ") in RequestedAuthnContext. Only exact comparisions are supported.");
+            return requestedMethods;
+        }
+
+        // build a list of all requested authn classes and declrefs
+        List<AuthnContextClassRef> authnClasses = authnContext.getAuthnContextClassRefs();
+        if (authnClasses != null) {
+            for (AuthnContextClassRef classRef : authnClasses) {
+                if (classRef != null && !DatatypeHelper.isEmpty(classRef.getAuthnContextClassRef())) {
+                    requestedMethods.add(classRef.getAuthnContextClassRef());
+                }
+            }
+        }
+
+        List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
+        if (authnDeclRefs != null) {
+            for (AuthnContextDeclRef declRef : authnDeclRefs) {
+                if (declRef != null&& !DatatypeHelper.isEmpty(declRef.getAuthnContextDeclRef())) {
+                    requestedMethods.add(declRef.getAuthnContextDeclRef());
+                }
+            }
+        }
+        
+        if(requestedMethods.contains(AuthnContext.UNSPECIFIED_AUTHN_CTX)){
+            requestedMethods.clear();
+        }
+
+        return requestedMethods;
+    }
+}
\ No newline at end of file