手写字符识别理解

手写字符识别理解,第1张

手写字符识别理解

手写字符识别是我大学做的第一个深度学习

一、下载代码训练素材

本文章用到的代码链接:GitHub - mivlab/AI_course

二、设置环境

可以先直接运行主要训练代码,然后根据报错安装代码所需要的环境,可以根据网上指南来学习安装不同的环境

三、关于数据

训练集和验证集已经和代码打包好一起下载了。

同时被打包的还有性别识别的数据集,这个本次学习中我们不需用到,但了解原理后只需改变数据集就可以进行性别识别了。

总共有60000个数据,都放在了total.txt中

 可以看到我们用了48000个训练数据

同时也准备了12000个验证数据 

对于数据集的处理,我们可以看到文件夹中代码mnist_loader.py里的函数

在shuffle_split函数中,首先对数据集用random进行了打乱,然后取前面80%作为了训练集,后面20%作为了验证集。 

 四、代码理解

在代码mnist_loader.py中我们定义了一个类,其中包括了三个函数,分别是初始化函数,getitem和len三个函数,其中getitem是获取一张图片,以及他的标签。len是用来计算数据集的规模大小,即计算有多少张图、多少个样

本。

 回到训练素材的代码,训练代码主要就是定义了一个train()的训练函数

首先是做好了数据加载器,其中ToTensor()的作用是把形如255的像素除以256使其变量取在0到1之间。

做好数据集后就要定义网络模型了,代码中在models的文件夹中cnn.py的代码里定义了Net(nn.Module)这个类,并定义了foward()函数来输出数据

 定义完了模型就准备开始训练了。

原始下载的代码中采用了CUDA来进行训练,如果没有显卡或显卡不支持,只需讲代码中的CUDA统一改成CPU,让CPU来运行即可。

在代码的第49行开始正式训练。

 五、运行结果

epoch:1/30就代表总共训练30个epoch,一个epoch代表所有训练数据过一遍,总共有188个批次,Train Loss就是训练函数的值,Acc则是训练函数的准确率。

训练视频如下:

20220109

 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5700507.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存