Fixed a regex typo in the conversion filter from the last patch.
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / utils / SAML1_0to1_1ConversionFilter.java
1 /*
2  * The Shibboleth License, Version 1. Copyright (c) 2002 University Corporation for Advanced Internet Development, Inc.
3  * All rights reserved Redistribution and use in source and binary forms, with or without modification, are permitted
4  * provided that the following conditions are met: Redistributions of source code must retain the above copyright
5  * notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above
6  * copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials
7  * provided with the distribution, if any, must include the following acknowledgment: "This product includes software
8  * developed by the University Corporation for Advanced Internet Development <http://www.ucaid.edu>Internet2 Project.
9  * Alternately, this acknowledegement may appear in the software itself, if and wherever such third-party
10  * acknowledgments normally appear. Neither the name of Shibboleth nor the names of its contributors, nor Internet2, nor
11  * the University Corporation for Advanced Internet Development, Inc., nor UCAID may be used to endorse or promote
12  * products derived from this software without specific prior written permission. For written permission, please contact
13  * shibboleth@shibboleth.org Products derived from this software may not be called Shibboleth, Internet2, UCAID, or the
14  * University Corporation for Advanced Internet Development, nor may Shibboleth appear in their name, without prior
15  * written permission of the University Corporation for Advanced Internet Development. THIS SOFTWARE IS PROVIDED BY THE
16  * COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE
18  * DISCLAIMED AND THE ENTIRE RISK OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE. IN NO
19  * EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC.
20  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
21  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
23  * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 package edu.internet2.middleware.shibboleth.utils;
27
28 import java.io.BufferedReader;
29 import java.io.ByteArrayInputStream;
30 import java.io.ByteArrayOutputStream;
31 import java.io.DataOutputStream;
32 import java.io.IOException;
33 import java.io.InputStreamReader;
34 import java.io.OutputStream;
35 import java.io.PrintWriter;
36 import java.util.regex.Matcher;
37 import java.util.regex.Pattern;
38
39 import javax.servlet.Filter;
40 import javax.servlet.FilterChain;
41 import javax.servlet.FilterConfig;
42 import javax.servlet.ServletException;
43 import javax.servlet.ServletInputStream;
44 import javax.servlet.ServletOutputStream;
45 import javax.servlet.ServletRequest;
46 import javax.servlet.ServletResponse;
47 import javax.servlet.http.HttpServletRequest;
48 import javax.servlet.http.HttpServletRequestWrapper;
49 import javax.servlet.http.HttpServletResponse;
50 import javax.servlet.http.HttpServletResponseWrapper;
51
52 import org.apache.log4j.Logger;
53 import org.apache.log4j.MDC;
54 import org.opensaml.SAMLConfig;
55 import org.opensaml.SAMLException;
56 import org.opensaml.SAMLIdentifier;
57
58 /**
59  * Servlet filter that intercepts incoming SAML 1.0 requests, converts them to SAML 1.1, and then reverses the
60  * conversion for the subsequent response.
61  * 
62  * @author Walter Hoehn
63  */
64 public class SAML1_0to1_1ConversionFilter implements Filter {
65
66         private static Logger log = Logger.getLogger(SAML1_0to1_1ConversionFilter.class.getName());
67     private SAMLIdentifier idgen = SAMLConfig.instance().getDefaultIDProvider();
68
69         /*
70          * @see javax.servlet.Filter#init(javax.servlet.FilterConfig)
71          */
72         public void init(FilterConfig config) throws ServletException {
73         }
74
75         /*
76          * @see javax.servlet.Filter#doFilter(javax.servlet.ServletRequest, javax.servlet.ServletResponse,
77          *      javax.servlet.FilterChain)
78          */
79         public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
80                         ServletException {
81
82                 MDC.put("serviceId", "[SAML Conversion Filter]");
83                 if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) {
84                         log.error("Only HTTP(s) requests are supported by the ClientCertTrustFilter.");
85                         return;
86                 }
87                 HttpServletRequest httpRequest = (HttpServletRequest) request;
88                 HttpServletResponse httpResponse = (HttpServletResponse) response;
89
90                 if (!httpRequest.getMethod().equals("POST")) {
91                         log.debug("Skipping SAML conversion because request method is not (POST).");
92                         chain.doFilter(httpRequest, httpResponse);
93                 }
94
95                 log.debug("Added SAML conversion wrapper to request.");
96
97                 StringBuffer stringBuffer = new StringBuffer();
98                 BufferedReader reader = request.getReader();
99                 for (String line = reader.readLine(); line != null; line = reader.readLine()) {
100                         stringBuffer.append(line);
101                 }
102                 reader.reset();
103
104                 String input = stringBuffer.toString();
105
106                 if (!isSAML1_0(input)) {
107                         log.debug("Skipping SAML conversion because the input does not contain a SAML 1.0 request.");
108                         chain.doFilter(new NoConversionRequestWrapper(httpRequest, input), httpResponse);
109                         return;
110                 }
111
112                 ConversionRequestWrapper requestWrapper = new ConversionRequestWrapper(httpRequest, input);
113                 ConversionResponseWrapper responseWrapper = new ConversionResponseWrapper(httpResponse, requestWrapper
114                                 .getOriginalRequestId());
115                 chain.doFilter(requestWrapper, responseWrapper);
116
117                 responseWrapper.localFlush();
118         }
119
120         /**
121          * @param input
122          */
123         private boolean isSAML1_0(String input) {
124
125                 Pattern majorRegex = Pattern.compile("<(.+:)?Request[^>]+(MajorVersion=['\"]1['\"])");
126                 Pattern minorRegex = Pattern.compile("<(.+:)?Request[^>]+(MinorVersion=['\"]0['\"])");
127                 Matcher majorMatcher = majorRegex.matcher(input);
128                 Matcher minorMatcher = minorRegex.matcher(input);
129
130                 if (!minorMatcher.find() || !majorMatcher.find()) { return false; }
131                 return true;
132         }
133
134         /*
135          * @see javax.servlet.Filter#destroy()
136          */
137         public void destroy() {
138
139         }
140
141         private class ConversionResponseWrapper extends HttpServletResponseWrapper {
142
143                 private ByteArrayOutputStream output = new ByteArrayOutputStream();
144                 private boolean localFlush = false;
145                 private String originalRequestId;
146
147                 private ConversionResponseWrapper(HttpServletResponse response, String originalRequestId) {
148
149                         super(response);
150                         this.originalRequestId = originalRequestId;
151                 }
152
153                 private void localFlush() throws IOException {
154
155                         String result = output.toString();
156
157                         //Fail if we encounter XML Dsig, since the conversion would break it anyway
158                         Pattern regex = Pattern.compile("<(.+:)?Signature");
159                         Matcher matcher = regex.matcher(result);
160                         if (matcher.find()) {
161                                 log.error("Unable to convert SAML request from 1.0 to 1.1.");
162                                 throw new IOException("Unable to auto-convert SAML messages containing digital signatures.");
163                         }
164
165                         //Update SAML minor verion on Response and assertions
166                         regex = Pattern.compile("<(.+:)?Response[^>]+(MinorVersion=['\"]1['\"])");
167                         matcher = regex.matcher(result);
168                         if (matcher.find()) {
169                                 StringBuffer buff = new StringBuffer();
170                                 int start = matcher.start(2);
171                                 int end = matcher.end(2);
172                                 buff.append(result.subSequence(0, start));
173                                 buff.append("MinorVersion=\"0\"");
174                                 buff.append(result.substring(end));
175                                 result = buff.toString();
176                         }
177
178                         regex = Pattern.compile("<(.+:)?Assertion[^>]+(MinorVersion=['\"]1['\"])");
179                         matcher = regex.matcher(result);
180                         StringBuffer buff = new StringBuffer();
181                         int end = 0;
182                         while (matcher.find()) {
183                                 int start = matcher.start(2);
184                                 buff.append(result.subSequence(end, start));
185                                 end = matcher.end(2);
186                                 buff.append("MinorVersion=\"0\"");
187                         }
188                         if (buff.length() > 0) {
189                                 buff.append(result.substring(end));
190                                 result = buff.toString();
191                         }
192
193                         //Substitue in the real identifier from the original request
194                         regex = Pattern.compile("<(.+:)?Response[^>]+InResponseTo=['\"]([^\"]+)['\"]");
195                         matcher = regex.matcher(result);
196                         if (matcher.find()) {
197                                 buff = new StringBuffer();
198                                 int start = matcher.start(2);
199                                 end = matcher.end(2);
200                                 buff.append(result.subSequence(0, start));
201                                 buff.append(originalRequestId);
202                                 buff.append(result.substring(end));
203                                 result = buff.toString();
204                         }
205
206                         //Replace deprecated artifact confirmation method
207                         regex = Pattern
208                                         .compile("<(.+:)?ConfirmationMethod>(urn:oasis:names:tc:SAML:1.0:cm:artifact)</(.+:)?ConfirmationMethod>");
209                         matcher = regex.matcher(result);
210                         buff = new StringBuffer();
211                         end = 0;
212                         while (matcher.find()) {
213                                 int start = matcher.start(2);
214                                 buff.append(result.subSequence(end, start));
215                                 end = matcher.end(2);
216                                 buff.append("urn:oasis:names:tc:SAML:1.0:cm:artifact-01");
217                         }
218                         if (buff.length() > 0) {
219                                 buff.append(result.substring(end));
220                                 result = buff.toString();
221                         }
222
223                         super.getOutputStream().write(result.getBytes());
224                         output.reset();
225                 }
226
227                 public ServletOutputStream getOutputStream() {
228
229                         return new ModifiableOutputStream(output);
230                 }
231
232                 public PrintWriter getWriter() {
233
234                         return new PrintWriter(getOutputStream(), true);
235                 }
236
237                 public void reset() {
238
239                         super.reset();
240                         output.reset();
241                 }
242
243                 public void resetBuffer() {
244
245                         output.reset();
246                 }
247
248                 public void flushBuffer() throws IOException {
249
250                         localFlush();
251                         super.flushBuffer();
252                 }
253
254                 private class ModifiableOutputStream extends ServletOutputStream {
255
256                         private DataOutputStream stream;
257
258                         public ModifiableOutputStream(OutputStream output) {
259
260                                 stream = new DataOutputStream(output);
261                         }
262
263                         public void write(int b) throws IOException {
264
265                                 stream.write(b);
266                         }
267
268                         public void write(byte[] b) throws IOException {
269
270                                 stream.write(b);
271                         }
272
273                         public void write(byte[] b, int off, int len) throws IOException {
274
275                                 stream.write(b, off, len);
276                         }
277
278                 }
279         }
280
281         private class ConversionRequestWrapper extends HttpServletRequestWrapper {
282
283                 private ServletInputStream stream;
284                 private boolean accessed = false;
285                 private String method;
286                 private String originalRequestId;
287                 private int newLength;
288
289                 private ConversionRequestWrapper(HttpServletRequest request, String input) throws IOException {
290
291                         super(request);
292
293                         //Fail if we encounter XML Dsig, since the conversion would break it anyway
294                         Pattern regex = Pattern.compile("<(.+:)?Signature");
295                         Matcher matcher = regex.matcher(input);
296                         if (matcher.find()) {
297                                 log.error("Unable to convert SAML request from 1.0 to 1.1.");
298                                 throw new IOException("Unable to auto-convert SAML messages containing digital signatures.");
299                         }
300
301                         //Update SAML minor verion on Request
302                         regex = Pattern.compile("<(.+:)?Request[^>]+(MinorVersion=['\"]0['\"])");
303                         matcher = regex.matcher(input);
304                         if (matcher.find()) {
305                                 StringBuffer buff = new StringBuffer();
306                                 int start = matcher.start(2);
307                                 int end = matcher.end(2);
308                                 buff.append(input.subSequence(0, start));
309                                 buff.append("MinorVersion=\"1\"");
310                                 buff.append(input.substring(end));
311                                 input = buff.toString();
312                         }
313
314                         //Substitute in a fake request id that is valid in SAML 1.1, but save the original so that we can put it
315                         // back later
316                         regex = Pattern.compile("<(.+:)?Request[^>]+RequestID=['\"]([^\"]+)['\"]");
317                         matcher = regex.matcher(input);
318                         if (matcher.find()) {
319                                 StringBuffer buff = new StringBuffer();
320                                 originalRequestId = matcher.group(2);
321                                 int start = matcher.start(2);
322                                 int end = matcher.end(2);
323                                 buff.append(input.subSequence(0, start));
324                                 try {
325                     buff.append(idgen.getIdentifier());
326                 }
327                 catch (SAMLException e) {
328                     throw new IOException("Unable to obtain a new SAML message ID from provider");
329                 }
330                                 buff.append(input.substring(end));
331                                 input = buff.toString();
332                         }
333
334                         newLength = input.length();
335
336                         stream = new ModifiedInputStream(new ByteArrayInputStream(input.getBytes()));
337                 }
338
339                 /*
340                  * (non-Javadoc)
341                  * 
342                  * @see javax.servlet.ServletRequest#getInputStream()
343                  */
344                 public ServletInputStream getInputStream() throws IOException {
345
346                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
347                         accessed = true;
348                         method = "getInputStream()";
349                         return stream;
350                 }
351
352                 /*
353                  * (non-Javadoc)
354                  * 
355                  * @see javax.servlet.ServletRequest#getReader()
356                  */
357                 public BufferedReader getReader() throws IOException {
358
359                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
360                         accessed = true;
361                         method = "getReader()";
362                         return new BufferedReader(new InputStreamReader(stream));
363                 }
364
365                 private String getOriginalRequestId() {
366
367                         return originalRequestId;
368
369                 }
370
371                 /*
372                  * (non-Javadoc)
373                  * 
374                  * @see javax.servlet.ServletRequest#getContentLength()
375                  */
376                 public int getContentLength() {
377
378                         return newLength;
379                 }
380
381         }
382
383         private class NoConversionRequestWrapper extends HttpServletRequestWrapper {
384
385                 private ServletInputStream stream;
386                 private boolean accessed = false;
387                 private String method;
388
389                 private NoConversionRequestWrapper(HttpServletRequest request, String input) {
390
391                         super(request);
392                         stream = new ModifiedInputStream(new ByteArrayInputStream(input.getBytes()));
393                 }
394
395                 /*
396                  * (non-Javadoc)
397                  * 
398                  * @see javax.servlet.ServletRequest#getInputStream()
399                  */
400                 public ServletInputStream getInputStream() throws IOException {
401
402                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
403                         accessed = true;
404                         method = "getInputStream()";
405                         return stream;
406                 }
407
408                 /*
409                  * (non-Javadoc)
410                  * 
411                  * @see javax.servlet.ServletRequest#getReader()
412                  */
413                 public BufferedReader getReader() throws IOException {
414
415                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
416                         accessed = true;
417                         method = "getReader()";
418                         return new BufferedReader(new InputStreamReader(stream));
419                 }
420
421         }
422
423         private class ModifiedInputStream extends ServletInputStream {
424
425                 private ByteArrayInputStream stream;
426
427                 private ModifiedInputStream(ByteArrayInputStream stream) {
428
429                         this.stream = stream;
430                 }
431
432                 /*
433                  * (non-Javadoc)
434                  * 
435                  * @see javax.servlet.ServletInputStream#readLine(byte[], int, int)
436                  */
437                 public int readLine(byte[] b, int off, int len) throws IOException {
438
439                         if (len <= 0) { return 0; }
440                         int count = 0, c;
441
442                         while ((c = stream.read()) != -1) {
443                                 b[off++] = (byte) c;
444                                 count++;
445                                 if (c == '\n' || count == len) {
446                                         break;
447                                 }
448                         }
449                         return count > 0 ? count : -1;
450                 }
451
452                 /*
453                  * (non-Javadoc)
454                  * 
455                  * @see java.io.InputStream#available()
456                  */
457                 public int available() throws IOException {
458
459                         return stream.available();
460                 }
461
462                 /*
463                  * (non-Javadoc)
464                  * 
465                  * @see java.io.InputStream#close()
466                  */
467                 public void close() throws IOException {
468
469                         stream.close();
470                 }
471
472                 /*
473                  * (non-Javadoc)
474                  * 
475                  * @see java.io.InputStream#mark(int)
476                  */
477                 public synchronized void mark(int readlimit) {
478
479                         stream.mark(readlimit);
480                 }
481
482                 /*
483                  * (non-Javadoc)
484                  * 
485                  * @see java.io.InputStream#markSupported()
486                  */
487                 public boolean markSupported() {
488
489                         return stream.markSupported();
490                 }
491
492                 /*
493                  * (non-Javadoc)
494                  * 
495                  * @see java.io.InputStream#read(byte[], int, int)
496                  */
497                 public int read(byte[] b, int off, int len) throws IOException {
498
499                         return stream.read(b, off, len);
500                 }
501
502                 /*
503                  * (non-Javadoc)
504                  * 
505                  * @see java.io.InputStream#read(byte[])
506                  */
507                 public int read(byte[] b) throws IOException {
508
509                         return stream.read(b);
510                 }
511
512                 /*
513                  * (non-Javadoc)
514                  * 
515                  * @see java.io.InputStream#reset()
516                  */
517                 public synchronized void reset() throws IOException {
518
519                         stream.reset();
520                 }
521
522                 /*
523                  * (non-Javadoc)
524                  * 
525                  * @see java.io.InputStream#skip(long)
526                  */
527                 public long skip(long n) throws IOException {
528
529                         return stream.skip(n);
530                 }
531
532                 /*
533                  * (non-Javadoc)
534                  * 
535                  * @see java.io.InputStream#read()
536                  */
537                 public int read() throws IOException {
538
539                         return stream.read();
540                 }
541
542         }
543
544 }