首页 > 技术文章 > SPark SQL 从 DB 读取数据方法和方式

TendToBigData 2017-01-17 21:47 原文

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;

/**
 * JDBC数据源
 * 
 * @author Administrator
 *
 */
public class JDBCDataSource {

	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("JDBCDataSource").setMaster("local");
		JavaSparkContext sc = new JavaSparkContext(conf);
		SQLContext sqlContext = new SQLContext(sc);
		
		// 方法1、分别将mysql中两张表的数据加载为DataFrame
		/*
		 * Map<String, String> options = new HashMap<String, String>();
		 * options.put("url", "jdbc:mysql://hadoop1:3306/testdb");
		 * options.put("driver", "com.mysql.jdbc.Driver"); 
		 * options.put("user","spark");
		 * options.put("password", "spark2016");
		 * options.put("dbtable", "student_info"); 
		 * DataFrame studentInfosDF = sqlContext.read().format("jdbc").options(options).load();
		 * 
		 * options.put("dbtable", "student_score"); 
		 * DataFrame studentScoresDF = sqlContext.read().format("jdbc") .options(options).load();
		 */
		// 方法2、分别将mysql中两张表的数据加载为DataFrame
		DataFrameReader reader = sqlContext.read().format("jdbc");
		reader.option("url", "jdbc:mysql://hadoop1:3306/testdb");
		reader.option("dbtable", "student_info");
		reader.option("driver", "com.mysql.jdbc.Driver");
		reader.option("user", "spark");
		reader.option("password", "spark2016");
		DataFrame studentInfosDF = reader.load();

		reader.option("dbtable", "student_score");
		DataFrame studentScoresDF = reader.load();
		// 将两个DataFrame转换为JavaPairRDD,执行join操作
		
		
		studentInfosDF.registerTempTable("studentInfos");
		studentScoresDF.registerTempTable("studentScores");
		
		String sql = "SELECT studentInfos.name,age,score "
				+ "		FROM studentInfos JOIN studentScores"
				+ "		 ON (studentScores.name = studentInfos.name)"
				+ "	 WHERE studentScores.score > 80";
		
		DataFrame sql2 = sqlContext.sql(sql);
		sql2.show();
		
		/*JavaPairRDD<String, Tuple2<Integer, Integer>> studentsRDD =
		studentInfosDF.javaRDD().mapToPair(new PairFunction<Row, String, Integer>() {

					private static final long serialVersionUID = 1L;

					@Override
					public Tuple2<String, Integer> call(Row row) throws Exception {
						return new Tuple2<String, Integer>(row.getString(0),
								Integer.valueOf(String.valueOf(row.get(1))));
					}

				}).join(studentScoresDF.javaRDD().mapToPair(new PairFunction<Row, String, Integer>() {

							private static final long serialVersionUID = 1L;

							@Override
							public Tuple2<String, Integer> call(Row row) throws Exception {
								return new Tuple2<String, Integer>(String.valueOf(row.get(0)),
										Integer.valueOf(String.valueOf(row.get(1))));
							}

						}));

		JavaRDD<Row> studentRowsRDD = studentsRDD.map(
				new Function<Tuple2<String, Tuple2<Integer, Integer>>, Row>() {
					private static final long serialVersionUID = 1L;

					@Override
					public Row call(Tuple2<String, Tuple2<Integer, Integer>> tuple) throws Exception {
						return RowFactory.create(tuple._1, tuple._2._1, tuple._2._2);
					}
				});

		JavaRDD<Row> filteredStudentRowsRDD = studentRowsRDD.filter(

				new Function<Row, Boolean>() {

					private static final long serialVersionUID = 1L;

					@Override
					public Boolean call(Row row) throws Exception {
						if (row.getInt(2) > 80) {
							return true;
						}
						return false;
					}

				});

		List<StructField> structFields = new ArrayList<StructField>();
		structFields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
		structFields.add(DataTypes.createStructField("age", DataTypes.IntegerType, true));
		structFields.add(DataTypes.createStructField("score", DataTypes.IntegerType, true));
		StructType structType = DataTypes.createStructType(structFields);

		DataFrame studentsDF = sqlContext.createDataFrame(filteredStudentRowsRDD, structType);

		Row[] rows = studentsDF.collect();
		for (Row row : rows) {
			System.out.println(row);
		}
		
		studentsDF.javaRDD().foreach(new VoidFunction<Row>() {

			private static final long serialVersionUID = 1L;

			@Override
			public void call(Row row) throws Exception {
				String sql = "insert into good_student_info values(" + "'" + String.valueOf(row.getString(0)) + "',"
						+ Integer.valueOf(String.valueOf(row.get(1))) + ","
						+ Integer.valueOf(String.valueOf(row.get(2))) + ")";

				Class.forName("com.mysql.jdbc.Driver");

				Connection conn = null;
				Statement stmt = null;
				try {
					conn = DriverManager.getConnection("jdbc:mysql://hadoop1:3306/testdb", "spark", "spark2016");
					stmt = conn.createStatement();
					stmt.executeUpdate(sql);
				} catch (Exception e) {
					e.printStackTrace();
				} finally {
					if (stmt != null) {
						stmt.close();
					}
					if (conn != null) {
						conn.close();
					}
				}
			}

		});*/

		/**
		 * 将SparkContext 关闭,释放资源
		 */
		sc.close();
	}

}

推荐阅读