- 一、具体含义
- 二、特点
- 三、基本思路
- 四、具体步骤
- 五、实现代码
学习来源: 日撸 Java 三百行(51-60天,kNN 与 NB) 一、具体含义
- KNN是K-Nearest Neighbor的英文简写,中文直译就是K个最近邻,有人干脆称之为“最近邻算法”。
- 字母“K”也许看着新鲜,不过作用其实早在中学就接触过。在学习排列组合时,教材都喜欢用字母“n”来指代多个,譬如“求n个数的和”,这里面也没有什么秘密,就是约定俗成的用法。而KNN算法的字母K扮演的就是与n同样的角色。K的值是多少,就代表使用了多少个最近邻。机器学习总要有自己的约定俗成,没来由地就是喜爱用“K”而不是“n”来指代多个,类似的命名方法还有后面将要提到的K-means算法。
- KNN的关键在于最近邻,光看名字似乎与分类没有什么关系,但前面我们介绍了, KNN的核心在于多数表决,而谁有投票表决权呢?就是这个“最近邻”,也就是以待分类样本点为中心,距离最近的K个点。这K个点中什么类别的占比最多,待分类样本点就属于什么类别。
- 简单. 没有学习过程, 也被称为惰性学习 lazy learning. 类似于开卷考试, 在已有数据中去找答案.
- 本源. 找相似, 正是人类认识事物的常用方法, 隐藏于人类或者其他动物的基因里面. 当然, 人类也会上当, 例如有人把邻居的滴水观音误认为是芋头, 偷食后中毒.
- 效果好. 永远不要小视 kNN, 对于很多数据, 你很难设计算法超越它.
- 适应性强. 可用于分类, 回归. 可用于各种数据.
- 可扩展性强. 设计不同的度量, 可获得意想不到的效果.
- 一般需要对数据归一化.
- 复杂度高. 这也是 kNN 最重要的缺点. 对于每一个测试数据, 复杂度为 O ( ( m + k ) n ) O((m+k)n)O((m+k)n), 其中 n nn 为训练数据个数, m mm 为条件属性个数, k kk 为邻居个数. 代码见 computeNearests().
- KNN最核心的功能“分类”是通过多数表决“投票”来完成的,具体方法是在待分类点的K个最近邻中查看哪个类别占比最多。哪个类别多,待分类点就属于哪个类别。
- 怎样确定K——是一个需要根据实际情况调节以便取得更好拟合效果的参数,可以根据交叉验证等实验方法,结合工作经验进行设置。
- KNN中常用的度量方法——欧几里得距离和曼哈顿距离。
- 找K个最近邻。KNN分类算法的核心就是找最近的K个点,选定度量距离的方法之后,以待分类样本点为中心,分别测量它到其他点的距离,找出其中的距离最近的 K个,这就是K个最近邻。
- 统计最近邻的类别占比。确定了最近邻之后,统计出每种类别在最近邻中的占比。
- 选取占比最多的类别作为待分类样本的类别。
package machinelearning.knn;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
/**
*
* @author Ling Lin E-mail:linling0.0@foxmail.com
*
* @version 创建时间:2022年4月30日 上午10:19:24
*
*/
public class KnnClassification {
// Manhattan distance.曼哈顿距离
public static final int MANHATTAN = 0;
// Euclidean distance.欧几里得距离
public static final int EUCLIDEAN = 1;
// The distance measure.
public int distanceMeasure = EUCLIDEAN;
// A random instance
public static final Random random = new Random();
// The number of neighbors
int numNeighbors = 7;
// The whole dataset.
Instances dataset;
// The training set. Represented by the indices of the data.
int[] trainingSet;
// The testing set. Represented by the indices of the data.
int[] testingSet;
// The predictions.
int[] predictions;
/**
* The first constructor.
*
* @param paraFilename
* The arff filename.
*
*/
public KnnClassification(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
// The last attribute is the decision class.
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception ee) {
System.out.println("Error occurred while trying to read \'" + paraFilename
+ "\' in KnnClassification constructor.\r\n" + ee);
System.exit(0);
} // Of try
}// Of the first constructor
/**
* Get a random indices for data randomization.
*
* @param paraLength
* The length of the sequence.
* @return An array of indices,e.g., {4, 3, 1, 5, 0, 2} with length 6.
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomInndices
/**
* Split the data into training and testing parts.
*
* @param paraTrainingFraction
* The fraction of the training set.
*/
public void splitTrainingTesting(double paraTrainingFraction) {
int tempSize = dataset.numInstances();
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
} // Of for i
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
} // Of for i
}// Of splitTrainingTesting
/**
* Predict for the whole testing set. The results are stored in predictions.
* #see predictions.
*/
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
} // Of for i
}// Of predict
/**
* Predict for given instance.
*
* @return The prediction.
*/
public int predict(int paraIndex) {
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}// Of predict
/**
* The distance between two instances.
*
* @param paraI
* The index of the first instance.
* @param paraJ
* The index of the second instance.
* @return The distance.
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// Of switch
return resultDistance;
}// Of distance
/**
* Get the accuracy of the classifier.
*
* @return The accuracy.
*/
public double getAccuracy() {
// A double divides an int gets another double.
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
} // Of if
} // Of for i
return tempCorrect / testingSet.length;
}// Of getAccuracy
/**
* Compute the nearest k neighbors. Select one neighbor in each scan. In
* fact we can scan only once. You may implement it by yourself.
*
* @param paraK
* the k value for kNN.
* @param paraCurrent
* current instance. We are comparing it with all others.
* @return the indices of the nearest instances.
*/
public int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempDistance;
double tempMinimalDistance;
int tempMinimalIndex = 0;
// Select the nearest paraK indices.
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
} // Of if
tempDistance = distance(paraCurrent, trainingSet[j]);
if (tempDistance < tempMinimalDistance) {
tempMinimalDistance = tempDistance;
tempMinimalIndex = j;
} // Of if
} // Of for j
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
} // Of for i
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}// Of computeNearests
/**
* Voting using the instances.
*
* @param paraNeighbors
* The indices of the neighbors.
* @return The predicted label.
*/
public int simpleVoting(int[] paraNeighbors) {
int[] tempVotes = new int[dataset.numClasses()];
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
} // Of for i
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
} // Of if
} // Of for i
return tempMaximalVotingIndex;
}// Of simpleVoting
/**
* The entrance of the program.
*
* @param args
* Not used now.
*/
public static void main(String args[]) {
KnnClassification tempClassifier = new KnnClassification("D:/00/data/iris.arff");
tempClassifier.splitTrainingTesting(0.8);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}// Of main
}// Of class KnnClassification
运行结果:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)