https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR100.html#torchvision.datasets.CIFAR100
torchvision.datasets中提供了一些经典数据集,其中最为常用的是cifar10/100,mnist,在搓增量学习、领域自适应、主动学习等任务时经常需要打交道。
这里我们以cifar100为例看一下其基本的用法。
首先,下载训练集与测试集:
from torchvision import datasets
train_dataset = datasets.cifar.CIFAR100(root='cifar100', train=True, transform=None, download=True)
test_dataset = datasets.cifar.CIFAR100(root='cifar100', train=False, transform=None, download=True)
可以看到有四个参数:
- root:数据集文件的存储路径。
- train:是否为训练集。
True则表示视为训练集,False表示视为测试集。
- transform:所应用的数据扩充方法。
- download:是否下载。
如果为True且root路径下无相应的数据集文件,则自动从互联网上下载数据集至给定路径。
到这里,严格来讲就算介绍完毕了,因为这里得到的train_dataset和test_dataset都属于torch.utils.data.Dataset对象,剩下来的用法和我们自己手工封装数据集的一致。
现在,我们着重考察下cifar100数据集的结构。
首先,直接用下标去访问一个数据集对象:
print(train_dataset[0])
输出结果如下:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x25568409670>, 19)
可以看到得到的是一个tuple,第一项为image,第二项为ground truth,即图像所属的分类,使用一个int值表示。
在训练计算损失函数的时候,直接使用F.cross_entropy即可,而不需要考虑将int标签转化成独热向量的形式。
而对于image,在实际训练中读数据集时transform一项必然带有transforms.ToTensor(),在这种情况下返回的则是一个Tensor向量以便网络训练。
此外,还有另一种访问方法:
print(train_dataset.data[0])
print(train_dataset.targets[0])
输出结果如下:
[[[255 255 255]
[255 255 255]
[255 255 255]
...
]]]
19
此时,标签仍为int不变,而数据返回的形式为ndarray:
print(train_dataset.data[0].shape)
输出结果如下:
(32, 32, 3)
可以看到cifar100图像的尺寸为32×32×3(3表示通道数)。
这里我们也顺便将这张图像展示:
train_dataset[0][0].show()
因为只有1024像素所以非常小。
查阅类别表可以发现类-19对应cattle(牛),这也与我们的观察相吻合。
现在我们来看另一个问题,可以发现第一张图像的类为19,那么这就表明,整个数据集50000张训练图像并不是按类别顺序进行划分的,即甚至可以在使用dataloader时不打开shuffle。
为了验证这一点,我们直接输出train_dataset.targets中的前10个标签:
print(train_dataset.targets[:10])
结果如下:
[19, 29, 0, 11, 1, 86, 90, 28, 23, 31]
可以发现确实是乱序的。
但是在一些奇怪的任务里面,会要求类别是有序的(实际上我们自己做的数据集也是尽量有序的,需要打乱通过dataloader实现即可),那么这里就看一下怎么去弄。
具体来说,肯定是从targets中所包含的类别信息入手。
首先将其从list转为ndarray方便我们使用numpy去 *** 作:
train_targets = np.array(train_dataset.targets)
那么,比方说我们要取出类别在[10, 19]内的全部样本,就可以把相应的target先取出来。
这里先得到一个长度为50000的包含每个元素是否满足条件的列表:
idx = np.logical_and(train_targets >= 10, train_targets < 20)
print(len(idx))
print(idx)
输出:
50000
[ True False False ... False False False]
然后使用np.where()进行广播,获得具体的下标:
idx = np.where(idx)
print(len(idx[0]))
print(idx[0])
可以看到满足要求的target所对应的下标共有5000个,确实是十分之一:
5000
[ 0 3 13 ... 49977 49981 49991]
最后将这些下标作为索引取出相应的子数据集即可:
train_data = np.array(train_dataset.data)
selected_data = train_data[idx[0]]
selected_target = train_targets[idx[0]]
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)