【回顾】SparkSQL 之 用户自定义函数

【回顾】SparkSQL 之 用户自定义函数,第1张

【回顾】SparkSQL 之 用户自定义函数

文章目录
  • 1、UDF
    • 1) 创建 Dataframe
    • 2) 注册 UDF
    • 3) 创建临时表
    • 4) 应用 UDF
  • 2、UDAF
    • 1) 实现方式 - RDD
    • 2) 实现方式 - 累加器
    • 3) 实现方式 - UDAF - 弱类型
    • 4) 实现方式 - UDAF - 强类型
      • Spark 3.0
      • 早期版本


1、UDF

UDF(User Defined Function):spark SQL中用户自定义函数,用法和spark SQL中的内置函数类似;是saprk SQL中内置函数无法满足要求,用户根据业务需求自定义的函数

基本使用步骤如下:


1) 创建 Dataframe
scala> val df = spark.read.json("/home/data/spark/user.json")
df: org.apache.spark.sql.Dataframe = [age: bigint, username: string]

2) 注册 UDF
// 自定义udf函数:添加说明词,并注册
scala> spark.udf.register("addName",(x:String)=> "Name:"+x)
res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(,StringType,Some(List(StringType)))

3) 创建临时表
scala> df.createOrReplaceTempView("people")

4) 应用 UDF
// 使用的时候直接调用注册名传入参数即可!
scala> spark.sql("Select addName(username),age from people").show()
+---------------------+---+
|UDF:addName(username)|age|
+---------------------+---+
|         Namezhangsan| 20|
|             Namelisi| 30|
|           Namewangwu| 40|
+---------------------+---+

scala> spark.sql("select addName(username) as newName,age from people").show
+------------+---+
|     newName|age|
+------------+---+
|Namezhangsan| 20|
|    Namelisi| 30|
|  Namewangwu| 40|
+------------+---+

注意:当spark-shell重新的启动的时候需要重新注册UDF函数,因为此时的SparkSession重新创建了,是新的入口。org.apache.spark.sql.AnalysisException: Undefined function: 'addName'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.

返回顶部


2、UDAF

类型的 Dataset 和 弱类型的 Dataframe 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。

除此之外,用户可以设定自己的自定义聚合函数。

  • 通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数
  • 从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了,可以统一采用强类型聚合函数 Aggregator。

需求:计算平均工资


1) 实现方式 - RDD
val conf: SparkConf = new SparkConf().setAppName("app").setMaster("local[*]")
val sc: SparkContext = new SparkContext(conf)
val res: (Int, Int) = sc.makeRDD(List(("zhangsan", 20), ("lisi", 30),("wangw", 40)))
 .map {
    case (name, age) => {
      (age, 1)
    }
 }
 .reduce {
    (t1, t2) => {
      (t1._1 + t2._1, t1._2 + t2._2)
    }
 }

println(res._1/res._2)
// 关闭连接
sc.stop()

30

返回顶部


2) 实现方式 - 累加器
class MyAC extends AccumulatorV2[Int,Int]{

   var sum:Int = 0   
   var count:Int = 0
   
   override def isZero: Boolean = {
     return sum ==0 && count == 0
   }
   
   override def copy(): AccumulatorV2[Int, Int] = {
     val newMyAc = new MyAC
     newMyAc.sum = this.sum
     newMyAc.count = this.count
     newMyAc
   }
   
   override def reset(): Unit = {
     sum =0
     count = 0
   }
   // 求和、计数
   override def add(v: Int): Unit = {
     sum += v
     count += 1
   }
   // 聚合
   override def merge(other: AccumulatorV2[Int, Int]): Unit = {
     other match {
       case o:MyAC => {
         sum += o.sum
         count += o.count
       }
       case _ => {}
     }
   }
   // 计算结果
   override def value: Int = sum/count
}

返回顶部


3) 实现方式 - UDAF - 弱类型

自定义avgUDF函数类,继承UserDefinedAggregateFunction ,并重写方法

  • // 输入的数据的结构
    override def inputSchema: StructType
    // 缓冲区的数据结构
    override def bufferSchema: StructType
    // 函数计算结果的数据类型
    override def dataType: DataType = LongType
    // 函数的稳定性,传入传出数据的类型保持一致
    override def deterministic: Boolean = true
    // 缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit
    // 根据输入的值跟新缓冲区
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit
    // 缓冲区数据合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
    // 计算平均值
    override def evaluate(buffer: Row): Any
