手写数字识别分类器

手写数字识别分类器,第1张

目录

使用MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。
该数据集分成训练集(前6万张图片)和测试集(最后1万张图片)

1.训练二元分类器

先简化问题,只尝试识别一个数字,比如数字5

from sklearn.datasets import fetch_openml  # 我的sklearn版本为1.0.2
from sklearn.linear_model import SGDClassifier
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 下载mnist数据集,fetch_openml默认返回的是一个DataFrame,设置as_frame=False返回一个Bunch
# mnist.keys() 可查看所有的键
# data键,包含一个数组,每个实例为一行,每个特征为一列。
# target键,包含一个带有标记的数组
mnist = fetch_openml('mnist_784', version=1, as_frame=False)

# 共有7万张图片,因为图片是28×28像素,所以每张图片有784个特征,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)
x, y = mnist["data"], mnist["target"] # x.shape=(70000, 784),y.shape=(70000,)
y = y.astype(np.uint8) # 注意标签是字符,我们把y转换成整数
# 将数据集分为训练集和测试集
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]

# 我们可以看到第一张图片是5
some_digit = x[0]
some_digit_image = some_digit.reshape(28, 28) # 把长为784的一维数组转换成28x28的二维数组
# imshow用于生成图像,参数cmap用于设置图的Colormap,如果将当前图窗比作一幅简笔画,则cmap就代表颜料盘的配色
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off") # 关掉坐标轴
plt.show()

# 使用随机梯度下降(SGD)分类器,比如Scikit-Learn的SGDClassifier类
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
# max_iter最大迭代次数,random_state用于打乱数据,42表示一个随机数种子
sgd_clf = SGDClassifier(max_iter=1000, random_state=42)
sgd_clf.fit(x_train, y_train_5)  # 在整个训练集上进行训练

# 模型预测
print(sgd_clf.predict([some_digit]))  # 返回true
2.性能测量 ①交叉验证(Cross-validation)

交叉验证就是将拿到的训练数据,分为训练和验证集。首先用训练集对模型进行训练,再利用验证集来测试该模型。
n折交叉验证:将数据分成n份,其中1份作为验证集。然后经过n次测试,每次都更换不同的验证集。得到n个结果,取平均值作为最终结果。

from sklearn.model_selection import cross_val_score
# sgd_clf是分类器,x_train表示训练实例,y_train_5表示每个训练实例对应的标签,cv=3表示3-折交叉验证,accuracy表示使用准确率作为结果的度量指标
cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring="accuracy") 

结果

array([0.95035, 0.96035, 0.9604 ])

但是,如果现在我有一个模型,对每张图片都判定为“非5” ,考虑到所有图片中有约10%的图片是5,这种模型训练下来的准确率也能达到90%左右,但是该模型永远正确无法识别“5‘。
这说明准确率通常无法成为分类器的首要性能指标,特别是当你处理有偏数据集时(即某些类比其他类更为频繁)。

②混淆矩阵

评估分类器性能的更好方法是混淆矩阵,其总体思路就是统计A被识别成B的次数。

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

# 获取每次预测的结果,cv=3表示3-折交叉验证
y_train_pred = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3)
# 构造混淆矩阵,y_train_5包含目标类别,y_train_pred包含对应的预测类别
confusion_matrix(y_train_5, y_train_pred)

结果

array([[53892,   687],
       [ 1891,  3530]])
# 真负类(TN) 假正类(FP)
# 假负类(FN) 真正类(TP)
# 真负类: 53892张“非5”图片被正确识别
# 假正类: 687张“非5”图片被错误地识别为“5”
# 假负类: 1891张“5”图片被错误地识别为“非5”
# 真正类:3530张“5”图片被正确识别

一个完美的分类器,它的副对角线的值都为0

# 我们可以假设我们都识别对了,打印出来看看
y_train_perfect_predictions = y_train_5
confusion_matrix(y_train_5, y_train_perfect_predictions)
# 结果
array([[54579,     0],
       [    0,  5421]])
③精度和召回率

精度 (Precision)=TP/(TP+FP) :你认为的该类样本,有多少猜对了(猜的精确性如何)。
召回率 (Recall)=TP/(TP+FN)​​​​​​​:该类样本有多少被找出来了(召回了多少)。
F1分数是精度和召回率的谐波平均值,正常的平均值平等对待所有的值,而谐波平均值会给予低值更高的权重。因此,只有当召回率和精度都很高时,分类器才能得到较高的F1分数。

from sklearn.metrics import precision_score, recall_score, f1_score
precision_score(y_train_5, y_train_pred) # 精度
recall_score(y_train_5, y_train_pred) # 召回率
f1_score(y_train_5, y_train_pred) # f1分数 

F1分数对那些具有相近的精度和召回率的分类器更为有利。
这不一定能一直符合你的期望:在某些情况下,你关心的是精度,而另一些情况下,你关心的是召回率。 例如,
①假设你训练一个分类器来检测儿童可以放心观看的视频,那么你可能更青睐那种拦截了很多好视频(低召回率),但是保留下来的视频都是安全(高精度)的分类器。
②如果你训练一个分类器通过图像监控来检测小偷:你大概可以接受精度只有30%,但召回率能达到99%(当然,安保人员会收到一些错误的警报,但是几乎所有的窃贼都在劫难逃)。

精度/召回率权衡:关键要调整阈值(SGDClassifier分类器使用的阈值是0

怎么理解这幅图:以中间阈值为例,就预测结果而言,右边图片我全部预测为“5”,5个猜对4个,精度:4/5。就整个样本而言,共有6个“5”,我只找到4个“5”,召回率:4/6。
 

 

④ROC曲线

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/717078.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存