PyTorch 4.DataLoader与Dataset

PyTorch 4.DataLoader与Dataset,第1张

PyTorch 4.DataLoader与Dataset

PyTorch 4.DataLoader与Dataset
  • torch.utils.data.DataLoader
  • torch.utils.data.Dataset

torch.utils.data.DataLoader
DataLoader(
			dataset,
			batch_size=1,
			shuffle=False,
			sampler=None,
			batch_sampler=None,
			num_workers=0,
			collate_fn=None,
			pin_memory=False,
			timeout=0,
			worker_init_fn=None,
			multiprocessing_context=None
)

功能:构建可迭代的数据装载器
dataset:Dataset类,决定数据从哪读取几如何读取
batchsize:批大小
num_works:是否多进程读取数据
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

  1. 读哪些数据?->Sampler输出的index
  2. 从哪读数据?->Dataset中的data_dir
  3. 怎么读数据? ->Dataset中的getitem
torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
Dataset的目标是根据输入的索引输出对应的image和label,而且这个功能是要在__getitem__()函数中完成的,所以当自定义数据集时,首先要继承Dataset类,还要复写__getitem__()函数

例子:

import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader

class My_dataset(Dataset):
	def __init__(self):
		super().__init__()
		# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
		# 以下数据组织这块即可以放在init方法,也可以放在getitem方法里
		self.x = torch.randn(1000,3)
		self.y = self.x.sum(axis=1)
		self.src, self.trg = [], []
		for i in range(1000):
			self.src.append(self.x[i])
			self.trg.append(self.y[i])
	def __getitem__(self,index):
		return self.src[index], self.trg[index]
	
	def __len__(self):
		return len(self.src)		

参考:
https://zhuanlan.zhihu.com/p/144373921

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5572431.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-14
下一篇 2022-12-14

发表评论

登录后才能评论

评论列表(0条)

保存