spark中调用xgboost实现2分类时需对rawPrediction修改

spark中调用xgboost实现2分类时需对rawPrediction修改,第1张

spark中调用xgboost实现2分类时需对rawPrediction修改

1、修改代码


  private def amendXGBPred(res: Dataframe): Dataframe = {
    val columns = res.columns

    if (columns.contains("rawPrediction")) {
      val aRes  = res.withColumnRenamed("rawPrediction", "rawPrediction_Ori")
      val code = (arg: Vector) => {//这个函数使原来的vector,变成新的vector
        val rawPre = arg.apply(0)
        new DenseVector(Array(-1.0 * rawPre, rawPre))
      }
      val addCol = udf(code)
      val columns = aRes.columns
      aRes.selectExpr(columns:_*).withColumn("rawPrediction", addCol(aRes("rawPrediction_Ori")))
    } else {
      res
    }

2、原由

一般的2分类模型,rawPrediction有2列,一列是分类为0的原始预测数值、一列是1的;

但,XGBoost 0.81 Java 开源版本有bug,二分类预测结果rawPrediction只有一列数据,是分类为1的预测数值;

而,spark内置的交叉验证源代码评估时使用的是rawPrediction列,因此对XGBoost算法产生的rawPrediction列进行下调整修改;

3、图片

 

 

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5688641.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存