目录
1. Dataset类 和 Dataloader类 介绍
2. 查看 Dataset 使用
3. Dataset 使用实例
【重点】4. DataLoader 使用示例
1. Dataset类 和 Dataloader类 介绍
DataSet | Dataloader |
提供一种方式去获取数据及其label 如何获取每一个数据及其label 告诉我们总共有多少的数据 | 为后面的网络提供不同的数据形式 |
在 jupyter notebook 中依次输入这三条语句,可以比较清晰的看到Dataset使用说明
from torch.utils.data import Dataset
help(Dataset)
DataSet??
3. Dataset 使用实例
数据集下载:
https://download.pytorch.org/tutorial/hymenoptera_data.ziphttp://xn--ghqv0hnrco1y6hay8g1u5dq0qia59jfip33gbebd90d
# 作者:要努力,努力,再努力
# 开发时间:2022/5/2 8:56
from PIL import Image
from torch.utils.data import Dataset
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.img_path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.img_path)
def __getitem__(self, idx):
img_item_path = os.path.join(self.root_dir, self.label_dir, self.img_path_list[idx])
img = Image.open(img_item_path)
file_name = 'D:/Workspaces/pycharm/pythontorch/' + self.root_dir + '/ants_labels/' + self.img_path_list[idx] + '.txt'
if not os.path.exists(file_name):
with open(file_name, 'w', encoding='utf-8') as rfile:
rfile.write(self.label_dir)
label = self.label_dir
return img
def __len__(self):
return len(self.img_path_list)
root_dir = 'dataset/train'
ants_label_dir = 'ants_images'
bees_root_dir = 'bees_images'
if __name__ == '__main__':
ants_data = MyData(root_dir, ants_label_dir)
print(len(ants_data))
for i in range(len(ants_data)):
ant_img = ants_data[i]
# ant_img.show()
bees_data = MyData(root_dir, bees_root_dir)
bee_img = bees_data[0]
# bee_img.show()
total = ants_data + bees_data # 继承 Dataset 后可以直接相加
# total[1].show()
【重点】4. DataLoader 使用示例
# 作者:要努力,努力,再努力
# 开发时间:2022/5/4 13:52
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision import datasets
test_set = datasets.CIFAR10(root='../chap04dataset_transform/dataset', train=False, transform=transforms.ToTensor(),
download=True)
# batch_size 每次从测试数据集中取四个数据
# shuffle 是否打乱, true 不一样 false 一样
# num_workers 默认为0 表示主进程 num_workers>0 在window可能会为0
# drop_last 最后的数据是否舍弃,true 舍弃,false 不舍弃
# 打断点 可以看到是RandomSample ,默认用随机抓取的方式
test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# CIFAR10 中 __getitem__ 返回 img, target
img, target = test_set[0] # 测试数据集中第一张图片
print(img.shape) # 图片大小
print(target) # 图片位置
writer = SummaryWriter('logs')
for ep in range(2):
step = 0
for datas in test_loader:
imgs, targets = datas
# print(imgs.shape) # torch.Size([4, 3, 32, 32]) 4张图片-3通道-32*32
# print(targets) # tensor([8, 1, 3, 6]) 取target 为 8 1 3 6 的图片
writer.add_images('ep:{}'.format(ep), imgs, step) # 注意不是 add_image()
step += 1
writer.close()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)