b05584b90cbcfba0323a1a5dc6e2b850f1247c77
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / utils / SAML1_0to1_1ConversionFilter.java
1 /*
2  * Copyright [2005] [University Corporation for Advanced Internet Development, Inc.]
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.utils;
18
19 import java.io.BufferedReader;
20 import java.io.ByteArrayInputStream;
21 import java.io.ByteArrayOutputStream;
22 import java.io.DataOutputStream;
23 import java.io.IOException;
24 import java.io.InputStreamReader;
25 import java.io.OutputStream;
26 import java.io.PrintWriter;
27 import java.util.regex.Matcher;
28 import java.util.regex.Pattern;
29
30 import javax.servlet.Filter;
31 import javax.servlet.FilterChain;
32 import javax.servlet.FilterConfig;
33 import javax.servlet.ServletException;
34 import javax.servlet.ServletInputStream;
35 import javax.servlet.ServletOutputStream;
36 import javax.servlet.ServletRequest;
37 import javax.servlet.ServletResponse;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletRequestWrapper;
40 import javax.servlet.http.HttpServletResponse;
41 import javax.servlet.http.HttpServletResponseWrapper;
42
43 import org.apache.log4j.Logger;
44 import org.apache.log4j.MDC;
45 import org.opensaml.SAMLConfig;
46 import org.opensaml.SAMLException;
47 import org.opensaml.SAMLIdentifier;
48
49 /**
50  * Servlet filter that intercepts incoming SAML 1.0 requests, converts them to SAML 1.1, and then reverses the
51  * conversion for the subsequent response.
52  * 
53  * @author Walter Hoehn
54  */
55 public class SAML1_0to1_1ConversionFilter implements Filter {
56
57         private static Logger log = Logger.getLogger(SAML1_0to1_1ConversionFilter.class.getName());
58         private SAMLIdentifier idgen = SAMLConfig.instance().getDefaultIDProvider();
59
60         /*
61          * @see javax.servlet.Filter#init(javax.servlet.FilterConfig)
62          */
63         public void init(FilterConfig config) throws ServletException {
64
65         }
66
67         /*
68          * @see javax.servlet.Filter#doFilter(javax.servlet.ServletRequest, javax.servlet.ServletResponse,
69          *      javax.servlet.FilterChain)
70          */
71         public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
72                         ServletException {
73
74                 MDC.put("serviceId", "[SAML Conversion Filter]");
75                 if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) {
76                         log.error("Only HTTP(s) requests are supported by the ClientCertTrustFilter.");
77                         return;
78                 }
79                 HttpServletRequest httpRequest = (HttpServletRequest) request;
80                 HttpServletResponse httpResponse = (HttpServletResponse) response;
81
82                 if (!httpRequest.getMethod().equals("POST")) {
83                         log.debug("Skipping SAML conversion because request method is not (POST).");
84                         chain.doFilter(httpRequest, httpResponse);
85                 }
86
87                 log.debug("Added SAML conversion wrapper to request.");
88
89                 StringBuffer stringBuffer = new StringBuffer();
90                 BufferedReader reader = request.getReader();
91                 for (String line = reader.readLine(); line != null; line = reader.readLine()) {
92                         stringBuffer.append(line);
93                 }
94                 reader.reset();
95
96                 String input = stringBuffer.toString();
97
98                 if (!isSAML1_0(input)) {
99                         log.debug("Skipping SAML conversion because the input does not contain a SAML 1.0 request.");
100                         chain.doFilter(new NoConversionRequestWrapper(httpRequest, input), httpResponse);
101                         return;
102                 }
103
104                 ConversionRequestWrapper requestWrapper = new ConversionRequestWrapper(httpRequest, input);
105                 ConversionResponseWrapper responseWrapper = new ConversionResponseWrapper(httpResponse, requestWrapper
106                                 .getOriginalRequestId());
107                 chain.doFilter(requestWrapper, responseWrapper);
108
109                 responseWrapper.localFlush();
110         }
111
112         /**
113          * @param input
114          */
115         private boolean isSAML1_0(String input) {
116
117                 Pattern majorRegex = Pattern.compile("<(.+:)?Request[^>]+(MajorVersion=['\"]1['\"])");
118                 Pattern minorRegex = Pattern.compile("<(.+:)?Request[^>]+(MinorVersion=['\"]0['\"])");
119                 Matcher majorMatcher = majorRegex.matcher(input);
120                 Matcher minorMatcher = minorRegex.matcher(input);
121
122                 if (!minorMatcher.find() || !majorMatcher.find()) { return false; }
123                 return true;
124         }
125
126         /*
127          * @see javax.servlet.Filter#destroy()
128          */
129         public void destroy() {
130
131         }
132
133         private class ConversionResponseWrapper extends HttpServletResponseWrapper {
134
135                 private ByteArrayOutputStream output = new ByteArrayOutputStream();
136                 private String originalRequestId;
137
138                 private ConversionResponseWrapper(HttpServletResponse response, String originalRequestId) {
139
140                         super(response);
141                         this.originalRequestId = originalRequestId;
142                 }
143
144                 private void localFlush() throws IOException {
145
146                         String result = output.toString();
147
148                         // Fail if we encounter XML Dsig, since the conversion would break it anyway
149                         Pattern regex = Pattern.compile("<(.+:)?Signature");
150                         Matcher matcher = regex.matcher(result);
151                         if (matcher.find()) {
152                                 log.error("Unable to convert SAML request from 1.0 to 1.1.");
153                                 throw new IOException("Unable to auto-convert SAML messages containing digital signatures.");
154                         }
155
156                         // Update SAML minor verion on Response and assertions
157                         regex = Pattern.compile("<(.+:)?Response[^>]+(MinorVersion=['\"]1['\"])");
158                         matcher = regex.matcher(result);
159                         if (matcher.find()) {
160                                 StringBuffer buff = new StringBuffer();
161                                 int start = matcher.start(2);
162                                 int end = matcher.end(2);
163                                 buff.append(result.subSequence(0, start));
164                                 buff.append("MinorVersion=\"0\"");
165                                 buff.append(result.substring(end));
166                                 result = buff.toString();
167                         }
168
169                         regex = Pattern.compile("<(.+:)?Assertion[^>]+(MinorVersion=['\"]1['\"])");
170                         matcher = regex.matcher(result);
171                         StringBuffer buff = new StringBuffer();
172                         int end = 0;
173                         while (matcher.find()) {
174                                 int start = matcher.start(2);
175                                 buff.append(result.subSequence(end, start));
176                                 end = matcher.end(2);
177                                 buff.append("MinorVersion=\"0\"");
178                         }
179                         if (buff.length() > 0) {
180                                 buff.append(result.substring(end));
181                                 result = buff.toString();
182                         }
183
184                         // Substitue in the real identifier from the original request
185                         regex = Pattern.compile("<(.+:)?Response[^>]+InResponseTo=['\"]([^\"]+)['\"]");
186                         matcher = regex.matcher(result);
187                         if (matcher.find()) {
188                                 buff = new StringBuffer();
189                                 int start = matcher.start(2);
190                                 end = matcher.end(2);
191                                 buff.append(result.subSequence(0, start));
192                                 buff.append(originalRequestId);
193                                 buff.append(result.substring(end));
194                                 result = buff.toString();
195                         }
196
197                         // Replace deprecated artifact confirmation method
198                         regex = Pattern
199                                         .compile("<(.+:)?ConfirmationMethod>(urn:oasis:names:tc:SAML:1.0:cm:artifact)</(.+:)?ConfirmationMethod>");
200                         matcher = regex.matcher(result);
201                         buff = new StringBuffer();
202                         end = 0;
203                         while (matcher.find()) {
204                                 int start = matcher.start(2);
205                                 buff.append(result.subSequence(end, start));
206                                 end = matcher.end(2);
207                                 buff.append("urn:oasis:names:tc:SAML:1.0:cm:artifact-01");
208                         }
209                         if (buff.length() > 0) {
210                                 buff.append(result.substring(end));
211                                 result = buff.toString();
212                         }
213
214                         super.getOutputStream().write(result.getBytes());
215                         output.reset();
216                 }
217
218                 public ServletOutputStream getOutputStream() {
219
220                         return new ModifiableOutputStream(output);
221                 }
222
223                 public PrintWriter getWriter() {
224
225                         return new PrintWriter(getOutputStream(), true);
226                 }
227
228                 public void reset() {
229
230                         super.reset();
231                         output.reset();
232                 }
233
234                 public void resetBuffer() {
235
236                         output.reset();
237                 }
238
239                 public void flushBuffer() throws IOException {
240
241                         localFlush();
242                         super.flushBuffer();
243                 }
244
245                 private class ModifiableOutputStream extends ServletOutputStream {
246
247                         private DataOutputStream stream;
248
249                         public ModifiableOutputStream(OutputStream output) {
250
251                                 stream = new DataOutputStream(output);
252                         }
253
254                         public void write(int b) throws IOException {
255
256                                 stream.write(b);
257                         }
258
259                         public void write(byte[] b) throws IOException {
260
261                                 stream.write(b);
262                         }
263
264                         public void write(byte[] b, int off, int len) throws IOException {
265
266                                 stream.write(b, off, len);
267                         }
268
269                 }
270         }
271
272         private class ConversionRequestWrapper extends HttpServletRequestWrapper {
273
274                 private ServletInputStream stream;
275                 private boolean accessed = false;
276                 private String method;
277                 private String originalRequestId;
278                 private int newLength;
279
280                 private ConversionRequestWrapper(HttpServletRequest request, String input) throws IOException {
281
282                         super(request);
283
284                         // Fail if we encounter XML Dsig, since the conversion would break it anyway
285                         Pattern regex = Pattern.compile("<(.+:)?Signature");
286                         Matcher matcher = regex.matcher(input);
287                         if (matcher.find()) {
288                                 log.error("Unable to convert SAML request from 1.0 to 1.1.");
289                                 throw new IOException("Unable to auto-convert SAML messages containing digital signatures.");
290                         }
291
292                         // Update SAML minor verion on Request
293                         regex = Pattern.compile("<(.+:)?Request[^>]+(MinorVersion=['\"]0['\"])");
294                         matcher = regex.matcher(input);
295                         if (matcher.find()) {
296                                 StringBuffer buff = new StringBuffer();
297                                 int start = matcher.start(2);
298                                 int end = matcher.end(2);
299                                 buff.append(input.subSequence(0, start));
300                                 buff.append("MinorVersion=\"1\"");
301                                 buff.append(input.substring(end));
302                                 input = buff.toString();
303                         }
304
305                         // Substitute in a fake request id that is valid in SAML 1.1, but save the original so that we can put it
306                         // back later
307                         regex = Pattern.compile("<(.+:)?Request[^>]+RequestID=['\"]([^'\"]+)['\"]");
308                         matcher = regex.matcher(input);
309                         if (matcher.find()) {
310                                 StringBuffer buff = new StringBuffer();
311                                 originalRequestId = matcher.group(2);
312                                 int start = matcher.start(2);
313                                 int end = matcher.end(2);
314                                 buff.append(input.subSequence(0, start));
315                                 try {
316                                         buff.append(idgen.getIdentifier());
317                                 } catch (SAMLException e) {
318                                         throw new IOException("Unable to obtain a new SAML message ID from provider");
319                                 }
320                                 buff.append(input.substring(end));
321                                 input = buff.toString();
322                         }
323
324                         newLength = input.length();
325                         stream = new ModifiedInputStream(new ByteArrayInputStream(input.getBytes()));
326                 }
327
328                 /*
329                  * (non-Javadoc)
330                  * 
331                  * @see javax.servlet.ServletRequest#getInputStream()
332                  */
333                 public ServletInputStream getInputStream() throws IOException {
334
335                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
336                         accessed = true;
337                         method = "getInputStream()";
338                         return stream;
339                 }
340
341                 /*
342                  * (non-Javadoc)
343                  * 
344                  * @see javax.servlet.ServletRequest#getReader()
345                  */
346                 public BufferedReader getReader() throws IOException {
347
348                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
349                         accessed = true;
350                         method = "getReader()";
351                         return new BufferedReader(new InputStreamReader(stream));
352                 }
353
354                 private String getOriginalRequestId() {
355
356                         return originalRequestId;
357
358                 }
359
360                 /*
361                  * (non-Javadoc)
362                  * 
363                  * @see javax.servlet.ServletRequest#getContentLength()
364                  */
365                 public int getContentLength() {
366
367                         return newLength;
368                 }
369
370         }
371
372         private class NoConversionRequestWrapper extends HttpServletRequestWrapper {
373
374                 private ServletInputStream stream;
375                 private boolean accessed = false;
376                 private String method;
377
378                 private NoConversionRequestWrapper(HttpServletRequest request, String input) {
379
380                         super(request);
381                         stream = new ModifiedInputStream(new ByteArrayInputStream(input.getBytes()));
382                 }
383
384                 /*
385                  * (non-Javadoc)
386                  * 
387                  * @see javax.servlet.ServletRequest#getInputStream()
388                  */
389                 public ServletInputStream getInputStream() throws IOException {
390
391                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
392                         accessed = true;
393                         method = "getInputStream()";
394                         return stream;
395                 }
396
397                 /*
398                  * (non-Javadoc)
399                  * 
400                  * @see javax.servlet.ServletRequest#getReader()
401                  */
402                 public BufferedReader getReader() throws IOException {
403
404                         if (accessed) { throw new IllegalStateException(method + " has already been called for this request"); }
405                         accessed = true;
406                         method = "getReader()";
407                         return new BufferedReader(new InputStreamReader(stream));
408                 }
409
410         }
411
412         private class ModifiedInputStream extends ServletInputStream {
413
414                 private ByteArrayInputStream stream;
415
416                 private ModifiedInputStream(ByteArrayInputStream stream) {
417
418                         this.stream = stream;
419                 }
420
421                 /*
422                  * (non-Javadoc)
423                  * 
424                  * @see javax.servlet.ServletInputStream#readLine(byte[], int, int)
425                  */
426                 public int readLine(byte[] b, int off, int len) throws IOException {
427
428                         if (len <= 0) { return 0; }
429                         int count = 0, c;
430
431                         while ((c = stream.read()) != -1) {
432                                 b[off++] = (byte) c;
433                                 count++;
434                                 if (c == '\n' || count == len) {
435                                         break;
436                                 }
437                         }
438                         return count > 0 ? count : -1;
439                 }
440
441                 /*
442                  * (non-Javadoc)
443                  * 
444                  * @see java.io.InputStream#available()
445                  */
446                 public int available() throws IOException {
447
448                         return stream.available();
449                 }
450
451                 /*
452                  * (non-Javadoc)
453                  * 
454                  * @see java.io.InputStream#close()
455                  */
456                 public void close() throws IOException {
457
458                         stream.close();
459                 }
460
461                 /*
462                  * (non-Javadoc)
463                  * 
464                  * @see java.io.InputStream#mark(int)
465                  */
466                 public synchronized void mark(int readlimit) {
467
468                         stream.mark(readlimit);
469                 }
470
471                 /*
472                  * (non-Javadoc)
473                  * 
474                  * @see java.io.InputStream#markSupported()
475                  */
476                 public boolean markSupported() {
477
478                         return stream.markSupported();
479                 }
480
481                 /*
482                  * (non-Javadoc)
483                  * 
484                  * @see java.io.InputStream#read(byte[], int, int)
485                  */
486                 public int read(byte[] b, int off, int len) throws IOException {
487
488                         return stream.read(b, off, len);
489                 }
490
491                 /*
492                  * (non-Javadoc)
493                  * 
494                  * @see java.io.InputStream#read(byte[])
495                  */
496                 public int read(byte[] b) throws IOException {
497
498                         return stream.read(b);
499                 }
500
501                 /*
502                  * (non-Javadoc)
503                  * 
504                  * @see java.io.InputStream#reset()
505                  */
506                 public synchronized void reset() throws IOException {
507
508                         stream.reset();
509                 }
510
511                 /*
512                  * (non-Javadoc)
513                  * 
514                  * @see java.io.InputStream#skip(long)
515                  */
516                 public long skip(long n) throws IOException {
517
518                         return stream.skip(n);
519                 }
520
521                 /*
522                  * (non-Javadoc)
523                  * 
524                  * @see java.io.InputStream#read()
525                  */
526                 public int read() throws IOException {
527
528                         return stream.read();
529                 }
530
531         }
532
533 }