use the new session manager interface
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / common / provider / CookieCache.java
index 0269211..d43fb84 100644 (file)
 
 package edu.internet2.middleware.shibboleth.common.provider;
 
+import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.StringReader;
+import java.io.StringWriter;
 import java.security.GeneralSecurityException;
-import java.security.KeyException;
 import java.security.SecureRandom;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
+import java.util.zip.GZIPInputStream;
 import java.util.zip.GZIPOutputStream;
 
 import javax.crypto.Cipher;
@@ -36,55 +42,95 @@ import javax.crypto.spec.IvParameterSpec;
 import javax.servlet.http.Cookie;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.xml.parsers.DocumentBuilderFactory;
+import javax.xml.transform.OutputKeys;
+import javax.xml.transform.Transformer;
+import javax.xml.transform.TransformerFactory;
+import javax.xml.transform.dom.DOMSource;
+import javax.xml.transform.stream.StreamResult;
+
+import org.apache.log4j.Logger;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+import org.w3c.dom.NodeList;
+import org.xml.sax.InputSource;
 
 import edu.internet2.middleware.shibboleth.common.Cache;
+import edu.internet2.middleware.shibboleth.common.CacheException;
 import edu.internet2.middleware.shibboleth.utils.Base32;
 
 /**
  * <code>Cache</code> implementation that uses browser cookies to store data. Symmetric and HMAC algorithms are used
  * to encrypt and verify the data. Due to the size limitations of cookie storage, data may interleaved among multiple
- * cookies.
+ * cookies. NOTE: Using this cache implementation in a standalon tomcat configuration will usually require that the
+ * "maxHttpHeaderSize" parameter be greatly increased.
  * 
  * @author Walter Hoehn
  */
 public class CookieCache extends BaseCache implements Cache {
 
-       // TODO domain limit?
+       private static Logger log = Logger.getLogger(CookieCache.class.getName());
        private HttpServletResponse response;
-       private List<Cookie> myCurrentCookies = new ArrayList<Cookie>();
+       private Collection<Cookie> myCurrentCookies = new ArrayList<Cookie>();
        private Map<String, CacheEntry> dataCache = new HashMap<String, CacheEntry>();
-       private static final int CHUNK_SIZE = 4 * 1024; // in KB, minimal browser requirement
+       private static final int CHUNK_SIZE = 4 * 1024; // minimal browser requirement
        private static final int COOKIE_LIMIT = 20; // minimal browser requirement
        private static final String NAME_PREFIX = "IDP_CACHE:";
+       private static int totalCookies = 0;
        protected SecretKey secret;
        private static SecureRandom random = new SecureRandom();
-       private String cipherAlgorithm = "DESede/CBC/PKCS5Padding";
-       private String macAlgorithm = "HmacSHA1";
-       private String storeType = "JCEKS";
+       private String cipherAlgorithm;
+       private String macAlgorithm;
 
-       CookieCache(String name, HttpServletRequest request, HttpServletResponse response) {
+       public CookieCache(String name, SecretKey key, String cipherAlgorithm, String macAlgorithm,
+                       HttpServletRequest request, HttpServletResponse response) throws CacheException {
 
                super(name, Cache.CacheType.CLIENT_SIDE);
+               this.secret = key;
+               this.cipherAlgorithm = cipherAlgorithm;
+               this.macAlgorithm = macAlgorithm;
                this.response = response;
                Cookie[] requestCookies = request.getCookies();
-               for (int i = 0; i < requestCookies.length; i++) {
-                       if (requestCookies[i].getName().startsWith(NAME_PREFIX)) {
-                               myCurrentCookies.add(requestCookies[i]);
+               if (requestCookies != null) {
+                       for (int i = 0; i < requestCookies.length; i++) {
+                               if (requestCookies[i].getName().startsWith(NAME_PREFIX + getName())
+                                               && requestCookies[i].getValue() != null) {
+                                       myCurrentCookies.add(requestCookies[i]);
+                               }
                        }
                }
 
-               // TODO dechunk, decrypt, and pull in dataCache
+               if (usingDefaultSecret()) {
+                       log.warn("You are running the Cookie Cache with the "
+                                       + "default secret key.  This is UNSAFE!  Please change "
+                                       + "this configuration and restart the IdP.");
+               }
+
+               initFromCookies();
+       }
+
+       public void postProcessing() throws CacheException {
+
+               if (totalCookies > (COOKIE_LIMIT - 1)) {
+                       log.warn("The Cookie Cache mechanism is about to write a large amount of data to the "
+                                       + "client.  This may not work with some browser software, so it is recommended"
+                                       + " that you investigate other caching mechanisms.");
+               }
+
+               flushCache();
        }
 
-       public boolean contains(String key) {
+       public boolean contains(String key) throws CacheException {
 
                CacheEntry entry = dataCache.get(key);
 
                if (entry == null) { return false; }
 
                // Clean cache if it is expired
-               if (new Date().after(((CacheEntry) entry).expiration)) {
-                       deleteFromCache(key);
+               if ((((CacheEntry) entry).isExpired())) {
+                       log.debug("Found expired object.  Deleting...");
+                       totalCookies--;
+                       dataCache.remove(key);
                        return false;
                }
 
@@ -92,21 +138,17 @@ public class CookieCache extends BaseCache implements Cache {
                return true;
        }
 
-       private void deleteFromCache(String key) {
-
-               dataCache.remove(key);
-               flushCache();
-       }
-
-       public Object retrieve(String key) {
+       public String retrieve(String key) throws CacheException {
 
                CacheEntry entry = dataCache.get(key);
 
                if (entry == null) { return null; }
 
                // Clean cache if it is expired
-               if (new Date().after(((CacheEntry) entry).expiration)) {
-                       deleteFromCache(key);
+               if ((((CacheEntry) entry).isExpired())) {
+                       log.debug("Found expired object.  Deleting...");
+                       totalCookies--;
+                       dataCache.remove(key);
                        return null;
                }
 
@@ -114,35 +156,180 @@ public class CookieCache extends BaseCache implements Cache {
                return entry.value;
        }
 
-       public void store(String key, String value, long duration) {
+       public void remove(String key) throws CacheException {
+
+               dataCache.remove(key);
+               totalCookies--;
+       }
+
+       public void store(String key, String value, long duration) throws CacheException {
 
                dataCache.put(key, new CacheEntry(value, duration));
-               flushCache();
+               totalCookies++;
+       }
+
+       private void initFromCookies() throws CacheException {
+
+               log.debug("Attempting to initialize cache from client-supplied cookies.");
+               // Pull data from cookies
+               List<Cookie> relevantCookies = new ArrayList<Cookie>();
+               for (Cookie cookie : myCurrentCookies) {
+                       if (cookie.getName().startsWith(NAME_PREFIX + getName())) {
+                               relevantCookies.add(cookie);
+                       }
+               }
+               if (relevantCookies.isEmpty()) {
+                       log.debug("No applicable cookies found.  Cache is empty.");
+                       return;
+               }
+
+               // Sort
+               String[] sortedCookieValues = new String[relevantCookies.size()];
+               for (Cookie cookie : relevantCookies) {
+                       String[] tokenizedName = cookie.getName().split(":");
+                       sortedCookieValues[Integer.parseInt(tokenizedName[tokenizedName.length - 1]) - 1] = cookie.getValue();
+               }
+               // Concatenate
+               StringBuffer concat = new StringBuffer();
+               for (String cookieValue : sortedCookieValues) {
+                       concat.append(cookieValue);
+               }
+               log.debug("Dumping Encrypted/Encoded Input Cache: " + concat);
+
+               try {
+                       // Decode Base32
+                       byte[] in = Base32.decode(concat.toString());
+
+                       // Decrypt
+                       Cipher cipher = Cipher.getInstance(cipherAlgorithm);
+                       int ivSize = cipher.getBlockSize();
+                       byte[] iv = new byte[ivSize];
+                       Mac mac = Mac.getInstance(macAlgorithm);
+                       mac.init(secret);
+                       int macSize = mac.getMacLength();
+                       if (in.length < ivSize) {
+                               log.error("Cache is malformed (not enough bytes).");
+                               throw new CacheException("Cache is malformed (not enough bytes).");
+                       }
+
+                       // extract the IV, setup the cipher and extract the encrypted data
+                       System.arraycopy(in, 0, iv, 0, ivSize);
+                       IvParameterSpec ivSpec = new IvParameterSpec(iv);
+                       cipher.init(Cipher.DECRYPT_MODE, secret, ivSpec);
+                       byte[] encryptedData = new byte[in.length - iv.length];
+                       System.arraycopy(in, ivSize, encryptedData, 0, in.length - iv.length);
+
+                       // decrypt the rest of the data andsetup the streams
+                       byte[] decryptedBytes = cipher.doFinal(encryptedData);
+                       ByteArrayInputStream byteStream = new ByteArrayInputStream(decryptedBytes);
+                       GZIPInputStream compressedData = new GZIPInputStream(byteStream);
+                       ObjectInputStream dataStream = new ObjectInputStream(compressedData);
+
+                       // extract the components
+                       byte[] decodedMac = new byte[macSize];
+
+                       int bytesRead = dataStream.read(decodedMac);
+                       if (bytesRead != macSize) {
+                               log.error("Error parsing cache: Unable to extract HMAC.");
+                               throw new CacheException("Error parsing cache: Unable to extract HMAC.");
+                       }
+
+                       String decodedData = (String) dataStream.readObject();
+                       log.debug("Dumping Raw Input Cache: " + decodedData);
+
+                       // Verify HMAC
+                       byte[] generatedMac = mac.doFinal(decodedData.getBytes());
+                       if (!Arrays.equals(decodedMac, generatedMac)) {
+                               log.error("Cookie cache data failed integrity  check.");
+                               throw new GeneralSecurityException("Cookie cache data failed integrity check.");
+                       }
+
+                       // Parse XML
+                       DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
+                       factory.setValidating(false);
+                       factory.setNamespaceAware(false);
+                       Element cacheElement = factory.newDocumentBuilder().parse(new InputSource(new StringReader(decodedData)))
+                                       .getDocumentElement();
+                       NodeList items = cacheElement.getElementsByTagName("Item");
+                       for (int i = 0; i < items.getLength(); i++) {
+                               Element item = (Element) items.item(i);
+                               totalCookies++;
+                               dataCache.put(item.getAttribute("key"), new CacheEntry(item.getAttribute("value"), new Date(new Long(
+                                               item.getAttribute("expire")))));
+                       }
+
+               } catch (Exception e) {
+                       log.error("Error decrypting cache data: " + e);
+                       throw new CacheException("Unable to read cached data.");
+               }
+       }
+
+       private boolean usingDefaultSecret() {
+
+               byte[] defaultKey = new byte[]{(byte) 0xC7, (byte) 0x49, (byte) 0x80, (byte) 0xD3, (byte) 0x02, (byte) 0x4A,
+                               (byte) 0x61, (byte) 0xEF, (byte) 0x25, (byte) 0x5D, (byte) 0xE3, (byte) 0x2F, (byte) 0x57, (byte) 0x51,
+                               (byte) 0x20, (byte) 0x15, (byte) 0xC7, (byte) 0x49, (byte) 0x80, (byte) 0xD3, (byte) 0x02, (byte) 0x4A,
+                               (byte) 0x61, (byte) 0xEF};
+               byte[] encodedKey = secret.getEncoded();
+               return Arrays.equals(defaultKey, encodedKey);
        }
 
        /**
         * Secures, encodes, and writes out (to cookies) cached data.
         */
-       private void flushCache() {
+       private void flushCache() throws CacheException {
+
+               log.debug("Flushing cache.");
+               log.debug("Encrypting cache data.");
 
-               // TODO create String representation of all cache data
+               // Create XML/String representation of all cache data
                String stringData = null;
 
                try {
 
+                       DocumentBuilderFactory docFactory = DocumentBuilderFactory.newInstance();
+                       docFactory.setNamespaceAware(false);
+                       Document placeHolder = docFactory.newDocumentBuilder().newDocument();
+
+                       Element cacheNode = placeHolder.createElement("Cache");
+                       for (Entry<String, CacheEntry> entry : dataCache.entrySet()) {
+                               Element itemNode = placeHolder.createElement("Item");
+                               itemNode.setAttribute("key", entry.getKey());
+                               itemNode.setAttribute("value", entry.getValue().value);
+                               itemNode.setAttribute("expire", new Long(entry.getValue().expiration.getTime()).toString());
+                               cacheNode.appendChild(itemNode);
+                       }
+
+                       TransformerFactory factory = TransformerFactory.newInstance();
+                       DOMSource source = new DOMSource(cacheNode);
+                       Transformer transformer = factory.newTransformer();
+                       transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
+                       StringWriter stringWriter = new StringWriter();
+                       StreamResult result = new StreamResult(stringWriter);
+                       transformer.transform(source, result);
+                       stringData = stringWriter.toString().replaceAll(">\\s<", "><");
+                       log.debug("Dumping Raw Cache: " + stringData);
+
+               } catch (Exception e) {
+                       log.error("Error encoding cache data: " + e);
+                       throw new CacheException("Unable to cache data.");
+               }
+
+               try {
+
                        // Setup a gzipped data stream
                        ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
                        GZIPOutputStream compressedStream = new GZIPOutputStream(byteStream);
-                       DataOutputStream dataStream = new DataOutputStream(compressedStream);
+                       ObjectOutputStream dataStream = new ObjectOutputStream(compressedStream);
 
                        // Write data and HMAC to stream
                        Mac mac = Mac.getInstance(macAlgorithm);
                        mac.init(secret);
                        dataStream.write(mac.doFinal(stringData.getBytes()));
-                       dataStream.writeUTF(stringData);
+                       dataStream.writeObject(stringData);
 
                        // Flush
-                       dataStream.flush();
+                       // dataStream.flush();
                        compressedStream.flush();
                        compressedStream.finish();
                        byteStream.flush();
@@ -164,16 +351,14 @@ public class CookieCache extends BaseCache implements Cache {
 
                        // Base32 encode
                        String encodedData = Base32.encode(cacheBytes);
+                       log.debug("Dumping Encrypted/Encoded Cache: " + encodedData);
 
                        // Put into cookies
                        interleaveInCookies(encodedData);
 
-               } catch (KeyException e) {
-                       // TODO handle
-               } catch (GeneralSecurityException e) {
-                       // TODO handle
-               } catch (IOException e) {
-                       // TODO handle
+               } catch (Exception e) {
+                       log.error("Error encrypting cache data: " + e);
+                       throw new CacheException("Unable to cache data.");
                }
        }
 
@@ -182,45 +367,53 @@ public class CookieCache extends BaseCache implements Cache {
         */
        private void interleaveInCookies(String data) {
 
+               log.debug("Writing cache to cookies.");
+
                // Convert the String data to a list of cookies
-               List<Cookie> cookiesToResponse = new ArrayList<Cookie>();
+               Map<String, Cookie> cookiesToResponse = new HashMap<String, Cookie>();
                StringBuffer bufferredData = new StringBuffer(data);
+               int i = 1;
                while (bufferredData != null && bufferredData.length() > 0) {
                        Cookie cookie = null;
-                       String name = null;
-                       if (bufferredData.length() <= getCookieSpace()) {
+                       String name = NAME_PREFIX + getName() + ":" + i++;
+                       if (bufferredData.length() <= getCookieSpace(name)) {
                                cookie = new Cookie(name, bufferredData.toString());
                                bufferredData = null;
                        } else {
-                               cookie = new Cookie(name, bufferredData.substring(0, getCookieSpace() - 1));
-                               bufferredData.delete(0, getCookieSpace() - 1);
+                               cookie = new Cookie(name, bufferredData.substring(0, getCookieSpace(name) - 1));
+                               bufferredData.delete(0, getCookieSpace(name) - 1);
                        }
-                       cookiesToResponse.add(cookie);
+                       cookiesToResponse.put(cookie.getName(), cookie);
                }
 
-               // We have to null out cookies that are no longer needed
-               for (Cookie previousCookie : myCurrentCookies) {
-                       if (!cookiesToResponse.contains(previousCookie)) {
-                               cookiesToResponse.add(new Cookie(previousCookie.getName(), null));
+               // Expire cookies that we used previously but no longer need
+               for (Cookie currCookie : myCurrentCookies) {
+                       if (!cookiesToResponse.containsKey(currCookie.getName())) {
+                               currCookie.setMaxAge(0);
+                               currCookie.setValue(null);
+                               cookiesToResponse.put(currCookie.getName(), currCookie);
                        }
                }
 
                // Write our cookies to the response object
-               for (Cookie cookie : cookiesToResponse) {
+               for (Cookie cookie : cookiesToResponse.values()) {
                        response.addCookie(cookie);
                }
 
                // Update our cached copy of the cookies
-               myCurrentCookies = cookiesToResponse;
+               myCurrentCookies = cookiesToResponse.values();
        }
 
        /**
         * Returns the amount of value space available in cookies we create
         */
-       private int getCookieSpace() {
+       private int getCookieSpace(String cookieName) {
 
-               // TODO this needs to be better
-               return 3000;
+               // If we add other cookie variables, we would need to adjust this algorithm appropriately
+               StringBuffer used = new StringBuffer();
+               used.append("Set-Cookie: ");
+               used.append(cookieName + "=" + " ");
+               return CHUNK_SIZE - used.length() - 2;
        }
 
 }