645ac954666633c760ad31a69d2bf582706c576b
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / Saml2LoginContext.java
1 /*
2  * Copyright [2006] [University Corporation for Advanced Internet Development, Inc.]
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.idp.authn;
18
19 import java.io.Serializable;
20 import java.io.StringReader;
21 import java.io.StringWriter;
22 import java.util.ArrayList;
23 import java.util.List;
24
25 import javax.xml.parsers.DocumentBuilder;
26 import javax.xml.parsers.DocumentBuilderFactory;
27
28 import org.opensaml.Configuration;
29 import org.opensaml.saml2.core.AuthnContext;
30 import org.opensaml.saml2.core.AuthnContextClassRef;
31 import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
32 import org.opensaml.saml2.core.AuthnContextDeclRef;
33 import org.opensaml.saml2.core.AuthnRequest;
34 import org.opensaml.saml2.core.RequestedAuthnContext;
35 import org.opensaml.xml.io.Marshaller;
36 import org.opensaml.xml.io.MarshallingException;
37 import org.opensaml.xml.io.Unmarshaller;
38 import org.opensaml.xml.io.UnmarshallingException;
39 import org.opensaml.xml.util.XMLHelper;
40 import org.slf4j.Logger;
41 import org.slf4j.LoggerFactory;
42 import org.w3c.dom.Element;
43 import org.xml.sax.InputSource;
44
45 /**
46  * A SAML 2.0 {@link LoginContext}.
47  * 
48  * This class can interpret {@link RequestedAuthnContext} and act accordingly.
49  */
50 public class Saml2LoginContext extends LoginContext implements Serializable {
51
52     /** Serial version UID. */
53     private static final long serialVersionUID = -2518779446947534977L;
54
55     /** Class logger. */
56     private final Logger log = LoggerFactory.getLogger(Saml2LoginContext.class);
57     
58     /** Relay state from authentication request. */
59     private String relayState;
60
61     /** Serialized authentication request. */
62     private String serialAuthnRequest;
63
64     /** Unmarshalled authentication request. */
65     private transient AuthnRequest authnRequest;
66
67     /**
68      * Creates a new instance of Saml2LoginContext.
69      * 
70      * @param relyingParty entity ID of the relying party
71      * @param state relay state from incoming authentication request
72      * @param request SAML 2.0 Authentication Request
73      * 
74      * @throws MarshallingException thrown if the given request can not be marshalled and serialized into a string
75      */
76     public Saml2LoginContext(String relyingParty, String state, AuthnRequest request) throws MarshallingException {
77         super();
78         
79         if (relyingParty == null || request == null) {
80             throw new IllegalArgumentException("SAML 2 authentication request and relying party ID may not be null");
81         }
82         setRelyingParty(relyingParty);
83         relayState = state;
84         authnRequest = request;
85         serialAuthnRequest = serializeRequest(request);
86         
87         setForceAuthRequired(authnRequest.isForceAuthn());
88         setPassiveAuthRequired(authnRequest.isPassive());
89         getRequestedAuthenticationMethods().addAll(extractRequestedAuthenticationMethods());
90     }
91
92     /**
93      * Gets the authentication request that started the login process.
94      * 
95      * @return authentication request that started the login process
96      * 
97      * @throws UnmarshallingException thrown if the serialized form on the authentication request can be unmarshalled
98      */
99     public AuthnRequest getAuthenticationRequest() throws UnmarshallingException {
100         if (authnRequest == null) {
101             authnRequest = deserializeRequest(serialAuthnRequest);
102         }
103
104         return authnRequest;
105     }
106     
107     /**
108      * Gets the relay state from the orginating authentication request.
109      * 
110      * @return relay state from the orginating authentication request
111      */
112     public String getRelayState(){
113         return relayState;
114     }
115
116     /**
117      * Gets the requested authentication context information from the authentication request.
118      * 
119      * @return requested authentication context information or null
120      */
121     public RequestedAuthnContext getRequestedAuthenticationContext() {
122         try {
123             AuthnRequest request = getAuthenticationRequest();
124             return request.getRequestedAuthnContext();
125         } catch (UnmarshallingException e) {
126             return null;
127         }
128     }
129
130     /**
131      * Serializes an authentication request into a string.
132      * 
133      * @param request the request to serialize
134      * 
135      * @return the serialized form of the string
136      * 
137      * @throws MarshallingException thrown if the request can not be marshalled and serialized
138      */
139     protected String serializeRequest(AuthnRequest request) throws MarshallingException {
140         Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(request);
141         Element requestElem = marshaller.marshall(request);
142         StringWriter writer = new StringWriter();
143         XMLHelper.writeNode(requestElem, writer);
144         return writer.toString();
145     }
146
147     /**
148      * Deserailizes an authentication request from a string.
149      * 
150      * @param request request to deserialize
151      * 
152      * @return the request XMLObject
153      * 
154      * @throws UnmarshallingException thrown if the request can no be deserialized and unmarshalled
155      */
156     protected AuthnRequest deserializeRequest(String request) throws UnmarshallingException {
157         DocumentBuilderFactory builderFactory = DocumentBuilderFactory.newInstance();
158         try {
159             DocumentBuilder docBuilder = builderFactory.newDocumentBuilder();
160             InputSource requestInput = new InputSource(new StringReader(request));
161             Element requestElem = docBuilder.parse(requestInput).getDocumentElement();
162             Unmarshaller unmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(requestElem);
163             return (AuthnRequest) unmarshaller.unmarshall(requestElem);
164         } catch (Exception e) {
165             throw new UnmarshallingException("Unable to read serialized authentication request");
166         }
167     }
168     
169     /**
170      * Extracts the authentication methods requested within the request.
171      * 
172      * @return requested authentication methods, or an empty list if no preference
173      */
174     protected List<String> extractRequestedAuthenticationMethods(){
175         ArrayList<String> requestedMethods = new ArrayList<String>();
176
177         RequestedAuthnContext authnContext = getRequestedAuthenticationContext();
178         if (authnContext == null) {
179             return requestedMethods;
180         }
181
182         // For the immediate future, we only support the "exact" comparator.
183         AuthnContextComparisonTypeEnumeration comparator = authnContext.getComparison();
184         if (comparator != null && comparator != AuthnContextComparisonTypeEnumeration.EXACT) {
185             log.error("Unsupported comparision operator ( " + comparator
186                     + ") in RequestedAuthnContext. Only exact comparisions are supported.");
187             return requestedMethods;
188         }
189
190         // build a list of all requested authn classes and declrefs
191         List<AuthnContextClassRef> authnClasses = authnContext.getAuthnContextClassRefs();
192         if (authnClasses != null) {
193             for (AuthnContextClassRef classRef : authnClasses) {
194                 if (classRef != null) {
195                     requestedMethods.add(classRef.getAuthnContextClassRef());
196                 }
197             }
198         }
199
200         List<AuthnContextDeclRef> authnDeclRefs = authnContext.getAuthnContextDeclRefs();
201         if (authnDeclRefs != null) {
202             for (AuthnContextDeclRef declRef : authnDeclRefs) {
203                 if (declRef != null) {
204                     requestedMethods.add(declRef.getAuthnContextDeclRef());
205                 }
206             }
207         }
208         
209         if(requestedMethods.contains(AuthnContext.UNSPECIFIED_AUTHN_CTX)){
210             requestedMethods.clear();
211         }
212
213         return requestedMethods;
214     }
215 }