手写字符识别是我大学做的第一个深度学习
本文章用到的代码链接:GitHub - mivlab/AI_course
二、设置环境
可以先直接运行主要训练代码,然后根据报错安装代码所需要的环境,可以根据网上指南来学习安装不同的环境
三、关于数据集
训练集和验证集已经和代码打包好一起下载了。
同时被打包的还有性别识别的数据集,这个本次学习中我们不需用到,但了解原理后只需改变数据集就可以进行性别识别了。
总共有60000个数据,都放在了total.txt中
可以看到我们用了48000个训练数据
同时也准备了12000个验证数据
对于数据集的处理,我们可以看到文件夹中代码mnist_loader.py里的函数
在shuffle_split函数中,首先对数据集用random进行了打乱,然后取前面80%作为了训练集,后面20%作为了验证集。
四、代码理解
在代码mnist_loader.py中我们定义了一个类,其中包括了三个函数,分别是初始化函数,getitem和len三个函数,其中getitem是获取一张图片,以及他的标签。len是用来计算数据集的规模大小,即计算有多少张图、多少个样
本。
回到训练素材的代码,训练代码主要就是定义了一个train()的训练函数
首先是做好了数据加载器,其中ToTensor()的作用是把形如255的像素除以256使其变量取在0到1之间。
做好数据集后就要定义网络模型了,代码中在models的文件夹中cnn.py的代码里定义了Net(nn.Module)这个类,并定义了foward()函数来输出数据
定义完了模型就准备开始训练了。
原始下载的代码中采用了CUDA来进行训练,如果没有显卡或显卡不支持,只需讲代码中的CUDA统一改成CPU,让CPU来运行即可。
在代码的第49行开始正式训练。
五、运行结果
epoch:1/30就代表总共训练30个epoch,一个epoch代表所有训练数据过一遍,总共有188个批次,Train Loss就是训练函数的值,Acc则是训练函数的准确率。
训练视频如下:
20220109
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)