Add optional, but on by default, check to ensure that IdP session cookie comes from...
[java-idp.git] / src / main / java / edu / internet2 / middleware / shibboleth / idp / session / impl / SessionManagerImpl.java
1 /*
2  * Copyright 2007 University Corporation for Advanced Internet Development, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.idp.session.impl;
18
19 import java.security.NoSuchAlgorithmException;
20 import java.security.SecureRandom;
21 import java.util.List;
22 import java.util.Vector;
23
24 import javax.crypto.KeyGenerator;
25
26 import org.apache.commons.ssl.util.Hex;
27 import org.joda.time.DateTime;
28 import org.opensaml.util.storage.ExpiringObject;
29 import org.opensaml.util.storage.StorageService;
30 import org.opensaml.xml.util.DatatypeHelper;
31 import org.slf4j.Logger;
32 import org.slf4j.LoggerFactory;
33 import org.slf4j.MDC;
34 import org.springframework.context.ApplicationContext;
35 import org.springframework.context.ApplicationContextAware;
36 import org.springframework.context.ApplicationEvent;
37 import org.springframework.context.ApplicationListener;
38
39 import edu.internet2.middleware.shibboleth.common.session.LoginEvent;
40 import edu.internet2.middleware.shibboleth.common.session.LogoutEvent;
41 import edu.internet2.middleware.shibboleth.common.session.SessionManager;
42 import edu.internet2.middleware.shibboleth.common.util.EventingMapBasedStorageService.AddEntryEvent;
43 import edu.internet2.middleware.shibboleth.common.util.EventingMapBasedStorageService.RemoveEntryEvent;
44 import edu.internet2.middleware.shibboleth.idp.session.Session;
45
46 /** Manager of IdP sessions. */
47 public class SessionManagerImpl implements SessionManager<Session>, ApplicationContextAware, ApplicationListener {
48
49     /** Class logger. */
50     private final Logger log = LoggerFactory.getLogger(SessionManagerImpl.class);
51
52     /** Spring context used to publish login and logout events. */
53     private ApplicationContext appCtx;
54
55     /** Generator used to create secret keys associated with the session. */
56     private KeyGenerator secretKeyGen;
57
58     /** Number of random bits within a session ID. */
59     private final int sessionIDSize = 32;
60
61     /** A {@link SecureRandom} PRNG to generate session IDs. */
62     private final SecureRandom prng = new SecureRandom();
63
64     /** Backing service used to store sessions. */
65     private StorageService<String, SessionManagerEntry> sessionStore;
66
67     /** Partition in which entries are stored. */
68     private String partition;
69
70     /** Lifetime, in milliseconds, of session. */
71     private long sessionLifetime;
72
73     /**
74      * Constructor.
75      * 
76      * @param storageService service used to store sessions
77      * @param lifetime lifetime, in milliseconds, of sessions
78      */
79     public SessionManagerImpl(StorageService<String, SessionManagerEntry> storageService, long lifetime) {
80         sessionStore = storageService;
81         partition = "session";
82         sessionLifetime = lifetime;
83
84         try {
85             secretKeyGen = KeyGenerator.getInstance("AES");
86         } catch (NoSuchAlgorithmException e) {
87             log.error("AES key generation is not supported", e);
88         }
89     }
90
91     /**
92      * Constructor.
93      * 
94      * @param storageService service used to store session
95      * @param storageParition partition in which sessions are stored
96      * @param lifetime lifetime, in milliseconds, of sessions
97      */
98     public SessionManagerImpl(StorageService<String, SessionManagerEntry> storageService, String storageParition,
99             long lifetime) {
100         sessionStore = storageService;
101         if (!DatatypeHelper.isEmpty(storageParition)) {
102             partition = DatatypeHelper.safeTrim(storageParition);
103         } else {
104             partition = "session";
105         }
106         sessionLifetime = lifetime;
107     }
108
109     /** {@inheritDoc} */
110     public Session createSession() {
111         // generate a random session ID
112         byte[] sid = new byte[sessionIDSize];
113         prng.nextBytes(sid);
114         String sessionID = Hex.encode(sid);
115
116         Session session = new SessionImpl(sessionID, secretKeyGen.generateKey(), sessionLifetime);
117         SessionManagerEntry sessionEntry = new SessionManagerEntry(session, sessionLifetime);
118         sessionStore.put(partition, sessionID, sessionEntry);
119
120         MDC.put("idpSessionId", sessionID);
121         log.trace("Created session {}", sessionID);
122         appCtx.publishEvent(new LoginEvent(session));
123         return session;
124     }
125
126     /** {@inheritDoc} */
127     public Session createSession(String principal) {
128         // generate a random session ID
129         byte[] sid = new byte[sessionIDSize];
130         prng.nextBytes(sid);
131         String sessionID = Hex.encode(sid);
132
133         MDC.put("idpSessionId", sessionID);
134
135         Session session = new SessionImpl(sessionID, secretKeyGen.generateKey(), sessionLifetime);
136         SessionManagerEntry sessionEntry = new SessionManagerEntry(session, sessionLifetime);
137         sessionStore.put(partition, sessionID, sessionEntry);
138         log.trace("Created session {}", sessionID);
139         return session;
140     }
141
142     /** {@inheritDoc} */
143     public void destroySession(String sessionID) {
144         if (sessionID == null) {
145             return;
146         }
147
148         sessionStore.remove(partition, sessionID);
149     }
150
151     /** {@inheritDoc} */
152     public Session getSession(String sessionID) {
153         if (sessionID == null) {
154             return null;
155         }
156
157         SessionManagerEntry sessionEntry = sessionStore.get(partition, sessionID);
158         if (sessionEntry == null) {
159             return null;
160         }
161
162         if (sessionEntry.isExpired()) {
163             destroySession(sessionEntry.getSessionId());
164             return null;
165         } else {
166             return sessionEntry.getSession();
167         }
168     }
169
170     /** {@inheritDoc} */
171     public boolean indexSession(Session session, String index) {
172         if (sessionStore.contains(partition, index)) {
173             return false;
174         }
175
176         SessionManagerEntry sessionEntry = sessionStore.get(partition, session.getSessionID());
177         if (sessionEntry == null) {
178             return false;
179         }
180
181         if (sessionEntry.getSessionIndexes().contains(index)) {
182             return true;
183         }
184
185         sessionEntry.getSessionIndexes().add(index);
186         sessionStore.put(partition, index, sessionEntry);
187         log.trace("Added index {} to session {}", index, session.getSessionID());
188         return true;
189     }
190
191     /** {@inheritDoc} */
192     public void onApplicationEvent(ApplicationEvent event) {
193         if (event instanceof AddEntryEvent) {
194             AddEntryEvent addEvent = (AddEntryEvent) event;
195             if (addEvent.getValue() instanceof SessionManagerEntry) {
196                 SessionManagerEntry sessionEntry = (SessionManagerEntry) addEvent.getValue();
197                 appCtx.publishEvent(new LoginEvent(sessionEntry.getSession()));
198             }
199         }
200
201         if (event instanceof RemoveEntryEvent) {
202             RemoveEntryEvent removeEvent = (RemoveEntryEvent) event;
203             if (removeEvent.getValue() instanceof SessionManagerEntry) {
204                 SessionManagerEntry sessionEntry = (SessionManagerEntry) removeEvent.getValue();
205                 appCtx.publishEvent(new LogoutEvent(sessionEntry.getSession()));
206             }
207         }
208     }
209
210     /** {@inheritDoc} */
211     public void removeSessionIndex(String index) {
212         SessionManagerEntry sessionEntry = sessionStore.remove(partition, index);
213         if (sessionEntry != null) {
214             log.trace("Removing index {} for session {}", index, sessionEntry.getSessionId());
215             sessionEntry.getSessionIndexes().remove(index);
216         }
217     }
218
219     /** {@inheritDoc} */
220     public void setApplicationContext(ApplicationContext applicationContext) {
221         ApplicationContext rootContext = applicationContext;
222         while (rootContext.getParent() != null) {
223             rootContext = rootContext.getParent();
224         }
225         appCtx = rootContext;
226     }
227
228     /** Session store entry. */
229     public class SessionManagerEntry implements ExpiringObject {
230
231         /** User's session. */
232         private Session userSession;
233
234         /** Indexes for this session. */
235         private List<String> indexes;
236
237         /** Time this entry expires. */
238         private DateTime expirationTime;
239
240         /**
241          * Constructor.
242          * 
243          * @param session user session
244          * @param lifetime lifetime of session
245          */
246         public SessionManagerEntry(Session session, long lifetime) {
247             userSession = session;
248             expirationTime = new DateTime().plus(lifetime);
249             indexes = new Vector<String>();
250             indexes.add(userSession.getSessionID());
251         }
252
253         /** {@inheritDoc} */
254         public DateTime getExpirationTime() {
255             return expirationTime;
256         }
257
258         /**
259          * Gets the user session.
260          * 
261          * @return user session
262          */
263         public Session getSession() {
264             return userSession;
265         }
266
267         /**
268          * Gets the ID of the user session.
269          * 
270          * @return ID of the user session
271          */
272         public String getSessionId() {
273             return userSession.getSessionID();
274         }
275
276         /**
277          * Gets the list of indexes for this session.
278          * 
279          * @return list of indexes for this session
280          */
281         public List<String> getSessionIndexes() {
282             return indexes;
283         }
284
285         /** {@inheritDoc} */
286         public boolean isExpired() {
287             return expirationTime.isBeforeNow();
288         }
289
290         /** {@inheritDoc} */
291         public void onExpire() {
292
293         }
294     }
295 }