More work converting to multi-federation HS.
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / hs / HandleServlet.java
1 /*
2  * The Shibboleth License, Version 1. Copyright (c) 2002 University Corporation
3  * for Advanced Internet Development, Inc. All rights reserved
4  * 
5  * 
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * 
9  * Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 
12  * Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution, if any, must include
15  * the following acknowledgment: "This product includes software developed by
16  * the University Corporation for Advanced Internet Development
17  * <http://www.ucaid.edu> Internet2 Project. Alternately, this acknowledegement
18  * may appear in the software itself, if and wherever such third-party
19  * acknowledgments normally appear.
20  * 
21  * Neither the name of Shibboleth nor the names of its contributors, nor
22  * Internet2, nor the University Corporation for Advanced Internet Development,
23  * Inc., nor UCAID may be used to endorse or promote products derived from this
24  * software without specific prior written permission. For written permission,
25  * please contact shibboleth@shibboleth.org
26  * 
27  * Products derived from this software may not be called Shibboleth, Internet2,
28  * UCAID, or the University Corporation for Advanced Internet Development, nor
29  * may Shibboleth appear in their name, without prior written permission of the
30  * University Corporation for Advanced Internet Development.
31  * 
32  * 
33  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
34  * AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
35  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
36  * PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED AND THE ENTIRE RISK
37  * OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE.
38  * IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY
39  * CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC. BE LIABLE FOR ANY
40  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
41  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
42  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
43  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
44  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
45  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
46  */
47
48 package edu.internet2.middleware.shibboleth.hs;
49
50 import java.io.IOException;
51 import java.io.InputStream;
52 import java.util.Collections;
53 import java.util.Date;
54 import java.util.StringTokenizer;
55
56 import javax.servlet.RequestDispatcher;
57 import javax.servlet.ServletException;
58 import javax.servlet.UnavailableException;
59 import javax.servlet.http.HttpServlet;
60 import javax.servlet.http.HttpServletRequest;
61 import javax.servlet.http.HttpServletResponse;
62
63 import org.apache.log4j.Logger;
64 import org.apache.log4j.MDC;
65 import org.apache.xerces.parsers.DOMParser;
66 import org.doomdark.uuid.UUIDGenerator;
67 import org.opensaml.QName;
68 import org.opensaml.SAMLAuthorityBinding;
69 import org.opensaml.SAMLBinding;
70 import org.opensaml.SAMLException;
71 import org.opensaml.SAMLNameIdentifier;
72 import org.opensaml.SAMLResponse;
73 import org.w3c.dom.Element;
74 import org.w3c.dom.NodeList;
75 import org.xml.sax.EntityResolver;
76 import org.xml.sax.ErrorHandler;
77 import org.xml.sax.InputSource;
78 import org.xml.sax.SAXException;
79 import org.xml.sax.SAXParseException;
80
81 import sun.misc.BASE64Decoder;
82 import edu.internet2.middleware.shibboleth.common.AuthNPrincipal;
83 import edu.internet2.middleware.shibboleth.common.Credentials;
84 import edu.internet2.middleware.shibboleth.common.IdentityProvider;
85 import edu.internet2.middleware.shibboleth.common.NameIdentifierMapping;
86 import edu.internet2.middleware.shibboleth.common.NameIdentifierMappingException;
87 import edu.internet2.middleware.shibboleth.common.RelyingParty;
88 import edu.internet2.middleware.shibboleth.common.ShibPOSTProfile;
89 import edu.internet2.middleware.shibboleth.common.ShibResource;
90 import edu.internet2.middleware.shibboleth.common.ShibbolethOriginConfig;
91
92 public class HandleServlet extends HttpServlet {
93
94         private static Logger log = Logger.getLogger(HandleServlet.class.getName());
95         private Semaphore throttle;
96         private ShibbolethOriginConfig configuration;
97         private Credentials credentials;
98         private HSNameMapper nameMapper;
99         private ShibPOSTProfile postProfile = new ShibPOSTProfile();
100         private ServiceProviderMapper targetMapper = new ServiceProviderMapper();
101
102         protected void loadConfiguration() throws HSConfigurationException {
103
104                 DOMParser parser = loadParser(true);
105
106                 String originConfigFile = getInitParameter("OriginConfigFile");
107                 if (originConfigFile == null) {
108                         originConfigFile = "/conf/origin.xml";
109                 }
110                 log.debug("Loading Configuration from (" + originConfigFile + ").");
111
112                 try {
113                         parser.parse(new InputSource(new ShibResource(originConfigFile, this.getClass()).getInputStream()));
114
115                 } catch (SAXException e) {
116                         log.error("Error while parsing origin configuration: " + e);
117                         throw new HSConfigurationException("Error while parsing origin configuration.");
118                 } catch (IOException e) {
119                         log.error("Could not load origin configuration: " + e);
120                         throw new HSConfigurationException("Could not load origin configuration.");
121                 }
122
123                 //Load global configuration properties
124                 configuration = new ShibbolethOriginConfig(parser.getDocument().getDocumentElement());
125
126                 //Load signing credentials
127                 NodeList itemElements =
128                         parser.getDocument().getDocumentElement().getElementsByTagNameNS(
129                                 Credentials.credentialsNamespace,
130                                 "Credentials");
131                 if (itemElements.getLength() < 1) {
132                         log.error("Credentials not specified.");
133                         throw new HSConfigurationException("The Handle Service requires that signing credentials be supplied in the <Credentials> configuration element.");
134                 }
135
136                 if (itemElements.getLength() > 1) {
137                         log.error("Multiple Credentials specifications found, using first.");
138                 }
139
140                 credentials = new Credentials((Element) itemElements.item(0));
141
142                 //Load name mappings
143                 itemElements =
144                         parser.getDocument().getDocumentElement().getElementsByTagNameNS(
145                                 NameIdentifierMapping.mappingNamespace,
146                                 "NameMapping");
147
148                 for (int i = 0; i < itemElements.getLength(); i++) {
149                         try {
150                                 nameMapper.addNameMapping((Element) itemElements.item(i));
151                         } catch (NameIdentifierMappingException e) {
152                                 log.error("Name Identifier mapping could not be loaded: " + e);
153                         }
154                 }
155         }
156
157         private DOMParser loadParser(boolean schemaChecking) throws HSConfigurationException {
158
159                 DOMParser parser = new DOMParser();
160
161                 if (!schemaChecking) {
162                         return parser;
163                 }
164
165                 try {
166                         parser.setFeature("http://xml.org/sax/features/validation", true);
167                         parser.setFeature("http://apache.org/xml/features/validation/schema", true);
168
169                         parser.setEntityResolver(new EntityResolver() {
170                                 public InputSource resolveEntity(String publicId, String systemId) throws SAXException {
171                                         log.debug("Resolving entity for System ID: " + systemId);
172                                         if (systemId != null) {
173                                                 StringTokenizer tokenString = new StringTokenizer(systemId, "/");
174                                                 String xsdFile = "";
175                                                 while (tokenString.hasMoreTokens()) {
176                                                         xsdFile = tokenString.nextToken();
177                                                 }
178                                                 if (xsdFile.endsWith(".xsd")) {
179                                                         InputStream stream;
180                                                         try {
181                                                                 stream = new ShibResource("/schemas/" + xsdFile, this.getClass()).getInputStream();
182                                                         } catch (IOException ioe) {
183                                                                 log.error("Error loading schema: " + xsdFile + ": " + ioe);
184                                                                 return null;
185                                                         }
186                                                         if (stream != null) {
187                                                                 return new InputSource(stream);
188                                                         }
189                                                 }
190                                         }
191                                         return null;
192                                 }
193                         });
194
195                         parser.setErrorHandler(new ErrorHandler() {
196                                 public void error(SAXParseException arg0) throws SAXException {
197                                         throw new SAXException("Error parsing xml file: " + arg0);
198                                 }
199                                 public void fatalError(SAXParseException arg0) throws SAXException {
200                                         throw new SAXException("Error parsing xml file: " + arg0);
201                                 }
202                                 public void warning(SAXParseException arg0) throws SAXException {
203                                         throw new SAXException("Error parsing xml file: " + arg0);
204                                 }
205                         });
206
207                 } catch (SAXException e) {
208                         log.error("Unable to setup a workable XML parser: " + e);
209                         throw new HSConfigurationException("Unable to setup a workable XML parser.");
210                 }
211                 return parser;
212         }
213
214         public void init() throws ServletException {
215                 super.init();
216                 MDC.put("serviceId", "[HS] Core");
217                 try {
218                         log.info("Initializing Handle Service.");
219
220                         nameMapper = new HSNameMapper();
221                         loadConfiguration();
222
223                         throttle =
224                                 new Semaphore(
225                                         Integer.parseInt(
226                                                 configuration.getConfigProperty(
227                                                         "edu.internet2.middleware.shibboleth.hs.HandleServlet.maxThreads")));
228
229                         log.info("Handle Service initialization complete.");
230
231                 } catch (HSConfigurationException ex) {
232                         log.fatal("Handle Service runtime configuration error.  Please fix and re-initialize. Cause: " + ex);
233                         throw new UnavailableException("Handle Service failed to initialize.");
234                 }
235         }
236
237         public void doGet(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
238
239                 MDC.put("serviceId", "[HS] " + UUIDGenerator.getInstance().generateRandomBasedUUID());
240                 MDC.put("remoteAddr", req.getRemoteAddr());
241                 log.info("Handling request.");
242
243                 try {
244                         throttle.enter();
245                         checkRequestParams(req);
246
247                         req.setAttribute("shire", req.getParameter("shire"));
248                         req.setAttribute("target", req.getParameter("target"));
249
250                         RelyingParty relyingParty = targetMapper.getRelyingParty(req.getParameter("providerId"));
251
252                         String header =
253                                 relyingParty.getConfigProperty("edu.internet2.middleware.shibboleth.hs.HandleServlet.username");
254                         String username = header.equalsIgnoreCase("REMOTE_USER") ? req.getRemoteUser() : req.getHeader(header);
255
256                         SAMLNameIdentifier nameId =
257                                 nameMapper.getNameIdentifierName(
258                                         relyingParty.getHSNameFormatId(),
259                                         new AuthNPrincipal(username),
260                                         relyingParty,
261                                         relyingParty.getIdentityProvider());
262
263                         //Print out something better here
264                         //log.info("Issued Handle (" + handle + ") to (" + username +
265                         // ")");
266
267                         //TODO decide what to do about authMethod
268                         byte[] buf =
269                                 generateAssertion(
270                                         relyingParty,
271                                         nameId,
272                                         req.getParameter("shire"),
273                                         req.getRemoteAddr(),
274                                         relyingParty.getConfigProperty("edu.internet2.middleware.shibboleth.hs.HandleServlet.authMethod"));
275
276                         createForm(req, res, buf);
277
278                 } catch (NameIdentifierMappingException ex) {
279                         log.error(ex);
280                         handleError(req, res, ex);
281                         return;
282                 } catch (InvalidClientDataException ex) {
283                         log.error(ex);
284                         handleError(req, res, ex);
285                         return;
286                 } catch (SAMLException ex) {
287                         log.error(ex);
288                         handleError(req, res, ex);
289                         return;
290                 } catch (InterruptedException ex) {
291                         log.error(ex);
292                         handleError(req, res, ex);
293                         return;
294                 } finally {
295                         throttle.exit();
296                 }
297         }
298
299         protected byte[] generateAssertion(
300                 RelyingParty relyingParty,
301                 SAMLNameIdentifier nameId,
302                 String shireURL,
303                 String clientAddress,
304                 String authType)
305                 throws SAMLException, IOException {
306
307                 SAMLAuthorityBinding binding =
308                         new SAMLAuthorityBinding(
309                                 SAMLBinding.SAML_SOAP_HTTPS,
310                                 relyingParty.getConfigProperty("edu.internet2.middleware.shibboleth.hs.HandleServlet.AAUrl"),
311                                 new QName(org.opensaml.XML.SAMLP_NS, "AttributeQuery"));
312
313                 //TODO Scott mentioned the clientAddress should be optional at some
314                 // point
315                 SAMLResponse r =
316                         postProfile.prepare(
317                                 shireURL,
318                                 relyingParty,
319                                 nameId,
320                                 clientAddress,
321                                 authType,
322                                 new Date(System.currentTimeMillis()),
323                                 Collections.singleton(binding));
324
325                 return r.toBase64();
326         }
327
328         protected void createForm(HttpServletRequest req, HttpServletResponse res, byte[] buf)
329                 throws IOException, ServletException {
330
331                 //Hardcoded to ASCII to ensure Base64 encoding compatibility
332                 req.setAttribute("assertion", new String(buf, "ASCII"));
333
334                 if (log.isDebugEnabled()) {
335                         try {
336                                 log.debug(
337                                         "Dumping generated SAML Response:"
338                                                 + System.getProperty("line.separator")
339                                                 + new String(new BASE64Decoder().decodeBuffer(new String(buf, "ASCII")), "UTF8"));
340                         } catch (IOException e) {
341                                 log.error("Encountered an error while decoding SAMLReponse for logging purposes.");
342                         }
343                 }
344
345                 RequestDispatcher rd = req.getRequestDispatcher("/hs.jsp");
346                 rd.forward(req, res);
347         }
348
349         protected void handleError(HttpServletRequest req, HttpServletResponse res, Exception e)
350                 throws ServletException, IOException {
351
352                 req.setAttribute("errorText", e.toString());
353                 req.setAttribute("requestURL", req.getRequestURI().toString());
354                 RequestDispatcher rd = req.getRequestDispatcher("/hserror.jsp");
355
356                 rd.forward(req, res);
357         }
358
359         protected void checkRequestParams(HttpServletRequest req) throws InvalidClientDataException {
360
361                 if (req.getParameter("target") == null || req.getParameter("target").equals("")) {
362                         throw new InvalidClientDataException("Invalid data from SHIRE: no target URL received.");
363                 }
364                 if ((req.getParameter("shire") == null) || (req.getParameter("shire").equals(""))) {
365                         throw new InvalidClientDataException("Invalid data from SHIRE: No acceptance URL received.");
366                 }
367                 if ((req.getRemoteUser() == null) || (req.getRemoteUser().equals(""))) {
368                         throw new InvalidClientDataException("Unable to authenticate remote user");
369                 }
370                 if ((req.getRemoteAddr() == null) || (req.getRemoteAddr().equals(""))) {
371                         throw new InvalidClientDataException("Unable to obtain client address.");
372                 }
373         }
374
375         class InvalidClientDataException extends Exception {
376                 public InvalidClientDataException(String message) {
377                         super(message);
378                 }
379         }
380
381         private class Semaphore {
382                 private int value;
383
384                 public Semaphore(int value) {
385                         this.value = value;
386                 }
387
388                 public synchronized void enter() throws InterruptedException {
389                         --value;
390                         if (value < 0) {
391                                 wait();
392                         }
393                 }
394
395                 public synchronized void exit() {
396                         ++value;
397                         notify();
398                 }
399         }
400         //TODO This is just a stub... and should be moved out when meat is added
401         class ServiceProviderMapper {
402
403                 /**
404                  * @param providerIdFromTarget
405                  * @return
406                  */
407                 public RelyingParty getRelyingParty(String providerIdFromTarget) {
408
409                         return new RelyingParty(null, configuration, credentials);
410                 }
411         }
412 }