001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapreduce.lib.db;
020    
021    import java.io.DataInput;
022    import java.io.DataOutput;
023    import java.io.IOException;
024    import java.sql.Connection;
025    import java.sql.DatabaseMetaData;
026    import java.sql.PreparedStatement;
027    import java.sql.ResultSet;
028    import java.sql.SQLException;
029    import java.sql.Statement;
030    import java.util.ArrayList;
031    import java.util.List;
032    
033    import org.apache.commons.logging.Log;
034    import org.apache.commons.logging.LogFactory;
035    import org.apache.hadoop.io.LongWritable;
036    import org.apache.hadoop.io.Writable;
037    import org.apache.hadoop.mapreduce.InputFormat;
038    import org.apache.hadoop.mapreduce.InputSplit;
039    import org.apache.hadoop.mapreduce.Job;
040    import org.apache.hadoop.mapreduce.JobContext;
041    import org.apache.hadoop.mapreduce.MRJobConfig;
042    import org.apache.hadoop.mapreduce.RecordReader;
043    import org.apache.hadoop.mapreduce.TaskAttemptContext;
044    import org.apache.hadoop.util.ReflectionUtils;
045    import org.apache.hadoop.classification.InterfaceAudience;
046    import org.apache.hadoop.classification.InterfaceStability;
047    import org.apache.hadoop.conf.Configurable;
048    import org.apache.hadoop.conf.Configuration;
049    /**
050     * A InputFormat that reads input data from an SQL table.
051     * <p>
052     * DBInputFormat emits LongWritables containing the record number as 
053     * key and DBWritables as value. 
054     * 
055     * The SQL query, and input class can be using one of the two 
056     * setInput methods.
057     */
058    @InterfaceAudience.Public
059    @InterfaceStability.Stable
060    public class DBInputFormat<T extends DBWritable>
061        extends InputFormat<LongWritable, T> implements Configurable {
062    
063      private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
064      
065      private String dbProductName = "DEFAULT";
066    
067      /**
068       * A Class that does nothing, implementing DBWritable
069       */
070      @InterfaceStability.Evolving
071      public static class NullDBWritable implements DBWritable, Writable {
072        @Override
073        public void readFields(DataInput in) throws IOException { }
074        @Override
075        public void readFields(ResultSet arg0) throws SQLException { }
076        @Override
077        public void write(DataOutput out) throws IOException { }
078        @Override
079        public void write(PreparedStatement arg0) throws SQLException { }
080      }
081      
082      /**
083       * A InputSplit that spans a set of rows
084       */
085      @InterfaceStability.Evolving
086      public static class DBInputSplit extends InputSplit implements Writable {
087    
088        private long end = 0;
089        private long start = 0;
090    
091        /**
092         * Default Constructor
093         */
094        public DBInputSplit() {
095        }
096    
097        /**
098         * Convenience Constructor
099         * @param start the index of the first row to select
100         * @param end the index of the last row to select
101         */
102        public DBInputSplit(long start, long end) {
103          this.start = start;
104          this.end = end;
105        }
106    
107        /** {@inheritDoc} */
108        public String[] getLocations() throws IOException {
109          // TODO Add a layer to enable SQL "sharding" and support locality
110          return new String[] {};
111        }
112    
113        /**
114         * @return The index of the first row to select
115         */
116        public long getStart() {
117          return start;
118        }
119    
120        /**
121         * @return The index of the last row to select
122         */
123        public long getEnd() {
124          return end;
125        }
126    
127        /**
128         * @return The total row count in this split
129         */
130        public long getLength() throws IOException {
131          return end - start;
132        }
133    
134        /** {@inheritDoc} */
135        public void readFields(DataInput input) throws IOException {
136          start = input.readLong();
137          end = input.readLong();
138        }
139    
140        /** {@inheritDoc} */
141        public void write(DataOutput output) throws IOException {
142          output.writeLong(start);
143          output.writeLong(end);
144        }
145      }
146    
147      private String conditions;
148    
149      private Connection connection;
150    
151      private String tableName;
152    
153      private String[] fieldNames;
154    
155      private DBConfiguration dbConf;
156    
157      /** {@inheritDoc} */
158      public void setConf(Configuration conf) {
159    
160        dbConf = new DBConfiguration(conf);
161    
162        try {
163          getConnection();
164    
165          DatabaseMetaData dbMeta = connection.getMetaData();
166          this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
167        }
168        catch (Exception ex) {
169          throw new RuntimeException(ex);
170        }
171    
172        tableName = dbConf.getInputTableName();
173        fieldNames = dbConf.getInputFieldNames();
174        conditions = dbConf.getInputConditions();
175      }
176    
177      public Configuration getConf() {
178        return dbConf.getConf();
179      }
180      
181      public DBConfiguration getDBConf() {
182        return dbConf;
183      }
184    
185      public Connection getConnection() {
186        try {
187          if (null == this.connection) {
188            // The connection was closed; reinstantiate it.
189            this.connection = dbConf.getConnection();
190            this.connection.setAutoCommit(false);
191            this.connection.setTransactionIsolation(
192                Connection.TRANSACTION_SERIALIZABLE);
193          }
194        } catch (Exception e) {
195          throw new RuntimeException(e);
196        }
197        return connection;
198      }
199    
200      public String getDBProductName() {
201        return dbProductName;
202      }
203    
204      protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
205          Configuration conf) throws IOException {
206    
207        @SuppressWarnings("unchecked")
208        Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
209        try {
210          // use database product name to determine appropriate record reader.
211          if (dbProductName.startsWith("ORACLE")) {
212            // use Oracle-specific db reader.
213            return new OracleDBRecordReader<T>(split, inputClass,
214                conf, getConnection(), getDBConf(), conditions, fieldNames,
215                tableName);
216          } else if (dbProductName.startsWith("MYSQL")) {
217            // use MySQL-specific db reader.
218            return new MySQLDBRecordReader<T>(split, inputClass,
219                conf, getConnection(), getDBConf(), conditions, fieldNames,
220                tableName);
221          } else {
222            // Generic reader.
223            return new DBRecordReader<T>(split, inputClass,
224                conf, getConnection(), getDBConf(), conditions, fieldNames,
225                tableName);
226          }
227        } catch (SQLException ex) {
228          throw new IOException(ex.getMessage());
229        }
230      }
231    
232      /** {@inheritDoc} */
233      @SuppressWarnings("unchecked")
234      public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
235          TaskAttemptContext context) throws IOException, InterruptedException {  
236    
237        return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
238      }
239    
240      /** {@inheritDoc} */
241      public List<InputSplit> getSplits(JobContext job) throws IOException {
242    
243        ResultSet results = null;  
244        Statement statement = null;
245        try {
246          statement = connection.createStatement();
247    
248          results = statement.executeQuery(getCountQuery());
249          results.next();
250    
251          long count = results.getLong(1);
252          int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
253          long chunkSize = (count / chunks);
254    
255          results.close();
256          statement.close();
257    
258          List<InputSplit> splits = new ArrayList<InputSplit>();
259    
260          // Split the rows into n-number of chunks and adjust the last chunk
261          // accordingly
262          for (int i = 0; i < chunks; i++) {
263            DBInputSplit split;
264    
265            if ((i + 1) == chunks)
266              split = new DBInputSplit(i * chunkSize, count);
267            else
268              split = new DBInputSplit(i * chunkSize, (i * chunkSize)
269                  + chunkSize);
270    
271            splits.add(split);
272          }
273    
274          connection.commit();
275          return splits;
276        } catch (SQLException e) {
277          throw new IOException("Got SQLException", e);
278        } finally {
279          try {
280            if (results != null) { results.close(); }
281          } catch (SQLException e1) {}
282          try {
283            if (statement != null) { statement.close(); }
284          } catch (SQLException e1) {}
285    
286          closeConnection();
287        }
288      }
289    
290      /** Returns the query for getting the total number of rows, 
291       * subclasses can override this for custom behaviour.*/
292      protected String getCountQuery() {
293        
294        if(dbConf.getInputCountQuery() != null) {
295          return dbConf.getInputCountQuery();
296        }
297        
298        StringBuilder query = new StringBuilder();
299        query.append("SELECT COUNT(*) FROM " + tableName);
300    
301        if (conditions != null && conditions.length() > 0)
302          query.append(" WHERE " + conditions);
303        return query.toString();
304      }
305    
306      /**
307       * Initializes the map-part of the job with the appropriate input settings.
308       * 
309       * @param job The map-reduce job
310       * @param inputClass the class object implementing DBWritable, which is the 
311       * Java object holding tuple fields.
312       * @param tableName The table to read data from
313       * @param conditions The condition which to select data with, 
314       * eg. '(updated > 20070101 AND length > 0)'
315       * @param orderBy the fieldNames in the orderBy clause.
316       * @param fieldNames The field names in the table
317       * @see #setInput(Job, Class, String, String)
318       */
319      public static void setInput(Job job, 
320          Class<? extends DBWritable> inputClass,
321          String tableName,String conditions, 
322          String orderBy, String... fieldNames) {
323        job.setInputFormatClass(DBInputFormat.class);
324        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
325        dbConf.setInputClass(inputClass);
326        dbConf.setInputTableName(tableName);
327        dbConf.setInputFieldNames(fieldNames);
328        dbConf.setInputConditions(conditions);
329        dbConf.setInputOrderBy(orderBy);
330      }
331      
332      /**
333       * Initializes the map-part of the job with the appropriate input settings.
334       * 
335       * @param job The map-reduce job
336       * @param inputClass the class object implementing DBWritable, which is the 
337       * Java object holding tuple fields.
338       * @param inputQuery the input query to select fields. Example : 
339       * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
340       * @param inputCountQuery the input query that returns 
341       * the number of records in the table. 
342       * Example : "SELECT COUNT(f1) FROM Mytable"
343       * @see #setInput(Job, Class, String, String, String, String...)
344       */
345      public static void setInput(Job job,
346          Class<? extends DBWritable> inputClass,
347          String inputQuery, String inputCountQuery) {
348        job.setInputFormatClass(DBInputFormat.class);
349        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
350        dbConf.setInputClass(inputClass);
351        dbConf.setInputQuery(inputQuery);
352        dbConf.setInputCountQuery(inputCountQuery);
353      }
354    
355      protected void closeConnection() {
356        try {
357          if (null != this.connection) {
358            this.connection.close();
359            this.connection = null;
360          }
361        } catch (SQLException sqlE) {
362          LOG.debug("Exception on close", sqlE);
363        }
364      }
365    }