学习来源:日撸 Java 三百行(51-60天,kNN 与 NB)
第 56 天: kMeans 聚类K均值聚类算法(k-means clustering algorithm,KMeans)是一种无监督学习的算法,它解决的是聚类问题。
算法思想:
KMeans是典型的基于距离的聚类算法,采用距离作为相似性的评价指标,即认为两个对象的距离越近,其相似度就越大。该算法认为簇是由距离靠近的对象组成的,因此把得到紧凑且独立的簇作为最终目标。主要思想是:在给定聚类簇数K和K个初始聚类中心点的情况下,把每个点分到离其最近的聚类中心点所代表的簇中。所有点分配完毕之后,根据一个簇内的所有点来重新计算该簇的聚类中心点,然后再迭代地进行分配点和更新聚类中心点的步骤,直至各聚类中心点不再变化。
算法步骤:
1.设置所期望的聚类簇数 K的值,代码中设置K值为3。
2.读入iris.arff中的数据。
3.随机初始化K个聚类中心点。
4.计算每个点到各聚类中心点的欧式距离,并将其分到距离最小的聚类中心点所在的簇中。
5.重新计算每个簇的聚类中心点。
6.重复上面4、5两步 *** 作,直到各聚类中心点不再变化。
代码:
package machine_learning;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
/**
* @Description: kMeans clustering.
* @author: Xin-Yu Li
* @time: May 6(th),2022
*/
public class KMeans {
public static final int MANHATTAN = 0;
public static final int EUCLIDEAN = 1;
public int distanceMeasure = EUCLIDEAN;
public static final Random random = new Random();
Instances dataset;
int numClusters = 2;
int[][] clusters;
public KMeans(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
} // Of try
}// Of the first constructor
public void setNumClusters(int paraNumClusters) {
numClusters = paraNumClusters;
}// Of the setter
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
public double distance(int paraI, double[] paraArray) {
int resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// Of switch
return resultDistance;
}// Of distance
public void clustering() {
int[] tempOldClusterArray = new int[dataset.numInstances()];
tempOldClusterArray[0] = -1;
int[] tempClusterArray = new int[dataset.numInstances()];
Arrays.fill(tempClusterArray, 0);
double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];
int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
for (int i = 0; i < numClusters; i++) {
for (int j = 0; j < tempCenters[0].length; j++) {
tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
} // Of for j
} // Of for i
int[] tempClusterLengths = null;
while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
System.out.println("New loop ...");
tempOldClusterArray = tempClusterArray;
tempClusterArray = new int[dataset.numInstances()];
int tempNearestCenter;
double tempNearestDistance;
double tempDistance;
for (int i = 0; i < dataset.numInstances(); i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for (int j = 0; j < numClusters; j++) {
tempDistance = distance(i, tempCenters[j]);
if (tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
} // Of if
} // Of for j
tempClusterArray[i] = tempNearestCenter;
} // Of for i
tempClusterLengths = new int[numClusters];
Arrays.fill(tempClusterLengths, 0);
double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
} // Of for j
tempClusterLengths[tempClusterArray[i]]++;
} // Of for i
for (int i = 0; i < tempNewCenters.length; i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[i][j] /= tempClusterLengths[i];
} // Of for j
} // Of for i
System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
tempCenters = tempNewCenters;
} // Of while
clusters = new int[numClusters][];
int[] tempCounters = new int[numClusters];
for (int i = 0; i < numClusters; i++) {
clusters[i] = new int[tempClusterLengths[i]];
} // Of for i
for (int i = 0; i < tempClusterArray.length; i++) {
clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
tempCounters[tempClusterArray[i]]++;
} // Of for i
System.out.println("The clusters are: " + Arrays.deepToString(clusters));
}// Of clustering
public static void testClustering() {
KMeans tempKMeans = new KMeans("C:\Users\LXY\Desktop\iris.arff");
tempKMeans.setNumClusters(3);
tempKMeans.clustering();
}// Of testClustering
public static void main(String arags[]) {
testClustering();
}// Of main
}// Of class KMeans
运行截图:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)