TensorDataset和DataLoader基本使用方法与数据集切分函数

TensorDataset和DataLoader基本使用方法与数据集切分函数,第1张

TensorDataset和DataLoader基本使用方法与数据切分函数

       在PyTorch中, *** 作数据所需要使用的模块是torch.utils,其中utils.data类下面有大量用来执行数据预处理的工具。在MBSGD中,我们需要将数据划分为许多组特征张量+对应标签的形式,因此最开始我们要将数据的特征张量与标签打包成一个对象。深度学习中的特征张量与标签几乎总是分开的,不像机器学习中标签常常出现在特征矩阵的最后一列或第一列。

      合并张量与标签,我们所使用的类是utils.data.TensorDataset,这个功能类似于utils.data.TensorDataset,这个功能类似于python中的zip,可以将最外面的维度一致的tensor进行打包,也就是将第一维度一致的tensor进行打包。

      当我们将数据打包成一个对象之后,我们需要使用划分小批量的功能DataLoader。DataLoader是处理训练前专用的功能,它可以接受任意形式的数组、张量作为输入,并把他们一次性转换为神经网络可以接入的tensor。

       PyTorch中所有的数据都是Dataset的子类,换而言之就是在使用PyTorch建模训练数据时,需要创建一个和数据集对应的类来表示该数据集。

一、数据读取方法:

1)使用TensorDataset这个简单的类型转化函数,将数据统一转化为”TensorDataset"类然后带入模型进行计算。但TensorDataset其实使用面较窄,最直接的限制就是该函数只能将张量类型转化为TensorDataset类。

2)更加通用的数据读取方式则是手动创建一个继承自torch.utils.data.dataset的数据类,用来作为当前数据的表述,如下面的乳腺癌数据的读取方式。

                  

       接下来,创建一个用于表示该数据集的Dataset的子类,在创建Dataset的子类过程中,必须要重写getitem方法和len方法,其中getitem方法返回输入索引后对应的特征和标签,而len方法则返回数据集的总数据个数。在必须要进行的init初始化过程中,我们也可以输入可代表数据集基本属性的相关内容,包括数据集的特征、标签、大小等等,视情况而定。

       然后,我们可以使用random_split方法对其进行切分。PyTorch的torch.utils.data中,提供了random_split函数可用于数据集切分

       此时切分的结果是一个映射式的对象,只有dataset和indices两个属性,其中dataset属性用于查看原数据集对象,indices属性用于查看切分后数据集的每一条数据的index(序号)。

二、DatasetLoader进行数据加载

       Dataset类主要负责数据类的生成,在PyTorch中,所有数据集都是Dataset的子类;而DatasetLoader类则是加载模型训练的接口。DataLoader进行数据转化,将一般数据状态转化为“可建模”的状态。所谓“可建模”状态,指的是经过DataLoader处理的数据,不仅包含数据原始的数据信息,还包含数据处理方法信息,如调用几个线程进行训练、分多少批次等,DataLoader常用参数如下:

 

 【小结】Dataset和DatasetLoader及数据集切分的基本使用流程如下:

   

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5721124.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-18
下一篇 2022-12-18

发表评论

登录后才能评论

评论列表(0条)

保存