Java(19)

Java(19),第1张

学习来源:日撸 Java 三百行(51-60天,kNN 与 NB)

第 56 天: kMeans 聚类

K均值聚类算法(k-means clustering algorithm,KMeans)是一种无监督学习的算法,它解决的是聚类问题。
算法思想:
KMeans是典型的基于距离的聚类算法,采用距离作为相似性的评价指标,即认为两个对象的距离越近,其相似度就越大。该算法认为簇是由距离靠近的对象组成的,因此把得到紧凑且独立的簇作为最终目标。主要思想是:在给定聚类簇数K和K个初始聚类中心点的情况下,把每个点分到离其最近的聚类中心点所代表的簇中。所有点分配完毕之后,根据一个簇内的所有点来重新计算该簇的聚类中心点,然后再迭代地进行分配点和更新聚类中心点的步骤,直至各聚类中心点不再变化。

算法步骤:
1.设置所期望的聚类簇数 K的值,代码中设置K值为3。
2.读入iris.arff中的数据。
3.随机初始化K个聚类中心点。
4.计算每个点到各聚类中心点的欧式距离,并将其分到距离最小的聚类中心点所在的簇中。
5.重新计算每个簇的聚类中心点。
6.重复上面4、5两步 *** 作,直到各聚类中心点不再变化。

代码:

package machine_learning;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;

/**
 * @Description: kMeans clustering.
 * @author: Xin-Yu Li
 * @time: May 6(th),2022
 */
public class KMeans {
	public static final int MANHATTAN = 0;
	public static final int EUCLIDEAN = 1;
	public int distanceMeasure = EUCLIDEAN;

	public static final Random random = new Random();
	Instances dataset;
	int numClusters = 2;
	int[][] clusters;

	public KMeans(String paraFilename) {
		dataset = null;
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

	public void setNumClusters(int paraNumClusters) {
		numClusters = paraNumClusters;
	}// Of the setter

	
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices

	public double distance(int paraI, double[] paraArray) {
		int resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch

		return resultDistance;
	}// Of distance

	public void clustering() {
		int[] tempOldClusterArray = new int[dataset.numInstances()];
		tempOldClusterArray[0] = -1;
		int[] tempClusterArray = new int[dataset.numInstances()];
		Arrays.fill(tempClusterArray, 0);
		double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];

		int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
		for (int i = 0; i < numClusters; i++) {
			for (int j = 0; j < tempCenters[0].length; j++) {
				tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
			} // Of for j
		} // Of for i

		int[] tempClusterLengths = null;
		while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
			System.out.println("New loop ...");
			tempOldClusterArray = tempClusterArray;
			tempClusterArray = new int[dataset.numInstances()];

			int tempNearestCenter;
			double tempNearestDistance;
			double tempDistance;

			for (int i = 0; i < dataset.numInstances(); i++) {
				tempNearestCenter = -1;
				tempNearestDistance = Double.MAX_VALUE;

				for (int j = 0; j < numClusters; j++) {
					tempDistance = distance(i, tempCenters[j]);
					if (tempNearestDistance > tempDistance) {
						tempNearestDistance = tempDistance;
						tempNearestCenter = j;
					} // Of if
				} // Of for j
				tempClusterArray[i] = tempNearestCenter;
			} // Of for i

			tempClusterLengths = new int[numClusters];
			Arrays.fill(tempClusterLengths, 0);
			double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
			for (int i = 0; i < dataset.numInstances(); i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
				} // Of for j
				tempClusterLengths[tempClusterArray[i]]++;
			} // Of for i

			for (int i = 0; i < tempNewCenters.length; i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[i][j] /= tempClusterLengths[i];
				} // Of for j
			} // Of for i

			System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
			tempCenters = tempNewCenters;
		} // Of while

		clusters = new int[numClusters][];
		int[] tempCounters = new int[numClusters];
		for (int i = 0; i < numClusters; i++) {
			clusters[i] = new int[tempClusterLengths[i]];
		} // Of for i

		for (int i = 0; i < tempClusterArray.length; i++) {
			clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
			tempCounters[tempClusterArray[i]]++;
		} // Of for i

		System.out.println("The clusters are: " + Arrays.deepToString(clusters));
	}// Of clustering

	public static void testClustering() {
		KMeans tempKMeans = new KMeans("C:\Users\LXY\Desktop\iris.arff");
		tempKMeans.setNumClusters(3);
		tempKMeans.clustering();
	}// Of testClustering

	
	public static void main(String arags[]) {
		testClustering();
	}// Of main

}// Of class KMeans

运行截图:

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/871269.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-13
下一篇 2022-05-13

发表评论

登录后才能评论

评论列表(0条)

保存