- torch.utils.data.DataLoader
- torch.utils.data.Dataset
DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, timeout=0, worker_init_fn=None, multiprocessing_context=None )
功能:构建可迭代的数据装载器
dataset:Dataset类,决定数据从哪读取几如何读取
batchsize:批大小
num_works:是否多进程读取数据
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
- 读哪些数据?->Sampler输出的index
- 从哪读数据?->Dataset中的data_dir
- 怎么读数据? ->Dataset中的getitem
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
Dataset的目标是根据输入的索引输出对应的image和label,而且这个功能是要在__getitem__()函数中完成的,所以当自定义数据集时,首先要继承Dataset类,还要复写__getitem__()函数
例子:
import torch from torch import nn from torch.utils.data import Dataset,DataLoader class My_dataset(Dataset): def __init__(self): super().__init__() # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。 # 以下数据组织这块即可以放在init方法,也可以放在getitem方法里 self.x = torch.randn(1000,3) self.y = self.x.sum(axis=1) self.src, self.trg = [], [] for i in range(1000): self.src.append(self.x[i]) self.trg.append(self.y[i]) def __getitem__(self,index): return self.src[index], self.trg[index] def __len__(self): return len(self.src)
参考:
https://zhuanlan.zhihu.com/p/144373921
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)