7cf139eb6402058a3e29b9228057b44bb825dffa
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / aa / attrresolv / provider / JDBCDataConnector.java
1 /*
2  * Copyright (c) 2003 National Research Council of Canada
3  *
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without 
7  * restriction, including without limitation the rights to use, 
8  * copy, modify, merge, publish, distribute, sublicense, and/or 
9  * sell copies of the Software, and to permit persons to whom the 
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice shall be 
13  * included in all copies or substantial portions of the Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 
22  * OTHER DEALINGS IN THE SOFTWARE.
23  *
24  */
25
26 package edu.internet2.middleware.shibboleth.aa.attrresolv.provider;
27
28 import java.io.PrintWriter;
29 import java.lang.reflect.Constructor;
30 import java.security.Principal;
31 import java.sql.Connection;
32 import java.sql.ResultSet;
33 import java.sql.ResultSetMetaData;
34 import java.sql.SQLException;
35 import java.sql.Statement;
36 import java.util.Iterator;
37 import java.util.Properties;
38 import java.util.regex.Matcher;
39 import java.util.regex.Pattern;
40
41 import javax.naming.NamingEnumeration;
42 import javax.naming.NamingException;
43 import javax.naming.directory.Attribute;
44 import javax.naming.directory.Attributes;
45 import javax.naming.directory.BasicAttribute;
46 import javax.naming.directory.BasicAttributes;
47 import javax.sql.DataSource;
48
49 import org.apache.commons.dbcp.ConnectionFactory;
50 import org.apache.commons.dbcp.DriverManagerConnectionFactory;
51 import org.apache.commons.dbcp.PoolableConnectionFactory;
52 import org.apache.commons.dbcp.PoolingDataSource;
53 import org.apache.commons.pool.ObjectPool;
54 import org.apache.commons.pool.impl.GenericObjectPool;
55 import org.apache.log4j.Logger;
56 import org.apache.log4j.Priority;
57 import org.w3c.dom.Element;
58 import org.w3c.dom.Node;
59 import org.w3c.dom.NodeList;
60
61 import edu.internet2.middleware.shibboleth.aa.attrresolv.AttributeResolver;
62 import edu.internet2.middleware.shibboleth.aa.attrresolv.DataConnectorPlugIn;
63 import edu.internet2.middleware.shibboleth.aa.attrresolv.Dependencies;
64 import edu.internet2.middleware.shibboleth.aa.attrresolv.ResolutionPlugInException;
65 import edu.internet2.middleware.shibboleth.aa.attrresolv.ResolverAttribute;
66
67 /*
68  * Built at the Canada Institute for Scientific and Technical Information (CISTI 
69  * <ahref="http://www.cisti-icist.nrc-cnrc.gc.ca/">http://www.cisti-icist.nrc-cnrc.gc.ca/</a>, 
70  * the National Research Council Canada 
71  * (NRC <a href="http://www.nrc-cnrc.gc.ca/">http://www.nrc-cnrc.gc.ca/</a>)
72  * by David Dearman, COOP student from Dalhousie University,
73  * under the direction of Glen Newton, Head research (IT)
74  * <ahref="mailto:glen.newton@nrc-cnrc.gc.ca">glen.newton@nrc-cnrc.gc.ca</a>. 
75  */
76
77 /**
78  * Data Connector that uses JDBC to access user attributes stored in databases.
79  *
80  * @author David Dearman (dearman@cs.dal.ca)
81  */
82
83 public class JDBCDataConnector extends BaseResolutionPlugIn implements DataConnectorPlugIn {
84
85         private static Logger log = Logger.getLogger(JDBCDataConnector.class.getName());
86         protected Properties props = new Properties();
87         protected String searchVal;
88         protected DataSource dataSource;
89         protected JDBCAttributeExtractor extractor;
90
91         public JDBCDataConnector(Element element) throws ResolutionPlugInException {
92
93                 super(element);
94
95                 //Get the query string
96                 NodeList searchNode = element.getElementsByTagNameNS(AttributeResolver.resolverNamespace, "Search");
97                 searchVal = ((Element) searchNode.item(0)).getAttribute("query");
98
99                 if (searchVal == null || searchVal.equals("")) {
100                         Node tnode = searchNode.item(0).getFirstChild();
101                         if (tnode != null && tnode.getNodeType() == Node.TEXT_NODE) {
102                                 searchVal = tnode.getNodeValue();
103                         }
104                         if (searchVal == null || searchVal.equals("")) {
105                                 log.error("Search requires a specified query field");
106                                 //TODO stinky error message
107                                 throw new ResolutionPlugInException("mySQLDataConnection requires a \"Search\" specification");
108                         }
109                 } else {
110                         log.debug("Search Query: (" + searchVal + ")");
111                 }
112
113                 //Instantiate an attribute extractor, using the default if none is specified
114                 String aeClassName = ((Element) searchNode.item(0)).getAttribute("attributeExtractor");
115                 if (aeClassName == null || aeClassName.equals("")) {
116                         aeClassName = DefaultAE.class.getName();
117                 }
118                 try {
119                         Class aeClass = Class.forName(aeClassName);
120                         Constructor constructor = aeClass.getConstructor(null);
121                         extractor = (JDBCAttributeExtractor) constructor.newInstance(null);
122                         log.debug("Supplied attributeExtractor class loaded.");
123
124                 } catch (ClassNotFoundException e) {
125                         log.error("The supplied Attribute Extractor class could not be found: " + e);
126                         throw new ResolutionPlugInException(
127                                 "The supplied Attribute Extractor class could not be found: " + e.getMessage());
128                 } catch (Exception e) {
129                         log.error("Unable to instantiate Attribute Extractor implementation: " + e);
130                         throw new ResolutionPlugInException(
131                                 "Unable to instantiate Attribute Extractor implementation: " + e.getMessage());
132                 }
133
134                 //Grab all other properties
135                 NodeList propertiesNode = element.getElementsByTagNameNS(AttributeResolver.resolverNamespace, "Property");
136                 for (int i = 0; propertiesNode.getLength() > i; i++) {
137                         Element property = (Element) propertiesNode.item(i);
138                         String propertiesName = property.getAttribute("name");
139                         String propertiesValue = property.getAttribute("value");
140
141                         if (propertiesName != null
142                                 && !propertiesName.equals("")
143                                 && propertiesValue != null
144                                 && !propertiesValue.equals("")) {
145                                 props.setProperty(propertiesName, propertiesValue);
146                                 log.debug("Property: (" + propertiesName + ")");
147                                 log.debug("   Value: (" + propertiesValue + ")");
148                         } else {
149                                 log.error("Property is malformed.");
150                                 throw new ResolutionPlugInException("Property is malformed.");
151                         }
152                 }
153
154                 if (props.getProperty("dbURL") == null) {
155                         log.error("JDBC connection requires a dbURL property");
156                         throw new ResolutionPlugInException("JDBCDataConnection requires a \"dbURL\" property");
157                 }
158
159                 //Load the supplied JDBC driver
160                 loadDriver((String) props.get("dbDriver"));
161                 
162                 //Setup the Pool
163                 GenericObjectPool genericObjectPool = new GenericObjectPool(null);
164
165                 try {
166                         if (props.getProperty("maxActiveConnections") != null) {
167                                 genericObjectPool.setMaxActive(Integer.parseInt(props.getProperty("maxActiveConnections")));
168                         }
169                         if (props.getProperty("maxIdleConnections") != null) {
170                                 genericObjectPool.setMaxIdle(Integer.parseInt(props.getProperty("maxIdleConnections")));
171                         }
172                 } catch (NumberFormatException e) {
173                         log.error("Malformed pooling configuration settings: using defaults.");
174                 }
175                 genericObjectPool.setWhenExhaustedAction(GenericObjectPool.WHEN_EXHAUSTED_BLOCK);
176
177                 ObjectPool connPool = genericObjectPool;
178                 ConnectionFactory connFactory = null;
179                 PoolableConnectionFactory poolConnFactory = null;
180
181                 try {
182                         connFactory = new DriverManagerConnectionFactory(props.getProperty("dbURL"), null);
183                         log.debug("Connection factory initialized.");
184                 } catch (Exception ex) {
185                         log.error(
186                                 "Connection factory couldn't be initialized, ensure database URL, username and password are correct.");
187                         throw new ResolutionPlugInException("Connection facotry couldn't be initialized: " + ex.getMessage());
188                 }
189
190                 try {
191                         poolConnFactory = new PoolableConnectionFactory(connFactory, connPool, null, null, false, true);
192                 } catch (Exception ex) {
193                         log.debug("Poolable connection factory error");
194                 }
195
196                 dataSource = new PoolingDataSource(connPool);
197                 try {
198                         dataSource.setLogWriter(
199                                 new Log4jPrintWriter(Logger.getLogger(JDBCDataConnector.class.getName() + ".Pool"), Priority.DEBUG));
200                 } catch (SQLException e) {
201                         log.error("Coudn't setup logger for database connection pool.");
202                 }
203         }
204
205         protected String substitute(String source, String pattern, boolean quote, Dependencies depends) {
206                 Matcher m = Pattern.compile(pattern).matcher(source);
207                 while (m.find()) {
208                         String field = source.substring(m.start() + 1, m.end() - 1);
209                         if (field != null && field.length() > 0) {
210                                 StringBuffer buf = new StringBuffer();
211
212                                 //Look for an attribute dependency.
213                                 ResolverAttribute dep = depends.getAttributeResolution(field);
214                                 if (dep != null) {
215                                         Iterator iter = dep.getValues();
216                                         while (iter.hasNext()) {
217                                                 if (buf.length() > 0)
218                                                         buf = buf.append(',');
219                                                 if (quote)
220                                                         buf = buf.append("'");
221                                                 buf = buf.append(iter.next());
222                                                 if (quote)
223                                                         buf = buf.append("'");
224                                         }
225                                 }
226
227                                 //If no values found, cycle over the connectors.
228                                 Iterator connDeps = connectorDependencyIds.iterator();
229                                 while (buf.length() == 0 && connDeps.hasNext()) {
230                                         Attributes attrs = depends.getConnectorResolution((String) connDeps.next());
231                                         if (attrs != null) {
232                                                 Attribute attr = attrs.get(field);
233                                                 if (attr != null) {
234                                                         try {
235                                                                 NamingEnumeration vals = attr.getAll();
236                                                                 while (vals.hasMore()) {
237                                                                         if (buf.length() > 0)
238                                                                                 buf = buf.append(',');
239                                                                         if (quote)
240                                                                                 buf = buf.append("'");
241                                                                         buf = buf.append(vals.next());
242                                                                         if (quote)
243                                                                                 buf = buf.append("'");
244                                                                 }
245                                                         } catch (NamingException e) {
246                                                                 // Auto-generated catch block
247                                                         }
248                                                 }
249                                         }
250                                 }
251
252                                 if (buf.length() == 0) {
253                                         log.warn(
254                                                 "Unable to find any values to substitute in query for "
255                                                         + field
256                                                         + ", so using the empty string");
257                                 }
258                                 source = source.replaceAll(m.group(), buf.toString());
259                                 m.reset(source);
260                         }
261                 }
262                 return source;
263         }
264
265         public Attributes resolve(Principal principal, String requester, Dependencies depends)
266                 throws ResolutionPlugInException {
267
268                 log.debug("Resolving connector: (" + getId() + ")");
269                 log.debug(getId() + " resolving for principal: (" + principal.getName() + ")");
270                 log.debug("The query string before inserting substitutions: " + searchVal);
271
272                 //Replaces %PRINCIPAL% in the query string with its value
273                 String convertedSearchVal = searchVal.replaceAll("%PRINCIPAL%", principal.getName());
274                 convertedSearchVal = convertedSearchVal.replaceAll("@PRINCIPAL@", "'" + principal.getName() + "'");
275
276                 //Find all delimited substitutions and replace with the named attribute value(s).
277                 convertedSearchVal = substitute(convertedSearchVal, "%.+%", false, depends);
278                 convertedSearchVal = substitute(convertedSearchVal, "@.+@", true, depends);
279
280                 //Replace any escaped substitution delimiters.
281                 convertedSearchVal = convertedSearchVal.replaceAll("\\%", "%");
282                 convertedSearchVal = convertedSearchVal.replaceAll("\\@", "@");
283
284                 log.debug("The query string after inserting substitutions: " + convertedSearchVal);
285
286                 /**
287                  * Retrieves a connection from the connection pool
288                  */
289                 Connection conn = null;
290                 try {
291                         conn = dataSource.getConnection();
292                         log.debug("Connection retrieved from pool");
293                 } catch (Exception e) {
294                         log.error("Unable to fetch a connection from the pool");
295                         throw new ResolutionPlugInException("Unable to fetch a connection from the pool: " + e.getMessage());
296                 }
297                 if (conn == null) {
298                         log.error("Pool didn't return a propertly initialized connection.");
299                         throw new ResolutionPlugInException("Pool didn't return a propertly initialized connection.");
300                 }
301
302                 ResultSet rs = null;
303                 try {
304                         //Gets the results set for the query
305                         rs = executeQuery(conn, convertedSearchVal);
306                         if (!rs.next())
307                                 return new BasicAttributes();
308
309                 } catch (SQLException e) {
310                         log.error("An ERROR occured while executing the query");
311                         throw new ResolutionPlugInException("An ERROR occured while executing the query: " + e.getMessage());
312                 }
313
314                 try {
315                         return extractor.extractAttributes(rs);
316
317                 } catch (JDBCAttributeExtractorException e) {
318                         log.error("An ERROR occured while extracting attributes from result set");
319                         throw new ResolutionPlugInException(
320                                 "An ERROR occured while extracting attributes from result set: " + e.getMessage());
321                 } finally {
322                         try {
323                                 rs.close();
324                         } catch (SQLException e) {
325                                 log.error("An error occured while closing the result set: " + e);
326                                 throw new ResolutionPlugInException("An error occured while closing the result set: " + e);
327                         }
328
329                         try {
330                                 conn.close();
331                         } catch (SQLException e) {
332                                 log.error("An error occured while closing the database connection: " + e);
333                                 throw new ResolutionPlugInException("An error occured while closing the database connection: " + e);
334                         }
335                 }
336         }
337
338         /** 
339          * Loads the driver used to access the database
340          * @param driver The driver used to access the database
341          * @throws ResolutionPlugInException If there is a failure to load the driver
342          */
343         public void loadDriver(String driver) throws ResolutionPlugInException {
344                 try {
345                         Class.forName(driver).newInstance();
346                         log.debug("Loading JDBC driver: " + driver);
347                 } catch (Exception e) {
348                         log.error("An error loading database driver: " + e);
349                         throw new ResolutionPlugInException(
350                                 "An IllegalAccessException occured while loading database driver: " + e.getMessage());
351                 }
352                 log.debug("Driver loaded.");
353         }
354
355         /**
356          * Execute the users query
357          * @param query The query the user wishes to execute
358          * @return The result of the users <code>query</code>
359          * @return null if an error occurs during execution
360          * @throws SQLException If an error occurs while executing the query
361         */
362         public ResultSet executeQuery(Connection conn, String query) throws SQLException {
363                 log.debug("Users Query: " + query);
364                 Statement stmt = conn.createStatement();
365                 return stmt.executeQuery(query);
366         }
367
368         private class Log4jPrintWriter extends PrintWriter {
369
370                 private Priority level;
371                 private Logger logger;
372                 private StringBuffer text = new StringBuffer("");
373
374                 private Log4jPrintWriter(Logger logger, org.apache.log4j.Priority level) {
375                         super(System.err);
376                         this.level = level;
377                         this.logger = logger;
378                 }
379
380                 public void close() {
381                         flush();
382                 }
383
384                 public void flush() {
385                         if (!text.toString().equals("")) {
386                                 logger.log(level, text.toString());
387                                 text.setLength(0);
388                         }
389                 }
390
391                 public void print(boolean b) {
392                         text.append(b);
393                 }
394
395                 public void print(char c) {
396                         text.append(c);
397                 }
398
399                 public void print(char[] s) {
400                         text.append(s);
401                 }
402
403                 public void print(double d) {
404                         text.append(d);
405                 }
406
407                 public void print(float f) {
408                         text.append(f);
409                 }
410
411                 public void print(int i) {
412                         text.append(i);
413                 }
414
415                 public void print(long l) {
416                         text.append(l);
417                 }
418
419                 public void print(Object obj) {
420                         text.append(obj);
421                 }
422
423                 public void print(String s) {
424                         text.append(s);
425                 }
426
427                 public void println() {
428                         if (!text.toString().equals("")) {
429                                 logger.log(level, text.toString());
430                                 text.setLength(0);
431                         }
432                 }
433
434                 public void println(boolean x) {
435                         text.append(x);
436                         logger.log(level, text.toString());
437                         text.setLength(0);
438                 }
439
440                 public void println(char x) {
441                         text.append(x);
442                         logger.log(level, text.toString());
443                         text.setLength(0);
444                 }
445
446                 public void println(char[] x) {
447                         text.append(x);
448                         logger.log(level, text.toString());
449                         text.setLength(0);
450                 }
451
452                 public void println(double x) {
453                         text.append(x);
454                         logger.log(level, text.toString());
455                         text.setLength(0);
456                 }
457
458                 public void println(float x) {
459                         text.append(x);
460                         logger.log(level, text.toString());
461                         text.setLength(0);
462                 }
463
464                 public void println(int x) {
465                         text.append(x);
466                         logger.log(level, text.toString());
467                         text.setLength(0);
468                 }
469
470                 public void println(long x) {
471                         text.append(x);
472                         logger.log(level, text.toString());
473                         text.setLength(0);
474                 }
475
476                 public void println(Object x) {
477                         text.append(x);
478                         logger.log(level, text.toString());
479                         text.setLength(0);
480                 }
481
482                 public void println(String x) {
483                         text.append(x);
484                         logger.log(level, text.toString());
485                         text.setLength(0);
486                 }
487         }
488 }
489
490 /**
491  * The default attribute extractor. 
492  */
493
494 class DefaultAE implements JDBCAttributeExtractor {
495
496         private static Logger log = Logger.getLogger(DefaultAE.class.getName());
497
498         // Constructor
499         public DefaultAE() {
500         }
501
502         /**
503          * Method of extracting the attributes from the supplied result set.
504          *
505          * @param ResultSet The result set from the query which contains the attributes
506          * @return BasicAttributes as objects containing all the attributes
507          * @throws JDBCAttributeExtractorException If there is a complication in retrieving the attributes
508          */
509         public BasicAttributes extractAttributes(ResultSet rs) throws JDBCAttributeExtractorException {
510                 BasicAttributes attributes = new BasicAttributes();
511
512                 log.debug("Using default Attribute Extractor");
513
514                 try {
515                         ResultSetMetaData rsmd = rs.getMetaData();
516                         int numColumns = rsmd.getColumnCount();
517                         log.debug("Number of returned columns: " + numColumns);
518
519                         for (int i = 1; i <= numColumns; i++) {
520                                 String columnName = rsmd.getColumnName(i);
521                                 String columnType = rsmd.getColumnTypeName(i);
522                                 Object columnValue = rs.getObject(columnName);
523                                 log.debug(
524                                         "("
525                                                 + i
526                                                 + ". ColumnType = "
527                                                 + columnType
528                                                 + ") "
529                                                 + columnName
530                                                 + " -> "
531                                                 + (columnValue != null ? columnValue.toString() : "(null)"));
532                                 attributes.put(new BasicAttribute(columnName, columnValue));
533                         }
534                 } catch (SQLException e) {
535                         log.error("An ERROR occured while retrieving result set meta data");
536                         throw new JDBCAttributeExtractorException(
537                                 "An ERROR occured while retrieving result set meta data: " + e.getMessage());
538                 }
539
540                 // Check for multiple rows.
541                 try {
542                         if (rs.next())
543                                 throw new JDBCAttributeExtractorException("Query returned more than one row.");
544                 } catch (SQLException e) {
545                         //TODO don't squelch this error!!!
546                 }
547
548                 return attributes;
549         }
550 }