- PyTorch深度学习(五):PyTorch进阶与核心(上)
- 一、torch.utils.data.Dataset类——数据载入
- 1.Dataset类简介
- 2.继承自Dataset的子类,重写方法
- 二、torch.utils.tensorboard——可视化工具
- 1.函数图表可视化
- 2.图片可视化
- 三、torchvision.transforms——图片的变换
- 1.图像转换
- 2.图像裁剪
- 3.旋转和翻转
- 4.数据增强
- 四、torchvision——数据集资源和模型
- 1.torchvision的数据集
- 2.torchvision的模型
- 五、DataLoader——数据集洗牌
首先介绍 Python 学习中两大法宝函数:
dir() 用于显示库/类/对象中可以使用的方法/属性
help() 用于显示方法的官方解释
例如:
一、torch.utils.data.Dataset类——数据载入 1.Dataset类简介
Dataset 提供一种方式去获取 数据data 及其 标签label,并且提供各种数据量指标,我们需要导入这个类:
from torch.utils.data import Dataset
其中, u t i l s tt utils utils 表示工具的意思;由 h e l p ( D a t a s e t ) tt help(Dataset) help(Dataset) 的官方文档可知,当子类需要继承 Dataset 这个类时,需要重写 __getitem__() 和 __len__() 这两个方法,前者用于获取集合中某一个样本的 数据data 及其 标签label,后者用于统计数据集的长度
这里提供一个【公开的花卉分类数据集】用于练习:GitHub上也可以找到
【公开数据集——花卉数据集】,提取码:rbwg
该数据集共分为 17 类不同的花卉,文件的分布结构如下:
F
l
o
w
e
r
s
{
t
r
a
i
n
{
1
{
i
.
j
p
g
(
i
∈
N
1
)
2
{
i
.
j
p
g
(
i
∈
N
2
)
⋮
17
{
i
.
j
p
g
(
i
∈
N
17
)
t
e
s
t
{
1
{
i
.
j
p
g
(
i
∈
N
1
′
)
2
{
i
.
j
p
g
(
i
∈
N
2
′
)
⋮
17
{
i
.
j
p
g
(
i
∈
N
17
′
)
tt Flowersbegin{cases}tt train begin{cases}tt 1 begin{cases} i.jpg(iin N_1) end{cases}\ tt 2 begin{cases} i.jpg(iin N_2) end{cases}\ tt vdots\ tt 17begin{cases} i.jpg(iin N_{17}) end{cases}end{cases}\ tt test begin{cases}tt 1 begin{cases} i.jpg(iin N'_1) end{cases}\ tt 2 begin{cases} i.jpg(iin N'_2) end{cases}\ tt vdots\ tt 17begin{cases} i.jpg(iin N'_{17}) end{cases}end{cases}\end{cases}
Flowers⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧train⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧1 {i.jpg(i∈N1)2 {i.jpg(i∈N2)⋮17{i.jpg(i∈N17)test ⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧1 {i.jpg(i∈N1′)2 {i.jpg(i∈N2′)⋮17{i.jpg(i∈N17′)
可以看到,训练集 t r a i n tt train train 和测试集 t e s t tt test test 分别有 17 类不同的花卉,每一类中都含有数量不等的 . J P G tt .JPG .JPG 图片
2.继承自Dataset的子类,重写方法由于 __getitem__() 方法的作用是获取所有数据集及其 标签label,因此我们需要读取数据集目录下的所有图片:
import os # 导入os image_path_list = os.listdir("./Flowers/train/1") # 获取目录下的所有文件名,并返回列表 print(image_path_list)
尝试将图片显示出来:
import cv2 as cv # 导入opencv-python path = os.path.join("./Flowers/train/1", image_path_list[0]) # 路径拼接 image = cv.imread(path , flags=1) # 读取图片,打开方式为全彩 cv.imshow('Flowers_train_1_1', image) # 显示图片 cv.waitKey(0) # 停留时间无限长
构建自己的数据集子类:
用这个类创建的对象通过下标即可访问到某一类的一张图片及其 标签label: i m a g e , l a b e l = o b j e c t tt image, label = object image,label=object;还可以通过 len(object) 来获取图片的数量
from torch.utils.data import Dataset import os import cv2 as cv class Mydata(Dataset): def __init__(self, data_dir_path, label_dir_name): """初始化类属性""" self.data_dir_path = data_dir_path # 数据所在目录的路径 self.label_dir_name = label_dir_name # 某一类的目录名 self.path = os.path.join(self.data_dir_path, # 合并路径 self.label_dir_name) self.image_path = os.listdir(self.path) # 遍历某一类的目录,获取所有图片名 def __getitem__(self, index): """获取一张图片及其标签""" image_name = self.image_path[index] # 获取图片名 image_item_path = os.path.join(self.data_dir_path, # 合并路径 self.label_dir_name, image_name) image = cv.imread(image_item_path, flags=1) # 读取图片,返回对象 label = self.label_dir_name # 获取图片标签 return image, label def __len__(self): """统计图片数量""" return len(self.image_path)
尝试访问训练集中 label 为 1 1 1(第一类)的第一张图片:
data_dir_path = "./Flowers/train" label_dir_name = "1" train1 = Mydata(data_dir_path, label_dir_name) image, label = train1[0] cv.imshow("Flowers_train_1_1", image) cv.waitKey(0)
还可以分别查看每个 label 里图片的数目,可以看到每个 label 都有60张图片
对于
M
y
d
a
t
a
tt Mydata
Mydata 返回的对象,还可以通过
+
tt +
+ 来合并起来,如下:
每个 label 有
60
60
60 张图片,共
17
17
17 个 label,因此合并后的对象
X
_
t
r
a
i
n
tt X_train
X_train 共有
17
×
20
=
1020
17×20=1020
17×20=1020 张图片,包含了所有的训练集
二、torch.utils.tensorboard——可视化工具
tensorboard 是一个 PyTorch 中的可视化工具(tensorboard 原本是 Tensorflow 的工具,后来给 PyTorch 也加入了这个工具),它可以用来展示网络图、张量的指标变化、张量的分布情况等。特别是在训练网络的时候,可以设置不同的参数(例如:权重、偏移、卷积层数、全连接层数等),使用tensorboard 可以很直观的帮我们进行参数的选择。tensorboard 通过运行一个本地服务器,来监听 6006 6006 6006 端口;在浏览器发出请求时,它会自动分析训练时的数据,绘制训练过程中的可视化图像
tensorboard 需要单独安装,因此需要在对应的虚拟环境下进行安装:
conda install tensorboard
这里面主要使用的是 SummaryWriter 这个函数
from torch.utils.tensorboard import SummaryWriter
如果导入时提示版本过旧,则需要更新:
conda upgrade tensorboard1.函数图表可视化
writer = SummaryWriter(log_dir=None, comment=’’, purge_step=None, max_queue=10, flush_secs=120, filename_suffix=’’)
用于创建一个 Tensorboard 日志的文件夹,并返回一个对象
log_dir (string) 存储 tensorboard 日志等文件的文件夹的位置,default=“runs/当前日期时间”,它在每次运行后更改,习惯上我们通常存放在 “logs” 目录下
comment (string) 在 log_dir 为默认值时,comment 是默认值后添加的后缀,即文件名为:"runs/当前日期时间-comment"
flush_secs (int) 刷新的频率,单位为:秒,默认为 2 2 2 分钟
writer.add_scalar(tag, scalar_value, global_step=None, walltime=None)
向 tensorboard 日志中添加 “一个标量” 的可视化图表,该方法一般用来记录训练过程的 Loss、Accuracy、Learning Rate 等数值的变化,可视化训练过程
tag (string) 图表的标题(也是事件的名称,用于标识一个图表)
scalar_value (float or string/blobname) 相当于因变量/纵坐标,一般取浮点数
global_step 相当于自变量/横坐标,一般指训练的步长
writer.add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None)
向 tensorboard 日志中添加 “多个标量” 的可视化图表,与 a d d _ s c a l a r ( ) tt add_scalar() add_scalar() 不同的是,该方法主要在一个图表中绘制多条曲线
main_tag (string) 图表的标题(也是事件的名称,用于标识一个图表)
tag_scalar_dict ((dict)) 多条曲线需要用字典键值对来表示,如:tag_scalar_dict={“sin”: np.sin(x), “cos”: np.cos(x), “tan”: np.tan(x)},即多个因变量/纵坐标
global_step 多个因变量共同使用一个自变量/横坐标,一般指训练的步长
writer.close() 结束使用
在代码执行后,还需要打开 tensorboard 日志才能看到结果(tensorboard 通过运行一个本地服务器,来监听 6006 6006 6006 端口;在浏览器发出请求时,它会自动分析训练时的数据,绘制训练过程中的可视化图像),在终端输入以下指令:
tensorboard --logdir=logs
如图,进入 http://localhost:6006/ 即可查看日志
- 注: 该指令需要使终端处于当前虚拟环境下,路径可以是相对路径,也可以是绝对路径;其中,logs 表示路径
- 绘制出的图表会自动连续化
- tag 或 main_tag 用于标识一个事件/图表,如果在调用 add_scalar() 或 add_scalars() 出现重复使用同一 tag 或 main_tag 时(也包括 Jupyter Notebook 重复编译的情况),则曲线就会出现在同一张图表上,并且会自动连续化,这样出来的效果会十分乱,解决方法:(1) 删除 logs/ 下的文件、(2) SummaryWriter() 指定新的文件夹
- 在打开日志后,图表下方会出现三个小按钮,其作用如下:
例如: 分别调用 add_scalar() 和 add_scalars(),分别绘制 y = x 2 y=x^2 y=x2 和 y = s i n ( x ) 、 y = c o s ( x ) y=sin(x)、y=cos(x) y=sin(x)、y=cos(x)
from torch.utils.tensorboard import SummaryWriter import numpy as np for i in range(100): writer.add_scalar("y=x^2", i**2, i) # 绘制y=x^2曲线 writer.add_scalars("trigonometric functions", # 分别绘制sin(x)和cos(x)曲线 {"sin": np.sin(i), "cos": np.cos(i)}, i) writer.close()
writer.add_image(tag, img_tensor, global_step=None, walltime=None, dataformats=‘CHW’)
向 Tensorboard 日志中添加一张或多张图片,该方法一般用于实时观察生成式模型的生成效果、可视化分割、目标检测的效果
tag (string) 图片的标题(也是事件的名称,用于标识一个图片显示项)
img_tensor (torch.Tensor, numpy.array, or string/blobname) 图片矩阵,类型可以是:PyTorch 张量、numpy.array…
global_step (int) 当多张图片共用一个 tag 时,global_step 就表示图片的顺序,它一般代表每个训练步长的输出图片
dataformats (string, optional) 可选的,图像数据的格式,default=‘CHW’,即 ( C h a n n e l × H e i g h t × W i d t h ) tt (Channel×Height×Width) (Channel×Height×Width),如: ( 3 × 224 × 224 ) (3×224×224) (3×224×224)
writer.add_image(tag, img_tensor, global_step=None, walltime=None, dataformats=‘NCHW’)
向 Tensorboard 日志中添加一组或多组小批量图片,每组为一个 batch_size 小批量图片堆,tensorboard 会将一个 batch_size小批量图片拼接在一起来显示
注意:一组小批量图片的形状为 ( b a t c h _ s i z e , C , H , W ) ({tt batch_size},C,H,W) (batch_size,C,H,W),这一个批量的图片会被拼接在一起同时显示
tag (string) 组图标题(也是事件的名称,用于标识一个组图显示项)
img_tensor (torch.Tensor, numpy.array, or string/blobname) 形状为一组小批量图片 ( b a t c h _ s i z e , C , H , W ) ({tt batch_size},C,H,W) (batch_size,C,H,W) 或多组小批量图片 ( K , b a t c h _ s i z e , C , H , W ) (K,{tt batch_size},C,H,W) (K,batch_size,C,H,W)
global_step (int) 当多组小批量图片共用一个 tag 时,global_step 就表示组的顺序
dataformats (string, optional) 可选的,图像数据的格式,default=‘NCHW’,即 ( B a t c h _ s i z e × C h a n n e l × H e i g h t × W i d t h ) tt (Batch_size×Channel×Height×Width) (Batch_size×Channel×Height×Width),如: ( 64 × 3 × 224 × 224 ) (64×3×224×224) (64×3×224×224)
例如: 调用 add_image() 来可视化两张图片
from torch.utils.tensorboard import SummaryWriter import numpy as np import cv2 as cv import torch writer = SummaryWriter("logs") image1 = cv.imread("./Flowers/train/1/1.jpg", flags=1) # 读取图片,opencv的图片打开通道默认为BGR image1 = cv.cvtColor(image1, cv.COLOR_BGR2RGB) # 将BGR转换为RGB图像 image1 = torch.tensor(image1) # 将矩阵转换为张量 writer.add_image("Flowers_train_1", image1, 1, dataformats='HWC') # 顺序为1 image2 = cv.imread("./Flowers/train/1/1.jpg", flags=1) # 读取图片,opencv的图片打开通道默认为BGR image2 = cv.cvtColor(image2, cv.COLOR_BGR2RGB) # 将BGR转换为RGB图像 image2 = torch.tensor(image2) # 将矩阵转换为张量 writer.add_image("Flowers_train_1", image2, 2, dataformats='HWC') # 顺序为2
- 注: OpenCV 读取的图片的通道顺序为 BGR,且形状分布为:shape=(H,W,C),数据类型为 numpy.ndarray
可以看到,图片上方有一个滑块
三、torchvision.transforms——图片的变换
transforms 可以将一种特定格式的图片转换为另一种特定格式的图片,其次还包括:图片裁剪、旋转和翻转、数字图像处理和数据增强等功能,transforms 需要在 torchvision 库中导入:
from torchvision import transforms
注意: Transforms 中的很多工具一般都是一个类,它们具有 __call__ 特性(即使用该类实例化的对象可以像调用函数那样被调用,函数功能由类中的 __call__( ) 方法决定),当需要使用这些工具时,需要先使用指定的类来实例化一个对象,再调用对象
1.图像转换2.图像裁剪创建对象:object = transforms.ToTensor()
调用对象:object(image)——将 PIL.Image 类型或 numpy.ndarray 类型的图片转换成张量类型
注意:在调用该对象进行转换时,会将形状为 (H,W,C) 的图片矩阵形状自动转换为 (C,H,W) 的形状,并且将数值归一化到 [ 0 , 1 ] [0,1] [0,1] !
image 图片对象,可以是 PIL.Image 类型或 numpy.ndarray (OpenCV) 类型的图片
例如:from torchvision import transforms import cv2 as cv image = cv.imread("./Flowers/train/1/1.jpg", flags=1) # 读取图片 image = cv.cvtColor(image, cv.COLOR_BGR2RGB) # 通道转换为RGB totensor = transforms.ToTensor() # 创建一个对象 image = totensor(image) # 将numpy.ndarray类型的图片转换为张量 print(type(image)) # 查看转换后的类型 print(image)可以看到类型为:torch.Tensor,并且已归一化到 [ 0 , 1 ] [0,1] [0,1]
创建对象:object = transforms.Normalize(mean, std, inplace=False)
mean (sequence) 均值,需要输入一个序列,分别代表每个通道标准化时的均值,如:[0.5, 0.5, 0.5]
std (sequence) 标准差,需要输入一个序列,分别代表每个通道标准化时的标准差,如:[0.5, 0.5, 0.5]
inplace (bool,optional) 可选的,是否直接原地替换
调用对象:object(tensor)——用均值和标准差来标准化张量图像,标准化公式为: o u t p u t = ( i n p u t − m e a n ) / s t d output=(input-mean)/std output=(input−mean)/std,三个通道分别计算
注意:标准化可以加快模型的收敛,一般要将通道标准化到 [ − 1 , 1 ] [-1,1] [−1,1],因此一般选择:mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]
tensor 张量类型的图片 (C,H,W)
创建对象:object = transforms.Resize(size, interpolation=InterpolationMode.BILINEAR)
size (sequence or int) 图像调整后的大小,输入是一个序列或整型
如果 size 是一个序列:(H,W),将图像大小调整为该参数;
如果 size 是一个整形:int,则图像较短的边缘将调整为该数字,如果 height > width,那么图像将重新缩放到 (size × height / width,size)
调用对象:object(image)——将输入图像的大小调整为给定的大小,如果图像是张量,则可以是 (…,H,W) 的形状,其中 “…” 可以表示任意数量的更高维度
image 图片对象,可以是 PIL.Image 类型或 torch.Tensor 张量,目前无法处理 numpy.ndarray (OpenCV) 类型的图片
上面这些类的使用都只针对图像进行某一转换,如果想让图像实现多种不同的转换,则 PyTorch 提供了一个 Compose 类,它专门把上面这些类所实例化的对象串联起来,以使图像实现多个不同的转换
创建对象:object = transforms.Compose(transforms)
transforms (list or transforms object) 可以是由上述类所实例化的对象组成的列表,也可以是单个对象,一般都是列表
调用对象:object(image)——可以将多个“图片变换”所实例化的对象组合在一起,来使图像实现多种不同的转换
image 图片对象,可以是 PIL.Image 类型或 numpy.ndarray (OpenCV) 类型的图片,还可以是 torch.Tensor 张量
例如:# 两个transforms对象 totensor = transforms.ToTensor() resize = transforms.Resize([512, 512]) print(type(image), image.shape) # 串联起来 compose = transforms.Compose([totensor, resize]) image = compose(image) print(type(image), image.shape)
transforms.Compose(transforms) 对后面的对象同样适用
3.旋转和翻转 4.数据增强创建对象:object = transforms.CenterCrop(size)
size (sequence or int) 预期裁剪后的大小,输入一个序列或整型
如果 size 是一个序列:(H,W),将裁剪后的大小为该参数,若 (H,W) 大于原图像,则使用 0 0 0 来填充空白位置;
如果 size 是一个整形:int,则裁剪为宽为该参数的正方形;
如果 size 是一个长度为 1 1 1 的序列:[N],则和整型的情况一致
调用对象:object(image)——裁剪图像的中心部分,如果图像是张量,则可以是 (…,H,W) 的形状,其中 “…” 可以表示任意数量的更高维度
image 图片对象,可以是 PIL.Image 类型或 torch.Tensor 张量,目前无法处理 numpy.ndarray (OpenCV) 类型的图片
创建对象:object = transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)
degrees (sequence or number) 要旋转的度数范围,如果 degrees 是一个数字而不是序列,则度数范围将是 (-degrees,+degrees),设置为 0 0 0 可停用旋转
translate (tuple, optional) 可选的,元素分别为水平平移和垂直平移的最大倍率的元组,如:translate=(a,b),则表示 Δ x Delta x Δx 在范围 [-width×a,+width×a] 中随机采样进行水平位移,并且 Δ y Delta y Δy 在范围 [-height×b,+height×b] 中随机采样进行垂直位移
scale (tuple, optional) 可选的,缩放系数间隔,如:scale=(a,b),将从范围 (a,b) 中随机采取缩放系数进行缩放,默认保持原比例
shear (sequence or number, optional) 可选的,可以选择的度数范围
如果 shear 是一个数字,则执行平行于 x x x 轴角度范围在 (-shear,+shear) 的剪切;
如果 shear 是 2 2 2 个元素的序列,则执行平行于 x x x 轴角度范围在 (shear[0],shear[1]) 的剪切;
如果 shear 是 4 4 4 个元素的序列,则执行平行于 x x x 轴角度范围在 (shear[0],shear[1]) 和平行于 y y y 轴角度范围在 **(shear[2],shear[3]) 的剪切;默认情况下不会应用该剪切
fill (sequence or number) 转换后的图像区域外的像素填充值,default=“0”,如果给定一个值,则该值会分别应用于所有像素点
fillcolor (sequence or number, optional) 已弃用
resample (int, optional) 已弃用
调用对象:object(image)——保持图像中心不变的情况下对图像进行随机的仿射变换
image 图片对象,可以是 PIL.Image 类型或 torch.Tensor 张量,目前无法处理 numpy.ndarray (OpenCV) 类型的图片
四、torchvision——数据集资源和模型 1.torchvision的数据集
导入相关库:
import torchvision
torchvision 数据集在 torchvision.datasets 库中,可直接在线下载的,我们可以在 PyTorch 的官网中看到,torchvision.datasets 中收录了许多数据集,涵盖了图像分类、目标检测和语义分割等数据集,详细文档可见 【PyTorch官方文档】
例如: 以下载 CIFAR 数据集为例
object(PIL.Image.Image) = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
加载(需要先下载) CIFAR 数据集到内存中
root (string) CIFAR 数据集的的路径,若 download=True,且指定路径找不到目标,则在该目录中下载该数据集
train (bool, optional) 是否加载训练集,若 train=True,则加载的是训练集,若 train=False,则加载的是测试集,default=True
transform (callable, optional) 可选的,加载的数据集需要进行的图片的变换,可以是 torchvision.transforms 中的对象
target_transform (callable, optional) 可选的,加载的数据集的标签需要进行的变换
download (bool, optional) 如果 download=True,则下载数据集并存放在 root 中,如果数据集已经下载,则不会再次下载,在解释器中的下载速度较慢,建议直接在 PyTorch 官网中下载:【PyTorch官方文档】(阅读数据集加载函数的源码,其中可以找到下载网站)
返回值 返回 PIL.Image.Image 对象
✅提示:其他数据集的下载参数也是类似的
object[n] 可以通过下标来顺序访问数据集,表示第 n 张图片对象,该对象包含了图像数据 data 和标签 label
object[n][0] 表示第 n 张图片的 data
object[n][1] 表示第 n 张图片的 label
例如:
object.classes 可以访问所有类的名称
例如:
每一个类所对应的图片如下图所示:
例如: 通过 tensorboard 显示训练集的前
10
10
10 张图片
import torchvision from torchvision import transforms from torch.utils.tensorboard import SummaryWriter totensor = transforms.ToTensor() # transforms的ToTensor对象 train_set = torchvision.datasets.CIFAR10("./CIFAR10", train=True, transform=totensor, download=True) # 训练集 test_set = torchvision.datasets.CIFAR10("./CIFAR10", train=False, transform=totensor, download=True) # 测试集 writer = SummaryWriter("CIFAR10") # 建立tensorboard日志 for i in range(10): image, label = train_set[i] writer.add_image("train_set", image, i)
由于图片的像素是
32
×
32
32×32
32×32 的,因此清晰度较低
torchvision 中提供了大量的现成的经过预训练的神经网络模型
五、DataLoader——数据集洗牌
导入相关库:
from torch.utils.data import DataLoader
object(torch.utils.data.dataloader.DataLoader) =
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
对数据集进行打包处理
dataset (Dataset) 需要打包的数据集
batch_size (int, optional) 每次获取数据样本时的批量数
shuffle (bool, optional) 决定是否打乱数据集,若 shuffle=True,则打乱数据集,default=False
sampler (Sampler or Iterable, optional) 当 shuffle=True 时,调用该随机采样器
batch_sampler (Sampler or Iterable, optional)
num_workers (int, optional) 读取数据的进程个数,当数据量较大时,可以使用多个进程来读取数据, 0 0 0 表示默认使用主进程加载数据
drop_last (bool, optional) 若 batch_size 和数据集总数不是整除关系,这意味着最后剩余的一些数据无法组成一个 batch_size,而 drop_last 就决定了是否舍弃剩余的数据集,若 drop_last=True,则舍弃剩余数据集;若 drop_last=False,则不舍弃
注意:打包后会将原数据 ( C , H , W ) (C,H,W) (C,H,W) 打包成 ( b a t c h _ s i z e , C , H , W ) ({tt batch_size},C,H,W) (batch_size,C,H,W),并将标签打包成一维张量: ( l a b e l 1 , l a b e l 2 , . . . , l a b e l b a t c h _ s i z e ) (tt label_1,label_2,...,label_{batch_size}) (label1,label2,...,labelbatch_size);
打包后的数据集通过 tensorboard 来显示图片时需要调用 add_images() 而不是 add_image()
例如: 使用 DataLoader 处理 CIFAR10 测试集(为了加载速度更快,我们使用测试集来举例),要求 batch_size 为 4 4 4
import torchvision from torchvision import transforms from torch.utils.data import DataLoader # 加载CIFAR10数据集 totensor = transforms.ToTensor() train_set = torchvision.datasets.CIFAR10("./CIFAR10", train=True, transform=totensor, download=True) test_set = torchvision.datasets.CIFAR10("./CIFAR10", train=False, transform=totensor, download=True) test_loader = DataLoader(dataset=test_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False) for data in test_loader: # 打印数据集 image, label = data print("image:", image.shape, "label:", label)
输出的结果如下(输出的结果过多,这里只展示部分结果),可见数据已打包成
(
4
,
C
,
H
,
W
)
(4,C,H,W)
(4,C,H,W)
我们使用 tensorboard 来可视化一下,观察两次 epoch 中,测试集的顺序是否相同:
writer = SummaryWriter("logs") for epoch in range(2): step = 0 for data in test_loader: images, labels = data writer.add_images("test_loader", images, step) step = step + 1
通过对比可以发现,每个 epoch 所使用的测试集的顺序是不一样的
参考资料:
[1]PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
[2]PyTorch官网文档
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)