spark可以定义包括new类,uadf类,mongodb类,scala类的方法。
有时候看到 new 类().{方法定义}.方法 这种怪异的代码,是匿名内部类的用法。
public class practice {
public static void main(String[] args) {
flatMap(new FlatMapFunction() {
@Override
public void call() {
System.out.println("spark接口就是这么调用的");
}
});
}
static void flatMap(FlatMapFunction tf) {
tf.call();
}
}
interface FlatMapFunction {
void call();
}
要使用接口,就必须实现接口的方法再调用方法。
匿名内部类语法,允许我们不需要单独定义接口,而是在main方法中来实现这个过程。这使得接口实现的修改像if和for一样随意了,比如spark中的一些接口的重载方法。
所谓“内部”是指在方法内调用,”匿名”是指没有给接口的实现类具体命名。
再看看spark的java版本算子,就是用到了匿名内部类,对于没有接触过匿名内部类的人,一定搞得云里雾里。
public class FlatMapOperator {
public static void main(String[] args){
SparkConf conf = new SparkConf().setMaster("local").setAppName("flatmap");
JavaSparkContext sc = new JavaSparkContext(conf);
List<String> list = Arrays.asList("w1 1","w2 2","w3 3","w4 4");
JavaRDD<String> listRdd = sc.parallelize(list);
JavaRDD<String> result = listRdd.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterator<String> call(String s) throws Exception {
return Arrays.asList(s.split(" ")).iterator();
}
});
result.foreach(new VoidFunction<String>() {
@Override
public void call(String s) throws Exception {
System.err.println(s);
}
});
}
}
FlatMapFunction是内部匿名类的声明,<String, String> 是接口的模板,call是重载的接口方法。
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("UDAF").getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("in/user.json")
//创建并注册自定义UDAF函数
val function = new MyAgeAvgFunction
spark.udf.register("myAvgAge",function)
//创建视图
df.createTempView("userinfo")
//查询男女平均年龄
val df2: DataFrame = spark.sql("select sex,myAvgAge(age) from userinfo group by sex")
df2.show()
}
}
//实现UDAF类
//实现的功能是对传入的数值进行累加,并且计数传入的个数,最后相除得到平均数
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输入数据结构
override def inputSchema: StructType = {
new StructType().add(StructField("age",LongType))
}
//缓存区数据结构
override def bufferSchema: StructType = {
new StructType().add(StructField("sum",LongType)).add(StructField("count",LongType))
}
//聚合函数返回值数据结构
override def dataType: DataType = DoubleType
//聚合函数是否是幂等的,即相同输入是否能得到相同输出
override def deterministic: Boolean = true
//设定默认值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//sum
buffer(0)=0L
//count
buffer(1)=0L
}
//给聚合函数传入一条新数据时所需要进行的操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将传入的数据进行累加
buffer(0)=buffer.getLong(0)+input.getLong(0)
//每传入一次计数加一
buffer(1)=buffer.getLong(1)+1
}
//合并聚合函数的缓冲区(不同分区)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//不同分区的数据进行累加
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
//将sum/count的得到平均数
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}。