开发过程中,有时候函数满足不了我们的需求,我们需要自己去定义函数使用。在spark中,有三种自定义函数,分别为UDF,UDAF,UDTF。
UDF:一对一
UDAF:多对一
UDTF:一对多
hobbies.txt文件内容
alice jogging,Coding,cooking lina travel,dance
需求:求出每个人hobbies的数量
*** 作代码:
val conf: SparkConf = new SparkConf().setAppName("innserdemo").setMaster("local[*]") val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate() val sc: SparkContext = spark.sparkContext import spark.implicits._ //文件路径 val hobbyDF: Dataframe = sc.textFile("in/hobbies.txt") .map(x => x.split(" ")) .map(x => Hobbies(x(0), x(1))).toDF() hobbyDF.createOrReplaceTempView("hobby") spark.udf.register("hobby_num",(x:String)=>{x.split(",").size}) import org.apache.spark.sql.functions val hobby_num: UserDefinedFunction = functions.udf((hobbies: String) => { hobbies.split(",").size }) val newhobbyDF: Dataframe = hobbyDF.withColumn("hobbynum", hobby_num($"hobbies")) newhobbyDF.printSchema() newhobbyDF.show(false)
运行结果:
UDAF函数实例:自定义函数UDAF 继承 UserDefinedAggregateFunction
需求:根据性别分组求平均年龄
*** 作代码:
val conf: SparkConf = new SparkConf().setAppName("innserdemo").setMaster("local[*]") val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate() val sc: SparkContext = spark.sparkContext val students: Seq[Student] = Seq( Student(1, "zhangsan", "F", 22), Student(2, "lisi", "M", 38), Student(3, "wangwu", "M", 13), Student(4, "zhaoliu", "F", 17), Student(5, "songba", "M", 32), Student(6, "sunjiu", "M", 16), Student(7, "qianshiyi", "F", 17), Student(8, "yinshier", "F", 15), Student(9, "fangshisan", "M", 12), Student(10, "yeshisan", "F", 11), Student(11, "ruishiyi", "F", 26), Student(12, "chenshier", "M", 28) ) val frame: Dataframe = spark.createDataframe(students) frame.printSchema() // import org.apache.spark.sql.functions._ spark.udf.register("myAvg",new MyAgeAvgFunction) frame.createOrReplaceTempView("students") val resultDF: Dataframe = spark.sql( "select gender,myAvg(age) from students group by gender" ) resultDF.printSchema() resultDF.show(false)
自定义函数MyAgeAvgFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction{ //聚合函数的输入数据的数据结构 override def inputSchema: StructType = { // new StructType().add("age",LongType) StructType(StructField("age",LongType) :: Nil) } //在缓冲区内的数据结构 ageSum(1000) ageNum(200) //sum 用来记录 所有年龄值相加的总和 43 + 52 + 61 + 78 = 234 => sum //count 用来记录相加的总和 1 + 1 + 1 + 1 = 4 => count override def bufferSchema: StructType = { // new StructType().add("sum",LongType).add("count",LongType) StructType(StructField("num",LongType) :: StructField("count",LongType) :: Nil) } //定义当前函数返回值的类型 sum/count override def dataType: DataType = DoubleType // 聚合函数幂等 override def deterministic: Boolean = true override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0)=0L //记录传入所有用户年龄相加的总和 buffer(1)=0L //记录传入所有用户年龄的个数 } //传入一条新数据后需要进行处理 //将Row(63)对象中的值取出与buffer(0)相加 //buffer(1)数据个数加1 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } //合并各分区内的数据 //例如 p1(321,6) p2(128,2) p3(219,3) override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //计算年龄相加总和 buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //总人数 buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } //计算最终结果 override def evaluate(buffer: Row): Any = { buffer.getLong(0)/buffer.getLong(1).toDouble } }
运行结果:
自定义函数UDTF继承GenericUDTF
UDTF.txt文件内容
01//zs//Hadoop scala spark hive hbase 02//ls//Hadoop scala kafka hive hbase Oozie 03//ww//Hadoop scala spark hive sqoop
需求:求出某一位同学课程的信息
*** 作代码:
val conf: SparkConf = new SparkConf().setAppName("UDTFDemo").setMaster("local[*]") val spark: SparkSession = SparkSession.builder() .config(conf) .config("hive.metastore.uris", "thrift://192.168.91.135:9083") .enableHiveSupport() .getOrCreate() val sc: SparkContext = spark.sparkContext import spark.implicits._ val rdd: RDD[String] = sc.textFile("in/UDTF.txt") val rdd2: RDD[(String, String, String)] = rdd.map(x => { x.split("//") }).filter(x => x(1).equals("ls")) .map(x => (x(0), x(1), x(2))) val frame: Dataframe = rdd2.toDF("id", "name", "class") frame.createOrReplaceTempView("udtftable") spark.sql("create temporary function Myudtf as 'day12_13.MyUDTF'") spark.sql("select Myudtf(class) from udtftable").show(false)
自定义函数MyUDTF:
class MyUDTF extends GenericUDTF{ override def process(objects: Array[AnyRef]): Unit = { val strings: Array[String] = objects(0).toString.split(" ") for(str<-strings){ val temp = new Array[String](1) temp(0)=str forward(temp) } } override def close(): Unit = { } override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = { val fieldName = new java.util.ArrayList[String]() val fieldOIS = new java.util.ArrayList[ObjectInspector]() //定义输出字段的类型 fieldName.add("type") fieldOIS.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector) ObjectInspectorFactory.getStandardStructObjectInspector(fieldName,fieldOIS) } }
运行结果:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)