1、修改代码
private def amendXGBPred(res: Dataframe): Dataframe = { val columns = res.columns if (columns.contains("rawPrediction")) { val aRes = res.withColumnRenamed("rawPrediction", "rawPrediction_Ori") val code = (arg: Vector) => {//这个函数使原来的vector,变成新的vector val rawPre = arg.apply(0) new DenseVector(Array(-1.0 * rawPre, rawPre)) } val addCol = udf(code) val columns = aRes.columns aRes.selectExpr(columns:_*).withColumn("rawPrediction", addCol(aRes("rawPrediction_Ori"))) } else { res }
2、原由
一般的2分类模型,rawPrediction有2列,一列是分类为0的原始预测数值、一列是1的;
但,XGBoost 0.81 Java 开源版本有bug,二分类预测结果rawPrediction只有一列数据,是分类为1的预测数值;
而,spark内置的交叉验证源代码评估时使用的是rawPrediction列,因此对XGBoost算法产生的rawPrediction列进行下调整修改;
3、图片
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)