spark3.0 用户自定义函数
重写Aggregator 方法
import org.apache.spark.{SparkConf, sql} import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions} import org.apache.spark.sql.expressions.Aggregator object Spark_basic { def main(args: Array[String]): Unit = { val conf = new SparkConf().setMaster("local[*]").setAppName("waj") val spark = SparkSession.builder().config(conf).getOrCreate() //TODO user defined function // 创建Dataframe val df = spark.read.json("datas/user.json") //注册临时表 df.createTempView("user") // 注册udaf 函数 spark.udf.register("mymean",functions.udaf(new MyAvgUDAF())) spark.sql("select mymean(age ) as mean from user").show() spark.close() } case class Buff(var total:Long,var count:Long) class MyAvgUDAF extends Aggregator[Long,Buff,Long]{ override def zero: Buff = new Buff(0,0L) override def reduce(b: Buff, a: Long): Buff = { b.count+=1 b.total+=a b } override def merge(b1: Buff, b2: Buff): Buff = { b1.total=b1.total+b2.total b1.count=b2.count+b1.count b1 } override def finish(reduction: Buff): Long = reduction.total/reduction.count override def bufferEncoder: Encoder[Buff] = Encoders.product override def outputEncoder: Encoder[Long] = Encoders.scalaLong } }
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)