需要数据集的可以给我留言
object DecisionTreeClass { def main(args: Array[String]): Unit = { val sc = new SparkContext("local[*]", "AdultData") val raw_data = sc.textFile("adult.data") val data = raw_data.map(line => line.split(", ")).filter(fields => fields.length == 15) data.cache() //数据索引化处理 val occupation_category_types = data.map(fields => fields(1)).distinct.collect() val number_set = data.map(fields => fields(2).toInt).collect().toSet val education_types = data.map(fields => fields(3)).distinct.collect() val marriage_types = data.map(fields => fields(5)).distinct.collect() val family_condition_types = data.map(fields => fields(7)).distinct.collect() val occupation_types = data.map(fields => fields(6)).distinct.collect() val racial_types = data.map(fields => fields(8)).distinct.collect() val nationality_types = data.map(fields => fields(13)).distinct.collect() println(marriage_types.length) val education_dict = acquireDict(education_types) val marriage_dict = acquireDict(marriage_types) val family_condition_dict = acquireDict(family_condition_types) val occupation_category_dict = acquireDict(occupation_category_types) val occupation_dict = acquireDict(occupation_types) val racial_dict = acquireDict(racial_types) val nationality_dict = acquireDict(nationality_types) val sex_dict = Map("Male" -> 1, "Female" -> 0) val data_set = data.map { fields => val number = fields(2).toInt val education = education_dict(fields(3)) val marriage = marriage_dict(fields(5)) val family_condition = family_condition_dict(fields(7)) val occupation_category = occupation_category_dict(fields(1)) val occupation = occupation_dict(fields(6)) val sex = sex_dict(fields(9)) val race = racial_dict(fields(8)) val nationality = nationality_dict(fields(13)) val featureVector = Vectors.dense(education, occupation, occupation_category, sex, family_condition, race, nationality) val label = marriage LabeledPoint(label, featureVector) } data_set.take(10).foreach(println) val Array(trainData, cvData, testData) = data_set.randomSplit(Array(0.8, 0.1, 0.1)) trainData.cache cvData.cache testData.cache //第一个参数为训练所需数据 //第二个参数numClassses为分类数,在此为代表婚姻数据种类数 //第三个参数categoricalFeaturesInfo用于标准化映射的形式,一般不用特别定义 //第四个参数impurity用于规定特征值合并的方式,在分类中有Gini(基尼不纯度)和entropy(熵)两种,在回归中有variance一种 //第五个参数maxDepth规定了决策树的最大深度,深度越深的树会更好地符合训练数据,但会消耗更多资源,而且会产生过拟合现象 //第六个参数maxBins规定了最多分类的种数,更多的分类数能更好分割种类,更好利用数据进行分类,但增加了计算量,其值不得少于提供的种类个数 val model = DecisionTree. trainClassifier(trainData, 7, Map[Int, Int](), "entropy", 10, 100) val predictionsAndLabels = cvData.map(example => (model.predict(example.features), example.label) ) //设置一个元组量,用于保存预测值与真的类别,用MulticlassMetrics来分析模型,得到训练准确值: val metrics = new MulticlassMetrics(predictionsAndLabels) println(metrics.precision) //考虑到我们只有三万多组数据,通过三重循环来探索更优的参数设置: val evaluations = for (impurity <- Array("gini", "entropy"); depth <- Array(1, 10, 25); bins <- Array(10, 50, 150)) yield { val _model = DecisionTree. trainClassifier(trainData, 7, Map[Int, Int](), impurity, depth, bins) val _predictionsAndLabels = cvData.map(example => (_model.predict(example.features), example.label) ) val _accuracy = new MulticlassMetrics(_predictionsAndLabels).precision ((depth, bins, impurity), _accuracy) } evaluations.sortBy(_._2).reverse.foreach(println) }
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)