Populate subject locality info from configuration or request
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / profile / saml2 / SSOProfileHandler.java
1 /*
2  * Copyright [2007] [University Corporation for Advanced Internet Development, Inc.]
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.idp.profile.saml2;
18
19 import java.io.IOException;
20 import java.util.ArrayList;
21
22 import javax.servlet.RequestDispatcher;
23 import javax.servlet.ServletException;
24 import javax.servlet.ServletRequest;
25 import javax.servlet.ServletResponse;
26 import javax.servlet.http.HttpServletRequest;
27 import javax.servlet.http.HttpSession;
28
29 import org.apache.log4j.Logger;
30 import org.opensaml.common.SAMLObjectBuilder;
31 import org.opensaml.common.binding.BindingException;
32 import org.opensaml.common.binding.decoding.MessageDecoder;
33 import org.opensaml.common.binding.encoding.MessageEncoder;
34 import org.opensaml.common.binding.security.SAMLSecurityPolicy;
35 import org.opensaml.common.xml.SAMLConstants;
36 import org.opensaml.saml2.core.SubjectLocality;
37 import org.opensaml.saml2.binding.AuthnResponseEndpointSelector;
38 import org.opensaml.saml2.core.AuthnContext;
39 import org.opensaml.saml2.core.AuthnContextClassRef;
40 import org.opensaml.saml2.core.AuthnContextDeclRef;
41 import org.opensaml.saml2.core.AuthnRequest;
42 import org.opensaml.saml2.core.AuthnStatement;
43 import org.opensaml.saml2.core.RequestedAuthnContext;
44 import org.opensaml.saml2.core.Response;
45 import org.opensaml.saml2.core.Statement;
46 import org.opensaml.saml2.core.StatusCode;
47 import org.opensaml.saml2.core.Subject;
48 import org.opensaml.saml2.metadata.AssertionConsumerService;
49 import org.opensaml.saml2.metadata.Endpoint;
50 import org.opensaml.saml2.metadata.provider.MetadataProviderException;
51 import org.opensaml.ws.security.SecurityPolicyException;
52 import org.opensaml.xml.io.MarshallingException;
53 import org.opensaml.xml.io.UnmarshallingException;
54
55 import edu.internet2.middleware.shibboleth.common.profile.ProfileException;
56 import edu.internet2.middleware.shibboleth.common.profile.ProfileRequest;
57 import edu.internet2.middleware.shibboleth.common.profile.ProfileResponse;
58 import edu.internet2.middleware.shibboleth.common.relyingparty.RelyingPartyConfiguration;
59 import edu.internet2.middleware.shibboleth.common.relyingparty.provider.saml2.SSOConfiguration;
60 import edu.internet2.middleware.shibboleth.common.util.HttpHelper;
61 import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
62 import edu.internet2.middleware.shibboleth.idp.authn.Saml2LoginContext;
63
64 /** SAML 2.0 SSO request profile handler. */
65 public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
66
67     /** Class logger. */
68     private final Logger log = Logger.getLogger(SSOProfileHandler.class);
69
70     /** Builder of AuthnStatement objects. */
71     private SAMLObjectBuilder<AuthnStatement> authnStatementBuilder;
72
73     /** Builder of AuthnContext objects. */
74     private SAMLObjectBuilder<AuthnContext> authnContextBuilder;
75
76     /** Builder of AuthnContextClassRef objects. */
77     private SAMLObjectBuilder<AuthnContextClassRef> authnContextClassRefBuilder;
78
79     /** Builder of AuthnContextDeclRef objects. */
80     private SAMLObjectBuilder<AuthnContextDeclRef> authnContextDeclRefBuilder;
81
82     /** Builder of SubjectLocality objects. */
83     private SAMLObjectBuilder<SubjectLocality> subjectLocalityBuilder;
84
85     /** URL of the authentication manager servlet. */
86     private String authenticationManagerPath;
87
88     /** URI of request decoder. */
89     private String decodingBinding;
90
91     /**
92      * Constructor.
93      * 
94      * @param authnManagerPath path to the authentication manager servlet
95      * @param decoder URI of the request decoder to use
96      */
97     @SuppressWarnings("unchecked")
98     public SSOProfileHandler(String authnManagerPath, String decoder) {
99         super();
100
101         if (authnManagerPath == null || decoder == null) {
102             throw new IllegalArgumentException("AuthN manager path or decoding bindings URI may not be null");
103         }
104
105         authenticationManagerPath = authnManagerPath;
106         decodingBinding = decoder;
107
108         authnStatementBuilder = (SAMLObjectBuilder<AuthnStatement>) getBuilderFactory().getBuilder(
109                 AuthnStatement.DEFAULT_ELEMENT_NAME);
110         authnContextBuilder = (SAMLObjectBuilder<AuthnContext>) getBuilderFactory().getBuilder(
111                 AuthnContext.DEFAULT_ELEMENT_NAME);
112         authnContextClassRefBuilder = (SAMLObjectBuilder<AuthnContextClassRef>) getBuilderFactory().getBuilder(
113                 AuthnContextClassRef.DEFAULT_ELEMENT_NAME);
114         authnContextDeclRefBuilder = (SAMLObjectBuilder<AuthnContextDeclRef>) getBuilderFactory().getBuilder(
115                 AuthnContextDeclRef.DEFAULT_ELEMENT_NAME);
116         subjectLocalityBuilder = (SAMLObjectBuilder<SubjectLocality>) getBuilderFactory().getBuilder(
117                 SubjectLocality.DEFAULT_ELEMENT_NAME);
118     }
119
120     /**
121      * Convenience method for getting the SAML 2 AuthnStatement builder.
122      * 
123      * @return SAML 2 AuthnStatement builder
124      */
125     public SAMLObjectBuilder<AuthnStatement> getAuthnStatementBuilder() {
126         return authnStatementBuilder;
127     }
128
129     /**
130      * Convenience method for getting the SAML 2 AuthnContext builder.
131      * 
132      * @return SAML 2 AuthnContext builder
133      */
134     public SAMLObjectBuilder<AuthnContext> getAuthnContextBuilder() {
135         return authnContextBuilder;
136     }
137
138     /**
139      * Convenience method for getting the SAML 2 AuthnContextClassRef builder.
140      * 
141      * @return SAML 2 AuthnContextClassRef builder
142      */
143     public SAMLObjectBuilder<AuthnContextClassRef> getAuthnContextClassRefBuilder() {
144         return authnContextClassRefBuilder;
145     }
146
147     /**
148      * Convenience method for getting the SAML 2 AuthnContextDeclRef builder.
149      * 
150      * @return SAML 2 AuthnContextDeclRef builder
151      */
152     public SAMLObjectBuilder<AuthnContextDeclRef> getAuthnContextDeclRefBuilder() {
153         return authnContextDeclRefBuilder;
154     }
155
156     /** {@inheritDoc} */
157     public String getProfileId() {
158         return "urn:mace:shibboleth:2.0:idp:profiles:saml2:request:sso";
159     }
160
161     /** {@inheritDoc} */
162     public void processRequest(ProfileRequest<ServletRequest> request, ProfileResponse<ServletResponse> response)
163             throws ProfileException {
164
165         HttpSession httpSession = ((HttpServletRequest) request.getRawRequest()).getSession(true);
166         if (httpSession.getAttribute(LoginContext.LOGIN_CONTEXT_KEY) == null) {
167             performAuthentication(request, response);
168         } else {
169             completeAuthenticationRequest(request, response);
170         }
171     }
172
173     /**
174      * Creates a {@link Saml2LoginContext} an sends the request off to the AuthenticationManager to begin the process of
175      * authenticating the user.
176      * 
177      * @param request current request
178      * @param response current response
179      * 
180      * @throws ProfileException thrown if there is a problem creating the login context and transferring control to the
181      *             authentication manager
182      */
183     protected void performAuthentication(ProfileRequest<ServletRequest> request,
184             ProfileResponse<ServletResponse> response) throws ProfileException {
185         HttpServletRequest httpRequest = (HttpServletRequest) request.getRawRequest();
186
187         AuthnRequest authnRequest = null;
188         try {
189             MessageDecoder<ServletRequest> decoder = decodeRequest(request);
190             SAMLSecurityPolicy securityPolicy = decoder.getSecurityPolicy();
191
192             String relyingParty = securityPolicy.getIssuer();
193             RelyingPartyConfiguration rpConfig = getRelyingPartyConfiguration(relyingParty);
194             if (rpConfig == null) {
195                 log.error("No relying party configuration for " + relyingParty);
196                 throw new ProfileException("No relying party configuration for " + relyingParty);
197             }
198
199             authnRequest = (AuthnRequest) decoder.getSAMLMessage();
200
201             Saml2LoginContext loginContext = new Saml2LoginContext(relyingParty, authnRequest);
202             loginContext.setAuthenticationEngineURL(authenticationManagerPath);
203             loginContext.setProfileHandlerURL(HttpHelper.getRequestUriWithoutContext(httpRequest));
204             if (loginContext.getRequestedAuthenticationMethods().size() == 0) {
205                 loginContext.getRequestedAuthenticationMethods().add(rpConfig.getDefaultAuthenticationMethod());
206             }
207
208             HttpSession httpSession = httpRequest.getSession();
209             httpSession.setAttribute(Saml2LoginContext.LOGIN_CONTEXT_KEY, loginContext);
210             RequestDispatcher dispatcher = httpRequest.getRequestDispatcher(authenticationManagerPath);
211             dispatcher.forward(httpRequest, response.getRawResponse());
212         } catch (MarshallingException e) {
213             log.error("Unable to marshall authentication request context");
214             throw new ProfileException("Unable to marshall authentication request context", e);
215         } catch (IOException ex) {
216             log.error("Error forwarding SAML 2 AuthnRequest " + authnRequest.getID() + " to AuthenticationManager", ex);
217             throw new ProfileException("Error forwarding SAML 2 AuthnRequest " + authnRequest.getID()
218                     + " to AuthenticationManager", ex);
219         } catch (ServletException ex) {
220             log.error("Error forwarding SAML 2 AuthnRequest " + authnRequest.getID() + " to AuthenticationManager", ex);
221             throw new ProfileException("Error forwarding SAML 2 AuthnRequest " + authnRequest.getID()
222                     + " to AuthenticationManager", ex);
223         }
224     }
225
226     /**
227      * Creates a response to the {@link AuthnRequest} and sends the user, with response in tow, back to the relying
228      * party after they've been authenticated.
229      * 
230      * @param request current request
231      * @param response current response
232      * 
233      * @throws ProfileException thrown if the response can not be created and sent back to the relying party
234      */
235     protected void completeAuthenticationRequest(ProfileRequest<ServletRequest> request,
236             ProfileResponse<ServletResponse> response) throws ProfileException {
237
238         HttpSession httpSession = ((HttpServletRequest) request.getRawRequest()).getSession(true);
239
240         Saml2LoginContext loginContext = (Saml2LoginContext) httpSession.getAttribute(LoginContext.LOGIN_CONTEXT_KEY);
241         httpSession.removeAttribute(LoginContext.LOGIN_CONTEXT_KEY);
242
243         SSORequestContext requestContext = buildRequestContext(loginContext, request, response);
244
245         checkSamlVersion(requestContext);
246
247         Response samlResponse;
248         try {
249             if (!loginContext.isPrincipalAuthenticated()) {
250                 requestContext
251                         .setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, StatusCode.AUTHN_FAILED_URI, null));
252                 throw new ProfileException("User failed authentication");
253             }
254
255             ArrayList<Statement> statements = new ArrayList<Statement>();
256             statements.add(buildAuthnStatement(requestContext));
257             if (requestContext.getProfileConfiguration().includeAttributeStatement()) {
258                 statements.add(buildAttributeStatement(requestContext));
259             }
260
261             Subject assertionSubject = buildSubject(requestContext, "urn:oasis:names:tc:SAML:2.0:cm:bearer");
262
263             samlResponse = buildResponse(requestContext, assertionSubject, statements);
264         } catch (ProfileException e) {
265             samlResponse = buildErrorResponse(requestContext);
266         }
267
268         requestContext.setSamlResponse(samlResponse);
269         encodeResponse(requestContext);
270         writeAuditLogEntry(requestContext);
271     }
272
273     /**
274      * Creates an appropriate message decoder, populates it, and decodes the incoming request.
275      * 
276      * @param request current request
277      * 
278      * @return message decoder containing the decoded message and other stateful information
279      * 
280      * @throws ProfileException thrown if the incomming message failed decoding
281      */
282     protected MessageDecoder<ServletRequest> decodeRequest(ProfileRequest<ServletRequest> request)
283             throws ProfileException {
284         MessageDecoder<ServletRequest> decoder = getMessageDecoderFactory().getMessageDecoder(decodingBinding);
285         if (decoder == null) {
286             log.error("No request decoder was registered for binding type: " + decodingBinding);
287             throw new ProfileException("No request decoder was registered for binding type: " + decodingBinding);
288         }
289
290         populateMessageDecoder(decoder);
291         decoder.setRequest(request.getRawRequest());
292         try {
293             decoder.decode();
294             return decoder;
295         } catch (BindingException e) {
296             log.error("Error decoding authentication request message", e);
297             throw new ProfileException("Error decoding authentication request message", e);
298         } catch (SecurityPolicyException e) {
299             log.error("Message did not meet security policy requirements", e);
300             throw new ProfileException("Message did not meet security policy requirements", e);
301         }
302     }
303
304     /**
305      * Creates an authentication request context from the current environmental information.
306      * 
307      * @param loginContext current login context
308      * @param request current request
309      * @param response current response
310      * 
311      * @return created authentication request context
312      * 
313      * @throws ProfileException thrown if there is a problem creating the context
314      */
315     protected SSORequestContext buildRequestContext(Saml2LoginContext loginContext,
316             ProfileRequest<ServletRequest> request, ProfileResponse<ServletResponse> response) throws ProfileException {
317         SSORequestContext requestContext = new SSORequestContext(request, response);
318
319         try {
320             requestContext.setMessageDecoder(getMessageDecoderFactory().getMessageDecoder(decodingBinding));
321
322             requestContext.setLoginContext(loginContext);
323
324             String relyingPartyId = loginContext.getRelyingPartyId();
325             AuthnRequest authnRequest = loginContext.getAuthenticationRequest();
326
327             requestContext.setRelyingPartyId(relyingPartyId);
328
329             requestContext.setRelyingPartyMetadata(getMetadataProvider().getEntityDescriptor(relyingPartyId));
330
331             requestContext.setRelyingPartyRoleMetadata(requestContext.getRelyingPartyMetadata().getSPSSODescriptor(
332                     SAMLConstants.SAML20P_NS));
333
334             RelyingPartyConfiguration rpConfig = getRelyingPartyConfiguration(relyingPartyId);
335             requestContext.setRelyingPartyConfiguration(rpConfig);
336
337             requestContext.setAssertingPartyId(requestContext.getRelyingPartyConfiguration().getProviderId());
338
339             requestContext.setAssertingPartyMetadata(getMetadataProvider().getEntityDescriptor(
340                     requestContext.getAssertingPartyId()));
341
342             requestContext.setAssertingPartyRoleMetadata(requestContext.getRelyingPartyMetadata().getIDPSSODescriptor(
343                     SAMLConstants.SAML20P_NS));
344
345             requestContext.setPrincipalName(loginContext.getPrincipalName());
346
347             requestContext.setProfileConfiguration((SSOConfiguration) rpConfig
348                     .getProfileConfiguration(SSOConfiguration.PROFILE_ID));
349
350             requestContext.setSamlRequest(authnRequest);
351
352             return requestContext;
353         } catch (UnmarshallingException e) {
354             log.error("Unable to unmarshall authentication request context");
355             requestContext.setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, null,
356                     "Error recovering request state"));
357             throw new ProfileException("Error recovering request state", e);
358         } catch (MetadataProviderException e) {
359             log.error("Unable to locate metadata for asserting or relying party");
360             requestContext
361                     .setFailureStatus(buildStatus(StatusCode.RESPONDER_URI, null, "Error locating party metadata"));
362             throw new ProfileException("Error locating party metadata");
363         }
364     }
365
366     /**
367      * Creates an authentication statement for the current request.
368      * 
369      * @param requestContext current request context
370      * 
371      * @return constructed authentication statement
372      */
373     protected AuthnStatement buildAuthnStatement(SSORequestContext requestContext) {
374         Saml2LoginContext loginContext = requestContext.getLoginContext();
375
376         AuthnContext authnContext = buildAuthnContext(requestContext);
377
378         AuthnStatement statement = getAuthnStatementBuilder().buildObject();
379         statement.setAuthnContext(authnContext);
380         statement.setAuthnInstant(loginContext.getAuthenticationInstant());
381
382         // TODO
383         statement.setSessionIndex(null);
384
385         if (loginContext.getAuthenticationDuration() > 0) {
386             statement.setSessionNotOnOrAfter(loginContext.getAuthenticationInstant().plus(
387                     loginContext.getAuthenticationDuration()));
388         }
389
390         statement.setSubjectLocality(buildSubjectLocality(requestContext));
391
392         return statement;
393     }
394
395     /**
396      * Creates an {@link AuthnContext} for a succesful authentication request.
397      * 
398      * @param requestContext current request
399      * 
400      * @return the built authn context
401      */
402     protected AuthnContext buildAuthnContext(SSORequestContext requestContext) {
403         AuthnContext authnContext = getAuthnContextBuilder().buildObject();
404
405         Saml2LoginContext loginContext = requestContext.getLoginContext();
406         AuthnRequest authnRequest = requestContext.getSamlRequest();
407         RequestedAuthnContext requestedAuthnContext = authnRequest.getRequestedAuthnContext();
408         if (requestedAuthnContext != null) {
409             if (requestedAuthnContext.getAuthnContextClassRefs() != null) {
410                 for (AuthnContextClassRef classRef : requestedAuthnContext.getAuthnContextClassRefs()) {
411                     if (classRef.getAuthnContextClassRef().equals(loginContext.getAuthenticationMethod())) {
412                         AuthnContextClassRef ref = getAuthnContextClassRefBuilder().buildObject();
413                         ref.setAuthnContextClassRef(loginContext.getAuthenticationMethod());
414                         authnContext.setAuthnContextClassRef(ref);
415                     }
416                 }
417             } else if (requestedAuthnContext.getAuthnContextDeclRefs() != null) {
418                 for (AuthnContextDeclRef declRef : requestedAuthnContext.getAuthnContextDeclRefs()) {
419                     if (declRef.getAuthnContextDeclRef().equals(loginContext.getAuthenticationMethod())) {
420                         AuthnContextDeclRef ref = getAuthnContextDeclRefBuilder().buildObject();
421                         ref.setAuthnContextDeclRef(loginContext.getAuthenticationMethod());
422                         authnContext.setAuthnContextDeclRef(ref);
423                     }
424                 }
425             }
426         } else {
427             AuthnContextDeclRef ref = getAuthnContextDeclRefBuilder().buildObject();
428             ref.setAuthnContextDeclRef(loginContext.getAuthenticationMethod());
429             authnContext.setAuthnContextDeclRef(ref);
430         }
431
432         return authnContext;
433     }
434
435     /**
436      * Constructs the subject locality for the authentication statement.
437      * 
438      * @param requestContext curent request context
439      * 
440      * @return subject locality for the authentication statement
441      */
442     protected SubjectLocality buildSubjectLocality(SSORequestContext requestContext) {
443         SubjectLocality subjectLocality = subjectLocalityBuilder.buildObject();
444
445         SSOConfiguration profileConfig = requestContext.getProfileConfiguration();
446         HttpServletRequest httpRequest = (HttpServletRequest) requestContext.getProfileRequest().getRawRequest();
447
448         if (profileConfig.getLocalityAddress() != null) {
449             subjectLocality.setAddress(profileConfig.getLocalityAddress());
450         } else {
451             subjectLocality.setAddress(httpRequest.getLocalAddr());
452         }
453
454         if (profileConfig.getLocalityDNSName() != null) {
455             subjectLocality.setDNSName(profileConfig.getLocalityDNSName());
456         } else {
457             subjectLocality.setDNSName(httpRequest.getLocalName());
458         }
459
460         return subjectLocality;
461     }
462
463     /**
464      * Encodes the request's SAML response and writes it to the servlet response.
465      * 
466      * @param requestContext current request context
467      * 
468      * @throws ProfileException thrown if no message encoder is registered for this profiles binding
469      */
470     protected void encodeResponse(SSORequestContext requestContext) throws ProfileException {
471         if (log.isDebugEnabled()) {
472             log.debug("Encoding response to SAML request " + requestContext.getSamlRequest().getID()
473                     + " from relying party " + requestContext.getRelyingPartyId());
474         }
475         AuthnResponseEndpointSelector endpointSelector = new AuthnResponseEndpointSelector();
476         endpointSelector.setEndpointType(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
477         endpointSelector.setMetadataProvider(getMetadataProvider());
478         endpointSelector.setRelyingParty(requestContext.getRelyingPartyMetadata());
479         endpointSelector.setRelyingPartyRole(requestContext.getRelyingPartyRoleMetadata());
480         endpointSelector.setSamlRequest(requestContext.getSamlRequest());
481         endpointSelector.getSupportedIssuerBindings().addAll(getMessageEncoderFactory().getEncoderBuilders().keySet());
482         Endpoint relyingPartyEndpoint = endpointSelector.selectEndpoint();
483
484         MessageEncoder<ServletResponse> encoder = getMessageEncoderFactory().getMessageEncoder(
485                 relyingPartyEndpoint.getBinding());
486         if (encoder == null) {
487             log.error("No response encoder was registered for binding type: " + relyingPartyEndpoint.getBinding());
488             throw new ProfileException("No response encoder was registered for binding type: "
489                     + relyingPartyEndpoint.getBinding());
490         }
491
492         super.populateMessageEncoder(encoder);
493         encoder.setIssuer(requestContext.getAssertingPartyId());
494         encoder.setRelyingParty(requestContext.getRelyingPartyMetadata());
495         encoder.setRelyingPartyEndpoint(relyingPartyEndpoint);
496         encoder.setRelyingPartyRole(requestContext.getRelyingPartyRoleMetadata());
497         ProfileResponse<ServletResponse> profileResponse = requestContext.getProfileResponse();
498         encoder.setResponse(profileResponse.getRawResponse());
499         encoder.setSamlMessage(requestContext.getSamlResponse());
500         requestContext.setMessageEncoder(encoder);
501
502         try {
503             encoder.encode();
504         } catch (BindingException e) {
505             throw new ProfileException("Unable to encode response to relying party: "
506                     + requestContext.getRelyingPartyId(), e);
507         }
508     }
509
510     /** Represents the internal state of a SAML 2.0 SSO Request while it's being processed by the IdP. */
511     protected class SSORequestContext extends SAML2ProfileRequestContext<AuthnRequest, Response, SSOConfiguration> {
512
513         /** Current login context. */
514         private Saml2LoginContext loginContext;
515
516         /**
517          * Constructor.
518          * 
519          * @param request current profile request
520          * @param response current profile response
521          */
522         public SSORequestContext(ProfileRequest<ServletRequest> request, ProfileResponse<ServletResponse> response) {
523             super(request, response);
524         }
525
526         /**
527          * Gets the current login context.
528          * 
529          * @return current login context
530          */
531         public Saml2LoginContext getLoginContext() {
532             return loginContext;
533         }
534
535         /**
536          * Sets the current login context.
537          * 
538          * @param context current login context
539          */
540         public void setLoginContext(Saml2LoginContext context) {
541             loginContext = context;
542         }
543     }
544 }