Be sure not to add null values to list
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / Saml2LoginContext.java
index b6aaa73..7173e98 100644 (file)
@@ -25,8 +25,8 @@ import java.util.List;
 import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
 
-import org.apache.log4j.Logger;
 import org.opensaml.Configuration;
+import org.opensaml.saml2.core.AuthnContext;
 import org.opensaml.saml2.core.AuthnContextClassRef;
 import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
 import org.opensaml.saml2.core.AuthnContextDeclRef;
@@ -36,7 +36,10 @@ import org.opensaml.xml.io.Marshaller;
 import org.opensaml.xml.io.MarshallingException;
 import org.opensaml.xml.io.Unmarshaller;
 import org.opensaml.xml.io.UnmarshallingException;
+import org.opensaml.xml.util.DatatypeHelper;
 import org.opensaml.xml.util.XMLHelper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import org.w3c.dom.Element;
 import org.xml.sax.InputSource;
 
@@ -51,7 +54,10 @@ public class Saml2LoginContext extends LoginContext implements Serializable {
     private static final long serialVersionUID = -2518779446947534977L;
 
     /** Class logger. */
-    private final Logger log = Logger.getLogger(Saml2LoginContext.class);
+    private final Logger log = LoggerFactory.getLogger(Saml2LoginContext.class);
+    
+    /** Relay state from authentication request. */
+    private String relayState;
 
     /** Serialized authentication request. */
     private String serialAuthnRequest;
@@ -63,22 +69,24 @@ public class Saml2LoginContext extends LoginContext implements Serializable {
      * Creates a new instance of Saml2LoginContext.
      * 
      * @param relyingParty entity ID of the relying party
+     * @param state relay state from incoming authentication request
      * @param request SAML 2.0 Authentication Request
      * 
      * @throws MarshallingException thrown if the given request can not be marshalled and serialized into a string
      */
-    public Saml2LoginContext(String relyingParty, AuthnRequest request) throws MarshallingException {
+    public Saml2LoginContext(String relyingParty, String state, AuthnRequest request) throws MarshallingException {
         super();
         
         if (relyingParty == null || request == null) {
             throw new IllegalArgumentException("SAML 2 authentication request and relying party ID may not be null");
         }
         setRelyingParty(relyingParty);
+        relayState = state;
         authnRequest = request;
         serialAuthnRequest = serializeRequest(request);
         
-        setForceAuth(authnRequest.isForceAuthn());
-        setPassiveAuth(authnRequest.isPassive());
+        setForceAuthRequired(authnRequest.isForceAuthn());
+        setPassiveAuthRequired(authnRequest.isPassive());
         getRequestedAuthenticationMethods().addAll(extractRequestedAuthenticationMethods());
     }
 
@@ -96,6 +104,15 @@ public class Saml2LoginContext extends LoginContext implements Serializable {
 
         return authnRequest;
     }
+    
+    /**
+     * Gets the relay state from the orginating authentication request.
+     * 
+     * @return relay state from the orginating authentication request
+     */
+    public String getRelayState(){
+        return relayState;
+    }
 
     /**
      * Gets the requested authentication context information from the authentication request.
@@ -153,7 +170,7 @@ public class Saml2LoginContext extends LoginContext implements Serializable {
     /**
      * Extracts the authentication methods requested within the request.
      * 
-     * @return requested authentication methods
+     * @return requested authentication methods, or an empty list if no preference
      */
     protected List<String> extractRequestedAuthenticationMethods(){
         ArrayList<String> requestedMethods = new ArrayList<String>();
@@ -173,23 +190,26 @@ public class Saml2LoginContext extends LoginContext implements Serializable {
 
         // build a list of all requested authn classes and declrefs
         List<AuthnContextClassRef> authnClasses = authnContext.getAuthnContextClassRefs();
-        List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
-
         if (authnClasses != null) {
             for (AuthnContextClassRef classRef : authnClasses) {
-                if (classRef != null) {
+                if (classRef != null && !DatatypeHelper.isEmpty(classRef.getAuthnContextClassRef())) {
                     requestedMethods.add(classRef.getAuthnContextClassRef());
                 }
             }
         }
 
+        List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
         if (authnDeclRefs != null) {
             for (AuthnContextDeclRef declRef : authnDeclRefs) {
-                if (declRef != null) {
+                if (declRef != null&& !DatatypeHelper.isEmpty(declRef.getAuthnContextDeclRef())) {
                     requestedMethods.add(declRef.getAuthnContextDeclRef());
                 }
             }
         }
+        
+        if(requestedMethods.contains(AuthnContext.UNSPECIFIED_AUTHN_CTX)){
+            requestedMethods.clear();
+        }
 
         return requestedMethods;
     }