// 自定义聚合函数类计算年龄平均值
class avgUDF extends UserDefinedAggregateFunction {
  // 输入的数据的结构
  override def inputSchema: StructType = {
    StructType(Array(StructField("age",LongType)))
  }
  // 缓冲区的数据结构
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("total",LongType),
        StructField("count",LongType)
      )
    )
  }
  // 函数计算结果的数据类型
  override def dataType: DataType = LongType
  // 函数的稳定性,传入传出数据的类型保持一致
  override def deterministic: Boolean = true
  // 缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    buffer.update(0,0L)
    buffer.update(1,0L)
  }
  // 根据输入的值跟新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 缓冲区例的 total + 传进来的 age
    buffer.update(0,buffer.getLong(0)+input.getLong(0))
    // count + 1
    buffer.update(1,buffer.getLong(1)+1)
  }
  // 缓冲区数据合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
  }
  // 计算平均值
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}
// 创建视图
df.createOrReplaceTempView("people")
// 调用udf函数
spark.sql("select avgAge(age) from people").show()

+-----------+
|avgudf(age)|
+-----------+
|         23|
+-----------+

返回顶部


4) 实现方式 - UDAF - 强类型 Spark 3.0

Spark 3.0 中 使用Aggregator替代了原来的 UserDefinedAggregateFunction,具体使用如下:

  • 自定义聚合类继承 Aggregator => 定义泛型
 // 样例类
 case class Buff( var sum:Long, var cnt:Long )
 // 自定义聚合类
 class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{
   // 初始值
   override def zero: Buff = Buff(0,0)
   // 根据输入数据跟新缓冲区的数据
   override def reduce(b: Buff, a: Long): Buff = {
        b.sum += a
        b.cnt += 1
        b
   }
   // 聚合:合并缓冲区
   override def merge(b1: Buff, b2: Buff): Buff = {
         b1.sum += b2.sum
         b1.cnt += b2.cnt
         b1
   }
   // 计算结果
   override def finish(reduction: Buff): Double = {
         reduction.sum.toDouble/reduction.cnt
   }
   // 网络传输缓冲区编码  自定义的类型就选product
   override def bufferEncoder: Encoder[Buff] = Encoders.product
   // 网络传输缓冲区解码  spark原有的就选相应的
   override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }
// TODO 创建 UDAF 函数
val udaf = new MyAvgAgeUDAF
// TODO 注册到 SparkSQL 中
spark.udf.register("avgAge", functions.udaf(new MyAvgAgeUDAF()))
// TODO 在 SQL 中使用聚合函数
spark.sql("select avgAge(age) from people").show()

+-----------+
|avgudf(age)|
+-----------+
|         23|
+-----------+

那么如果在早期版本中使用强类型的UDAF,该怎样使用呢?

返回顶部


早期版本

早期版本在使用 Aggregator 的时候基本步骤不变定义泛型的时候指定为User类型。

// 样例类
case class User(username:String, age:Long)
case class Buff( var total:Long, var count:Long )
class MyAvgAgeUDAF extends Aggregator[User, Buff, Long]{
  // z & zero : 初始值或零值
  // 缓冲区的初始化
  override def zero: Buff = {
    Buff(0L,0L)
  }
  // 根据输入的数据更新缓冲区的数据
  override def reduce(buff: Buff, in: User): Buff = {
    buff.total = buff.total + in.age
    buff.count = buff.count + 1
    buff
  }
  // 合并缓冲区
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.total = buff1.total + buff2.total
    buff1.count = buff1.count + buff2.count
    buff1
  }
  //计算结果
  override def finish(buff: Buff): Long = {
    buff.total / buff.count
  }
  // 缓冲区的编码 *** 作
  override def bufferEncoder: Encoder[Buff] = Encoders.product
  // 输出的编码 *** 作
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

使用UDFA函数的时候 *** 作有所不同,需要使用DSL语法 *** 作

将udfa函数转为查询列的对象,进行查询

// TODO:早期版本的强类型 UDAF 使用DSL语法 *** 作
// 读取数据
val df = spark.read.json("data/user.json")
val ds = df.as[User]
// 将udfa函数转换为查询的列对象
val udafColumn = new MyAvgAgeUDAF().toColumn

ds.select(udafColumn).show()

+------------------------------------------------------+
|MyAvgAgeUDAF(test02_UDF.Spark02_sql_UDF_avgAge04$User)|
+------------------------------------------------------+
|                                                    23|
+------------------------------------------------------+

返回顶部


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5697091.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存