More refactoring, moving more into the abstract SAML 2 profile handler
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / Saml2LoginContext.java
1 /*
2  * Copyright [2006] [University Corporation for Advanced Internet Development, Inc.]
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.idp.authn;
18
19 import java.io.Serializable;
20 import java.io.StringReader;
21 import java.io.StringWriter;
22 import java.util.ArrayList;
23 import java.util.List;
24
25 import javax.xml.parsers.DocumentBuilder;
26 import javax.xml.parsers.DocumentBuilderFactory;
27
28 import org.apache.log4j.Logger;
29 import org.opensaml.Configuration;
30 import org.opensaml.saml2.core.AuthnContextClassRef;
31 import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
32 import org.opensaml.saml2.core.AuthnContextDeclRef;
33 import org.opensaml.saml2.core.AuthnRequest;
34 import org.opensaml.saml2.core.RequestedAuthnContext;
35 import org.opensaml.xml.io.Marshaller;
36 import org.opensaml.xml.io.MarshallingException;
37 import org.opensaml.xml.io.Unmarshaller;
38 import org.opensaml.xml.io.UnmarshallingException;
39 import org.opensaml.xml.util.XMLHelper;
40 import org.w3c.dom.Element;
41 import org.xml.sax.InputSource;
42
43 /**
44  * A SAML 2.0 {@link LoginContext}.
45  * 
46  * This class can interpret {@link RequestedAuthnContext} and act accordingly.
47  */
48 public class Saml2LoginContext extends LoginContext implements Serializable {
49
50     /** Serial version UID. */
51     private static final long serialVersionUID = -2518779446947534977L;
52
53     /** Class logger. */
54     private final Logger log = Logger.getLogger(Saml2LoginContext.class);
55
56     /** Serialized authentication request. */
57     private String serialAuthnRequest;
58
59     /** Unmarshalled authentication request. */
60     private transient AuthnRequest authnRequest;
61
62     /**
63      * Creates a new instance of Saml2LoginContext.
64      * 
65      * @param relyingParty entity ID of the relying party
66      * @param request SAML 2.0 Authentication Request
67      * 
68      * @throws MarshallingException thrown if the given request can not be marshalled and serialized into a string
69      */
70     public Saml2LoginContext(String relyingParty, AuthnRequest request) throws MarshallingException {
71         if (relyingParty == null || request == null) {
72             throw new IllegalArgumentException("SAML 2 authentication request and relying party ID may not be null");
73         }
74
75         serialAuthnRequest = serializeRequest(request);
76         authnRequest = request;
77         setForceAuth(authnRequest.isForceAuthn());
78         setPassiveAuth(authnRequest.isPassive());
79         setRelyingParty(relyingParty);
80     }
81
82     /**
83      * Gets the authentication request that started the login process.
84      * 
85      * @return authentication request that started the login process
86      * 
87      * @throws UnmarshallingException thrown if the serialized form on the authentication request can be unmarshalled
88      */
89     public AuthnRequest getAuthenticationRequest() throws UnmarshallingException {
90         if (authnRequest == null) {
91             authnRequest = deserializeRequest(serialAuthnRequest);
92         }
93
94         return authnRequest;
95     }
96
97     /**
98      * Gets the requested authentication context information from the authentication request.
99      * 
100      * @return requested authentication context information or null
101      */
102     public RequestedAuthnContext getRequestedAuthenticationContext() {
103         try {
104             AuthnRequest request = getAuthenticationRequest();
105             return request.getRequestedAuthnContext();
106         } catch (UnmarshallingException e) {
107             return null;
108         }
109     }
110
111     /**
112      * This method evaluates a SAML2 {@link RequestedAuthnContext} and returns the list of requested authentication
113      * method URIs.
114      * 
115      * If the AuthnQuery did not contain a RequestedAuthnContext, this method will return <code>null</code>.
116      * 
117      * @return An array of authentication method URIs, or <code>null</code>.
118      */
119     public List<String> getRequestedAuthenticationMethods() {
120         ArrayList<String> requestedMethods = new ArrayList<String>();
121
122         RequestedAuthnContext authnContext = getRequestedAuthenticationContext();
123         if (authnContext == null) {
124             return requestedMethods;
125         }
126
127         // For the immediate future, we only support the "exact" comparator.
128         AuthnContextComparisonTypeEnumeration comparator = authnContext.getComparison();
129         if (comparator != null && comparator != AuthnContextComparisonTypeEnumeration.EXACT) {
130             log.error("Unsupported comparision operator ( " + comparator
131                     + ") in RequestedAuthnContext. Only exact comparisions are supported.");
132             return null;
133         }
134
135         // build a list of all requested authn classes and declrefs
136         List<AuthnContextClassRef> authnClasses = authnContext.getAuthnContextClassRefs();
137         List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
138
139         if (authnClasses != null) {
140             for (AuthnContextClassRef classRef : authnClasses) {
141                 if (classRef != null) {
142                     requestedMethods.add(classRef.getAuthnContextClassRef());
143                 }
144             }
145         }
146
147         if (authnDeclRefs != null) {
148             for (AuthnContextDeclRef declRef : authnDeclRefs) {
149                 if (declRef != null) {
150                     requestedMethods.add(declRef.getAuthnContextDeclRef());
151                 }
152             }
153         }
154
155         return requestedMethods;
156     }
157
158     /**
159      * Serializes an authentication request into a string.
160      * 
161      * @param request the request to serialize
162      * 
163      * @return the serialized form of the string
164      * 
165      * @throws MarshallingException thrown if the request can not be marshalled and serialized
166      */
167     protected String serializeRequest(AuthnRequest request) throws MarshallingException {
168         Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(request);
169         Element requestElem = marshaller.marshall(request);
170         StringWriter writer = new StringWriter();
171         XMLHelper.writeNode(requestElem, writer);
172         return writer.toString();
173     }
174
175     /**
176      * Deserailizes an authentication request from a string.
177      * 
178      * @param request request to deserialize
179      * 
180      * @return the request XMLObject
181      * 
182      * @throws UnmarshallingException thrown if the request can no be deserialized and unmarshalled
183      */
184     protected AuthnRequest deserializeRequest(String request) throws UnmarshallingException {
185         DocumentBuilderFactory builderFactory = DocumentBuilderFactory.newInstance();
186         try {
187             DocumentBuilder docBuilder = builderFactory.newDocumentBuilder();
188             InputSource requestInput = new InputSource(new StringReader(request));
189             Element requestElem = docBuilder.parse(requestInput).getDocumentElement();
190             Unmarshaller unmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(requestElem);
191             return (AuthnRequest) unmarshaller.unmarshall(requestElem);
192         } catch (Exception e) {
193             throw new UnmarshallingException("Unable to read serialized authentication request");
194         }
195     }
196 }