返回顶部

收藏

K-means算法(Spark Demo)

更多
import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._

object SparkKMeans {
    /**
     * line -> vector
     */
def parseVector (line: String) : Vector = {
        return new Vector (line.split (' ').map (_.toDouble) )
    }

    /**
     * 计算该节点的最近中心节点
     */
def closestCenter (p: Vector, centers: Array[Vector]) : Int = {
        var bestIndex = 0
        var bestDist = p.squaredDist (centers (0) ) //差平方之和
        for (i < - 1 until centers.length) {
            val dist = p.squaredDist (centers (i) )
            if (dist < bestDist) {
                bestDist = dist
                bestIndex = i
            }
        }
        return bestIndex
    }

def main (args: Array[String]) {
        if (args.length < 3) {
            System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
            System.exit (1)
        }
        val sc = new SparkContext (args (0), "SparkKMeans")
        val lines = sc.textFile (args (1), args (5).toInt)
                    val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector
                                 val dimensions = args (2).toInt //节点的维度
                                         val k = args (3).toInt //聚类个数
                                                 val iterations = args (4).toInt //迭代次数

                                                         // 随机初始化k个中心节点
                                                         val rand = new Random (42)
        var centers = new Array[Vector] (k)
        for (i < - 0 until k)
            centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)
                          println ("Initial centers: " + centers.mkString (", ") )
                          val time1 = System.currentTimeMillis()
            for (i < - 1 to iterations) {
                println ("On iteration " + i)

                // Map each point to the index of its closest center and a (point, 1) pair
                // that we will use to compute an average later
                val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }

                val newCenters = mappedPoints.reduceByKey {
                case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)
                    } .map {
                case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类,重新计算中心节点的位置
                    } .collect

                // 更新中心节点
                for ( (id, value) < - newCenters) {
                    centers (id) = value
                }
            }
                       val time2 = System.currentTimeMillis()
                                   println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )
    }
}

标签:java

收藏

1人收藏

支持

0

反对

0

发表评论