朴素贝叶斯分类器是基于贝叶斯定理所实现的一种对于数据分类的算法,本文章将先对贝叶斯定理和对于机器学习(数据分类)上的作用进行简单介绍,然后通过代码(python)实现朴素贝叶斯分类器识别MNIST数据集的手写数字。
贝叶斯定理
在机器学习中, 通常为数据(feature),为数据预测的模型(label)。
对于贝叶斯定理每一项的解释:
: 先验概率,在观测具体数据前预测某种数据标签(label)的概率。
: 似然值,在预测h下数据的概率。
: 数据出现的概率。
: 后验概率,在观测到数据后预测某种数据标签(label)的概率。
贝叶斯分类器的目标:
挑出能最大化后验概率的数据标签,通俗点解释就是贝叶斯分类器看过了训练是的数据特征(feature)和对应的数据标签(label),对于给出的一组数据特征,贝叶斯分类器可以通过概率计算(基于贝叶斯定理),找出最可能符合的数据标签,也就是能最大化后验概率的标签。
最大后验估计(Maximun A Posteriori)
独立于 ,即可以忽略。
最大似然估计(Maximum Likelihood)
在最大后验估计的基础上,假设每种标签出现的概率相等,即,忽略,则目标变为:
例子
令数据特征为 ,数据标签为 ,假设特征有四种,标签有两种,则贝叶斯分类器的目标为:
其中,
则有,
m-估计
由于我们是通过训练的数据计算出来的一系列概率,如果数据的某项特征在训练数据中从未出现,而需要预测标签的数据出现了这种特征,那通过贝叶斯分类器计算出来的项就会为0,从而导致分类器的预测出现问题,m-估计旨在解决该情况。
通常情况的计算的方法为:
,分子为该特征的个数,分母为特定标签下数据的数量。
使用m-估计的计算方法为:
,代表一种先验估计,如果该特征有 种可能的取值,则 ,是手动设置的参数,表示这种先验估计的比重,值越大代表越倾向于使用这种先验估计的方式。
MNIST手写数字识别应用
简单提取数据前1000张图片,900张作为训练分类器的数据,100张作为测试集来测试分类器。
数据预处理
该例子中下载的MNIST数据集如下,将数据除以255使每个像素的值位于0,1之间。该代码获取的X,y数据均为10000组,由于仅用于演示算法,我们只取了其中的前1000张保存为npy文件作为后续使用。
""" File name: lab_supp.py """ import numpy as np from sklearn.datasets import fetch_openml # load the npz file that saved def load_local_data(filename): lab_data = np.load(filename, allow_pickle=True) # return the saved data return lab_data['arr_0'], lab_data['arr_1'] def load_remote_data(): # directly fetch the data from the website X, y = fetch_openml('mnist_784', data_home='./', return_X_y=True) X = X / 255. # save the first 1000 data as npz file X_small = X[:1000] y_small = y[:1000] np.savez("1k_data.npz", X_small, y_small) # return whole data (10000) return X, y
算法应用部分
首先把0-1的像素数据化为0或1的像素数据(方便处理),对于给定的训练数据(特征为28*28的图片,即空间大小为784,像素点的可能取值为0,1),计算:
- ,0-9共十个数字在训练集中的概率分布。(代码中体现为变量 label_matrix)
- ,当数据标签为时,像素的取值(0或1),若使用m-估计, (因为像素只有0,1两种取值),。(代码中体现为变量 feature_matrix)
- 对于 ,计算后验概率:,选出最大的 ,即为预测的数字。(代码中体现为函数 predict)
以下为代码实现:
import lab_supp # lab_supp.py file, used for data preprocessing import numpy as np from collections import Counter # categories of Label, totally 10 (10 numbers) LABEL_CAT = 10 # m-estimate parameter PRIOR_EST = 0.5 m = 3 # normalize the picture pixel to 0 or 1 def normalization_0_1(imgs): num = imgs.shape[0] dim = imgs.shape[1] result = np.zeros((num, dim)) for i in range(num): for j in range(dim): if imgs[i][j] != 0: result[i][j] = 1 else: result[i][j] = 0 return result # calculate required matrix for Naive Bayesian def cal_prob(train_features, train_labels): num = train_features.shape[0] dim = train_features.shape[1] label_matrix = np.zeros((LABEL_CAT,)) feature_matrix = np.zeros((LABEL_CAT, dim)) label_stat = Counter(train_labels) for i in range(LABEL_CAT): label_matrix[i] = label_stat[str(i)] / num for label in range(LABEL_CAT): indexes = np.where(train_labels == str(label)) # get the index of specific label num_label = indexes[0].shape[0] for j in range(dim): times = 0 for index in indexes[0]: if train_features[index][j] == 0: times += 1 feature_matrix[label][j] = (times + m * PRIOR_EST) / (num_label + m) return label_matrix, feature_matrix # prediction process by using required matrix def predict(img, label_matrix, feature_matrix): possibility = np.ones((LABEL_CAT,)) possibility = possibility * label_matrix for i in range(LABEL_CAT): index = 0 for pixel in img: if pixel == 0: possibility[i] = possibility[i] * feature_matrix[i][index] else: possibility[i] = possibility[i] * (1 - feature_matrix[i][index]) index += 1 label = np.where(possibility == max(possibility)) return str(label[0][0]) # program entrance if __name__ == '__main__': X, y = lab_supp.load_local_data('1kpic.npz') # load data X_ = normalization_0_1(X) # normalize picture # split train & test set - 0.9,0.1 X_train = X_[0:900] y_train = y[0:900] X_test = X_[900:] y_test = y[900:] # calculate probability matrix m_label, m_feature = cal_prob(X_train, y_train) # predict the test set num = X_test.shape[0] result = np.zeros((num,), dtype=str) for i in range(num): result[i] = predict(X_test[i], m_label, m_feature) # calculate the error rate error = 0 for i in range(num): if result[i] != y_test[i]: error += 1 error_rate = error / num print("Error rate: ", error_rate)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)