hive复杂数据类型GenericUDF编写

hive复杂数据类型GenericUDF编写,第1张

hive复杂数据类型GenericUDF编写 引入依赖

  org.apache.hive
  hive-exec
  3.0.0

自定义函数
public class GetWorkTimeGenericUDF extends GenericUDF {
    ListObjectInspector listOI1;
    ListObjectInspector listOI2;
    StringObjectInspector elementOI1;
    StringObjectInspector elementOI2;
    private transient ObjectInspector[] argumentOIs;

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        if (arguments.length != 4) {
            throw new UDFArgumentLengthException("get_work_time takes 4 arguments: string,string array,array");
        }

        this.argumentOIs = arguments;

        if (isVoidType(arguments)) {
            return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        }
        ObjectInspector a = arguments[0];
        ObjectInspector b = arguments[1];
        ObjectInspector c = arguments[2];
        ObjectInspector d = arguments[3];
        this.listOI1 = (ListObjectInspector) c;
        this.listOI2 = (ListObjectInspector) d;
        this.elementOI1 = (StringObjectInspector) a;
        this.elementOI2 = (StringObjectInspector) b;


        // 1. 检查是否接收到正确的参数类型
        if (!(a instanceof StringObjectInspector) || !(b instanceof StringObjectInspector) || !(c instanceof ListObjectInspector) || !(d instanceof ListObjectInspector)) {
            throw new UDFArgumentException("first and second argument must be a string, third and fourth argument must be a list / array");
        }

        // 2. 检查listOI1是否包含的元素都是string, listOI2是否包含的元素都是int
        if (!(listOI1.getListElementObjectInspector() instanceof WritableStringObjectInspector) || !(listOI2.getListElementObjectInspector() instanceof WritableLongObjectInspector)) {
            throw new UDFArgumentException("third argument must be a list/array of strings,fourth argument must be a list/array of bigint");
        }
        // 返回类型是boolean,所以我们提供了正确的object inspector
        return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    }

    @Override
    public String getDisplayString(String[] children) {
        return "返回指定时间段内的工作日时长:get_work_time(string,string,array,array)";
    }


    @Override
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        double sum = 0.0D;
        if (isVoidType(argumentOIs)) {
            return sum;
        }
        String start_date = this.elementOI1.getPrimitiveJavaObject(arguments[0].get());
        String end_date = this.elementOI2.getPrimitiveJavaObject(arguments[1].get());
        List dates = (List) this.listOI1.getList(arguments[2].get());
        List workFlags = (List) this.listOI2.getList(arguments[3].get());

        // 检查空值
        if (start_date == null || end_date == null || dates == null || workFlags == null) {
            return sum;
        }
        DateTimeFormatter ftf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
        LocalDateTime startDate = LocalDateTime.parse(start_date, ftf);
        LocalDateTime endDate = LocalDateTime.parse(end_date, ftf);

        for (int i = 0, len = dates.size(); i < len; i++) {
            long workFlag = workFlags.get(i).get();
            if (startDate.format(DateTimeFormatter.ISO_LOCAL_DATE).equals(dates.get(i).toString())) {
                sum += workFlag * (1 - (startDate.getHour() * 60 * 60 + startDate.getMinute() * 60 + startDate.getSecond()) / (double) (60 * 60 * 24));
            } else if (endDate.format(DateTimeFormatter.ISO_LOCAL_DATE).equals(dates.get(i).toString())) {
                sum += workFlag * (endDate.getHour() * 60 * 60 + endDate.getMinute() * 60 + endDate.getSecond()) / (double) (60 * 60 * 24);
            } else {
                sum += workFlag;
            }
        }
        return sum;
    }

    protected boolean isVoidType(ObjectInspector[] arguments) {
        for (ObjectInspector oi : arguments) {
            if (oi.getCategory() == ObjectInspector.Category.PRIMITIVE) {
                if (((PrimitiveObjectInspector) oi).getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.VOID) {
                    return true;
                }
            }
        }
        return false;
    }
测试
public static void main(String[] args) throws Exception {
        // 建立需要的模型
        GetWorkTimeGenericUDF example = new GetWorkTimeGenericUDF();
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
//        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaVoidObjectInspector;
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector listOI1 = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector listOI2 = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
        JavaDoubleObjectInspector resultInspector = (JavaDoubleObjectInspector) example.initialize(new ObjectInspector[]{stringOI, stringOI, listOI1, listOI2});

        // create the actual UDF arguments
        List list1 = new ArrayList() {{
            add(new Text("2022-01-01"));
            add(new Text("2022-01-05"));
            add(new Text("2022-01-04"));
            add(new Text("2022-01-03"));
            add(new Text("2022-01-02"));
        }};
        List list2 = new ArrayList() {{
            add(new LongWritable(0));
            add(new LongWritable(1));
            add(new LongWritable(1));
            add(new LongWritable(0));
            add(new LongWritable(0));
        }};

        // 测试结果
        // 存在的值
        Object result = example.evaluate(new DeferredObject[]{new DeferredJavaObject("2022-01-01 13:05:12"), new DeferredJavaObject("2022-01-05 14:11:21"),
                new DeferredJavaObject(list1), new DeferredJavaObject(list2)});
        System.out.println(result);
        System.out.println(resultInspector.get(result));

        // 为null的参数
        Object result1 = example.evaluate(new DeferredObject[]{new DeferredJavaObject(null), new DeferredJavaObject("2022-01-05 14:11:21"),
                new DeferredJavaObject(list1), new DeferredJavaObject(null)});
        System.out.println(result1);
        System.out.println(resultInspector.get(result1));
    }

注意:evaluate方法的返回值类型,必须跟initialize校验返回值类型一致。

创建函数 创建临时函数
hive> add jar /data/sql_data/udf-1.jar;
hive> create temporary function f1 as 'org.pony.hive.udf.GetWorkTimeGenericUDF';
hive> show functions;
创建永久函数
hive> create database udf;
hive> DROP FUNCTION IF EXISTS udf.get_work_time;
hive> create function udf.get_work_time as 'org.pony.hive.udf.GetWorkTimeUDAF' using jar 'hdfs:/user/hive/udf/alg-training-5.0-SNAPSHOT.jar';
hive> show functions like udf.get_work_time;
使用函数

f1(“2022-01-01 13:05:12”, “2022-01-05 14:11:21”,array(“2022-01-05”, “2022-01-04”, “2022-01-03”, “2022-01-02”, “2022-01-01”),array(1, 1, 0, 0, 0));

f1(“2022-01-01 13:05:12”, “2022-01-05 14:11:21”,array(“2022-01-05”, “2022-01-04”, “2022-01-03”, “2022-01-02”, “2022-01-01”),null);

f1(null, “2022-01-05 14:11:21”,array(“2022-01-05”, “2022-01-04”, “2022-01-03”, “2022-01-02”, “2022-01-01”),array);

f1(null, “2022-01-05 14:11:21”,null,null);

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5706640.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存