SparkSQL中UDAF案例分析
1、统计单词的个数
package com.bynear.spark_sql; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; public class Spark_UDAF extends UserDefinedAggregateFunction {/** * inputSchema指的是输入的数据类型 * * @return */ @Override public StructType inputSchema() {ArrayList<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("str", DataTypes.StringType, true)); return DataTypes.createStructType(fields); }/** * bufferSchema指的是 中间进行聚合时 所处理的数据类型 * * @return */ @Override public StructType bufferSchema() {ArrayList<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("count", DataTypes.IntegerType, true)); return DataTypes.createStructType(fields); }/** * dataType指的是函数返回值的类型 * * @return */ @Override public DataType dataType() {return DataTypes.IntegerType; }/** * 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。 * * @return */ @Override public boolean deterministic() {return true; }/** * 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer * 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2, * 不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" * 为每个分组的数据执行初始化操作 * * @param buffer */ @Override public void initialize(MutableAggregationBuffer buffer) {buffer.update(0, 0); }/** * 用输入数据input更新buffer值,类似于combineByKey * 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算 * * @param buffer * @param input */ @Override public void update(MutableAggregationBuffer buffer, Row input) {buffer.update(0, Integer.valueOf(buffer.getAs(0).toString()) + 1); }/** * 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节 * 由于spark是分布式的,所以每一分组的数据,可能会在不同的节点上进行局部聚合,就是update * 但是 最后一个分组,在各个节点上的聚合值,要进行merge 也就是合并 * * @param buffer1 * @param buffer2 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {buffer1.update(0, Integer.valueOf(buffer1.getAs(0).toString()) + Integer.valueOf(buffer2.getAs(0).toString())); }/** * 只的是 一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值 * * @param buffer * @return */ @Override public Object evaluate(Row buffer) {return buffer.getInt(0); } }
package com.bynear.spark_sql; import com.clearspring.analytics.util.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.Arrays; import java.util.List; public class UDAF {public static void main(String[] args) {SparkConf conf = new SparkConf().setAppName("UDAF").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(sc); List<String> nameList = Arrays.asList("xiaoming", "xiaoming", "刘德华","古天乐","feifei", "feifei", "feifei", "katong"); //转换为javaRDD JavaRDD<String> nameRDD = sc.parallelize(nameList, 3); //转换为JavaRDD<Row> JavaRDD<Row> nameRowRDD = nameRDD.map(new Function<String, Row>() {public Row call(String name) throws Exception {return RowFactory.create(name); }}); List<StructField> fields = Lists.newArrayList(); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType structType = DataTypes.createStructType(fields); DataFrame namesDF = sqlContext.createDataFrame(nameRowRDD, structType); namesDF.registerTempTable("names"); sqlContext.udf().register("countString", new Spark_UDAF()); sqlContext.sql("select name,countString(name) as count from names group by name").show(); List<Row> rows = sqlContext.sql("select name,countString(name) as count from names group by name").javaRDD().collect(); for (Row row : rows) {System.out.println(row); }} }运行结果:
+--------+-----+
| name|count|
+--------+-----+
| feifei| 3|
|xiaoming| 2|
| 刘德华| 1|
| katong| 1|
| 古天乐| 1|
+--------+-----+
2、统计某品牌价格的平均值
package com.bynear.spark_sql; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; public class MyUDAF extends UserDefinedAggregateFunction {private StructType inputSchema; private StructType bufferSchema; public MyUDAF() {ArrayList<StructField> inputFields = new ArrayList<StructField>(); inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.DoubleType, true)); inputSchema = DataTypes.createStructType(inputFields); ArrayList<StructField> bufferFields = new ArrayList<StructField>(); bufferFields.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("count", DataTypes.DoubleType, true)); bufferSchema = DataTypes.createStructType(bufferFields); }@Override public StructType inputSchema() {return inputSchema; }@Override public StructType bufferSchema() {return bufferSchema; }@Override public DataType dataType() {return DataTypes.DoubleType; }@Override public boolean deterministic() {return true; }@Override public void initialize(MutableAggregationBuffer buffer) { // 缓存区两个分组 分组编号为0 求和sum 初始化值为0 // 分组编号为1 求count 初始化值为0 buffer.update(0, 0.0); buffer.update(1, 0.0); }@Override public void update(MutableAggregationBuffer buffer, Row input) {//如果input的索引值为0的值不为0 if (!input.isNullAt(0)) { // 两个分组分别进行更新数据!分组编号0 求和sum 缓存区的值 + 输入放入值 double updatesum = buffer.getDouble(0) + input.getDouble(0); // 分组编号1 求count 缓存区的个数 + 1 double updatecount = buffer.getDouble(1) + 1; buffer.update(0, updatesum); buffer.update(1, updatecount); }}@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {double metgesum = buffer1.getDouble(0) + buffer2.getDouble(0); double mergecount = buffer1.getDouble(1) + buffer2.getDouble(1); buffer1.update(0, metgesum); buffer1.update(1, mergecount); }@Override public Object evaluate(Row buffer) {return buffer.getDouble(0) / buffer.getDouble(1); } }
package com.bynear.spark_sql; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.math.BigDecimal; import java.util.ArrayList; public class MyUDAF_SQL {public static void main(String[] args) {SparkConf conf = new SparkConf().setAppName("myUDAF").setMaster("local"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); JavaRDD<String> lines = jsc.textFile("C://Users//Administrator//Desktop//fastJSon//sales.txt"); JavaRDD<Row> map = lines.map(new Function<String, Row>() {@Override public Row call(String line) throws Exception {String[] Linesplit = line.split(","); return RowFactory.create(String.valueOf(Linesplit[0]), Double.valueOf(Linesplit[1])); }}); ArrayList<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("salary", DataTypes.DoubleType, true)); StructType structType = DataTypes.createStructType(fields); DataFrame df = sqlContext.createDataFrame(map, structType); sqlContext.udf().register("myAverage", new MyUDAF()); df.registerTempTable("zjs_table"); df.show(); sqlContext.udf().register("twoDecimal", new UDF1<Double, Double>() {@Override public Double call(Double in) throws Exception {BigDecimal b = new BigDecimal(in); double res = b.setScale(2, BigDecimal.ROUND_HALF_DOWN).doubleValue(); return res; }}, DataTypes.DoubleType); DataFrame resultDF = sqlContext.sql("select name,twoDecimal(myAverage(salary)) as 平均值 from zjs_table group by name "); resultDF.show(); } }
文本:
三星,1542
三星,1548
三星,8456
三星,8866
中兴,1856
中兴,1752
苹果,1500
苹果,2500
苹果,3500
苹果,4500
苹果,5500
运行结果:
+----+-------+
|name| salary|
+----+-------+
| 三星|12345.0|
| 三星| 4521.0|
| 三星| 7895.0|
| 华为| 5421.0|
| 华为| 4521.0|
| 华为| 5648.0|
| 苹果|12548.0|
| 苹果| 7856.0|
| 苹果|45217.0|
| 苹果|89654.0|
+----+-------+
+----+--------+
|name| 平均值|
+----+--------+
| 三星| 8253.67|
| 华为| 5196.67|
| 苹果|38818.75|
+----+--------+
注意点:文本的编码格式,以及Java代码中DataTypes.DoubleType。。。。