# K-means算法(Spark Demo)

```import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._

object SparkKMeans {
/**
* line -> vector
*/
def parseVector (line: String) : Vector = {
return new Vector (line.split (' ').map (_.toDouble) )
}

/**
* 计算该节点的最近中心节点
*/
def closestCenter (p: Vector, centers: Array[Vector]) : Int = {
var bestIndex = 0
var bestDist = p.squaredDist (centers (0) ) //差平方之和
for (i < - 1 until centers.length) {
val dist = p.squaredDist (centers (i) )
if (dist < bestDist) {
bestDist = dist
bestIndex = i
}
}
return bestIndex
}

def main (args: Array[String]) {
if (args.length < 3) {
System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
System.exit (1)
}
val sc = new SparkContext (args (0), "SparkKMeans")
val lines = sc.textFile (args (1), args (5).toInt)
val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点，再将每个节点转换成Vector
val dimensions = args (2).toInt //节点的维度
val k = args (3).toInt //聚类个数
val iterations = args (4).toInt //迭代次数

// 随机初始化k个中心节点
val rand = new Random (42)
var centers = new Array[Vector] (k)
for (i < - 0 until k)
centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)
println ("Initial centers: " + centers.mkString (", ") )
val time1 = System.currentTimeMillis()
for (i < - 1 to iterations) {
println ("On iteration " + i)

// Map each point to the index of its closest center and a (point, 1) pair
// that we will use to compute an average later
val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }

val newCenters = mappedPoints.reduceByKey {
case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)
} .map {
case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类，重新计算中心节点的位置
} .collect

// 更新中心节点
for ( (id, value) < - newCenters) {
centers (id) = value
}
}
val time2 = System.currentTimeMillis()
println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )
}
}
```

1人收藏

0

0