PyTorch-Tutorials【pytorch官方教程中英文详解】- 3 Datasets&DataLoaders

PyTorch-Tutorials【pytorch官方教程中英文详解】- 3 Datasets&DataLoaders,第1张

PyTorch-Tutorials【pytorch官方教程中英文详解】- 3 Datasets&DataLoaders

【2022真的开始了,第一条博客。】

在文章PyTorch-Tutorials【pytorch官方教程中英文详解】- 2 Tensors中介绍了张量,接下来看看pytorch中读取数据集的Datasets&DataLoaders类。

原文链接:Datasets & DataLoaders — PyTorch Tutorials 1.10.1+cu102 documentation

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

【处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码解耦,以获得更好的可读性和模块化。PyTorch提供了两个数据原始组件:torch.utils.data.DataLoader和torch.utils.data.Dataset,它们允许您使用预加载的数据集以及您自己的数据。数据集存储样本及其对应的标签,DataLoader在Dataset周围包装了一个可迭代对象,以方便访问样本。】

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets

【PyTorch域库提供了许多预加载的数据集(如FashionMNIST),它们是torch.util .data. dataset的子类,并实现特定于特定数据的函数。它们可以用于原型和基准测试你的模型。你可以在这里找到它们:Image Datasets、Text Datasets和Audio Datasets】

1 Loading a Dataset

Here is an example of how to load the Fashion-MNIST dataset from TorchVision. Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

【下面是一个如何从TorchVision加载Fashion-MNIST数据集的例子。fashionmnist是Zalando文章图像的数据集,包含6万个训练示例和1万个测试示例。每个示例包含28×28灰度图像和来自10个类之一的相关标签。】

We load the FashionMNIST Dataset with the following parameters:

  • root is the path where the train/test data is stored,
  • train specifies training or test dataset,
  • download=True downloads the data from the internet if it’s not available at root.
  • transform and target_transform specify the feature and label transformations

【我们用以下参数加载FashionMNIST数据集:

  • root是训练/测试数据存储的路径,
  • train指定训练或测试数据集,
  • download=True表示如果无法从根目录获取数据,则从Internet下载数据。
  • transform和target_transform指定特性和标签转换】
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

输出:

2 Iterating and Visualizing the Dataset

We can index Datasets manually like a list: training_data[index]. We use matplotlib to visualize some samples in our training data.

【我们可以像列表一样手动索引数据集:training_data[index]。我们使用matplotlib将训练数据中的一些样本可视化。】

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

输出:

3 Creating a Custom Dataset for your files

A custom Dataset class must implement three functions: __init__, __len__, and __getitem__. Take a look at this implementation; the FashionMNIST images are stored in a directory img_dir, and their labels are stored separately in a CSV file annotations_file.

【自定义Dataset类必须实现三个函数:__init__、__len__和__getitem__。看看这个实现;FashionMNIST图像存储在目录img_dir中,它们的标签分别存储在CSV文件annotations_file中。】

In the next sections, we’ll break down what’s happening in each of these functions.

【在下一节中,我们将详细分析每个函数中发生的事情。】

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
__init__

The __init__ function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).

【当实例化Dataset对象时,__init__函数运行一次。我们初始化包含图像、注释文件和两个转换的目录(下一节将详细介绍)。】

The labels.csv file looks like:

【labels.csv文件如下所示:】

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform
__len__

The __len__ function returns the number of samples in our dataset.

【__len__函数返回数据集中的样本数量。】

Example:

def __len__(self):
    return len(self.img_labels)
__getitem__

The __getitem__ function loads and returns a sample from the dataset at the given index idx. based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a tuple.

【__getitem__函数从给定索引idx处的数据集加载并返回一个示例。基于索引,它识别图像在磁盘上的位置,使用read_image将其转换为一个张量,从self.Img_labels中的csv数据中检索相应的标签,调用它们上的转换函数(如果适用的话),并返回一个元组中的张量图像和相应的标签。】

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label
4 Preparing your data for training with DataLoaders

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

【数据集每次检索一个样本的数据集的特征和标签。在训练模型时,我们通常希望以“小批量”传递样本,在每个epoch重新打乱数据以减少模型过拟合,并使用Python的多进程处理来加速数据检索。】

DataLoader is an iterable that abstracts this complexity for us in an easy API.

【DataLoader是一个可迭代对象,它通过一个简单的API为我们抽象了这种复杂性。】

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
5 Iterate through the DataLoader

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels (containing batch_size=64 features and labels respectively). Because we specified shuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers).

【我们已经将该数据集加载到DataLoader中,并可以根据需要遍历该数据集。下面的每次迭代都返回一批train_features和train_labels(分别包含batch_size=64个特性和标签)。因为我们指定了shuffle=True,所以在我们遍历所有批次之后,数据就会被打乱(对于更细粒度的数据加载顺序的控制,请查看Sampler)。】

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: 
")

输出结果:
 
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) Label: 7
     6 Further Reading
  • torch.utils.data API

说明:记录学习笔记,如果错误欢迎指正!写文章不易,转载请联系我。

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5689287.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论
请登录后评论...
登录
后才能评论 提交

评论列表(0条)
保存
{label}