【学习记录】用pytorch自己写数据生成器

【学习记录】用pytorch自己写数据生成器,第1张

【学习记录】用pytorch自己写数据生成器
  • 一、先验知识
    • 1.self是啥
    • 2. __init__方法
    • 3.定制序列\定制容器
    • 4.描述符(property原理)
      • 4.1 先说属性访问
      • 4.2 @property装饰器
    • 5.迭代器:\__iter__()和 \__next__()
    • 6.生成器
    • 8. __call__方法
  • 二、实现数据生成器
    • 1.实现dataset
    • 2.DataLoader
      • 2.1 shuffle和batch
      • 2.2 DataLoader的__iter__
    • 3.在一个类中实现
  • 总结


学习过程中的记录和思考,有问题或不严谨的地方还请指出

一、先验知识 1.self是啥

每天都在定义类,每天都在写self,这到底是什么呢
比如我们定义了一个类

class Solution():
	def nit(self,a):
		self.a=a
	def one(self):
		self.a+=1
		print(self.a)

self在这里类似于C++的指针,有种定位的意思,当我们对这个类进行实例化,调用该对象的方法,比如solution=Solution()将Solution这个类实例化为对象
现在我想调用该对象中的one方法,对象就会将自身的引用作为第一个参数传给该方法,那么python就知道该去处理哪一个对象的方法。
我的理解是:

solution=Solution()
#现在想要调用nit方法,本来我需要传入两个参数,一个self,一个a,
#但是现在只要传入a就好了
solution.nit(2)

这里的solution就已经把自己引用成了self传输了进去,所以不再需要传参,同理,如果想调用对象的one方法,直接solution.one()就好了,不再需要传输参数

2. __init__方法

这个方法很熟悉了,一般在定义一个类时需要初始化 就会重写__init__()方法

3.定制序列\定制容器

也就是说,在类中定义了以下方法后,就可以像在列表字符串中一样使用一些方法,提高使用效率。

方法作用
__getitem__方法定义获取容器中指定元素的行为,这样就可以通过对象+key直接访问 s e l f [ k e y ] self[key] self[key]
__setitem__方法(self,key,value)定义设置容器中指定元素的行为,相当于 s e l f [ k e y ] = v a l u e self[key]=value self[key]=value,可以直接来给指定索引位置赋值
__len__方法定义了当被 l e n ( ) len() len()函数调用时候的行为(返回容器中元素的个数),也就是可以对这样对象做len(object)的 *** 作了
4.描述符(property原理) 4.1 先说属性访问
方法作用/含义
. (例如solution.x)通过点 *** 作符访问对象的属性
_get_()方法用于访问属性,它返回属性的值,当访问对象的属性时,该方法会自动调用。第一个参数:这个描述符的拥有者所在的类的实例,第二个参数是描述符的拥有者所在的类本身
_set_()方法将在属性分配 *** 作中调用,不返回任何内容;对对象的属性进行赋值 *** 作的时候,会自动调用该方法
_delete_()方法删除 *** 作,无返回内容
_getattr_(self,name)当用户试图获取一个不存在的属性时的行为
_getattribute_(self,name)当该类的属性被访问时的行为
_setattr_(self,name,value)定义后属性可以被设置
_delattr_(self,name)定义后可以通过该方法删除属性
_dict_以字典的形式显示出当前对象的所有属性以及相对应的值
4.2 @property装饰器

https://blog.csdn.net/qq_37718687/article/details/123877438?spm=1001.2014.3001.5502

5.迭代器:_iter_()和 _next_()

可迭代对象:提供迭代方法的容器成为可迭代对象,通常接触的可迭代对象有序列(如列表、元组、字符串)、字典等,都支持迭代 *** 作,直观理解就是:你都可以通过for循环访问其中的每一个元素

nums=[1,2,3,4,5]
for num in nums:
	print(num)

关于迭代,Python提供了 iter() 和 next() ,对一个可迭代对象调用 iter() 就可以得到它的迭代器,调用 next() 迭代器就会返回下一个值。

  1. _iter_():一个容器如果是迭代器,那就必须实现__iter__()方法,这个方法实际上就是返回迭代器本身
def __iter__(self)
	return self
  1. _next_():它决定了迭代的规则
    当使用next()迭代完存储的所有元素之后,如果继续迭代,则__next__()方法会抛出StopIteration异常。
    可以说,有个这个方法以后我们就可以对这个对象for循环了
class Solution():
	def __init__(self,a,b):
		self.a=a
		self.b=b
	def __iter__(self):
		return self
	def __next__(self):
		self.a+=1
		if self.a>self.b: #这里b的存在是给迭代器设置一个终点,不让它无线循环下去
			raise StopIteration
		return self.a
solution=Solution(1,10)
for solu in solution:
	print(solu)
#会得到答案 2,3,4,5,6,7,8,9,10
so=iter(solution)
next(so)
next(so)#...也是一样的

6.生成器

参考‘小甲鱼’

首先什么是生成器呢?

  1. 生成器其实是迭代器的一种实现,生成器就是一类特殊的迭代器。迭代器需要我们去定义一个类并且实现相关的方法(比如上面我们在类中定义__iter__()和__next__()),而生成器则只需要在普通的函数中加上一个yield即可(下文中会有例子)。
  2. 还有一个更重要的方面是,生成器的存在使得Python模仿协同程序的概念得以实现。所谓协同程序,就是可以运行的独立函数调用,函数可以暂停或者挂起,并在需要的时候从程序离开的地方继续或者重新开始。
    Python通过生成器来实现类似于协同程序的概念:生成器可以暂时挂起函数,并保留函数的局部变量等数据,然后再次调用它的时候,从上次暂停的位置继续执行下去。(可能有点绕,再往下看)

https://fishc.com.cn/thread-56023-1-1.html

一个生成器函数的定义很像一个普通的函数,除了当它要生成一个值的时候,使用 yield 关键字而不是 return。如果一个 def 的主体包含 yield,这个函数会自动变成一个生成器(即使它包含一个 return)。
每当生成器被调用的时候,它会返回一个值给调用者。然后在生成器内部使用yield来完成。除了以上内容,创建一个生成器没有什么多余步骤了。举例来说:

#普通函数
def function1():
	return 1
#生成器
def function2():
	yield 1

上面的function2就是生成器函数,生成器函数会返回生成器的迭代器(就是生成器,很绕吧,再来一遍,生成器的迭代器==生成器)。

  • 使用yield的主要目的是为了边用边生成,当一个生成器函数调用yield,生成器函数的“状态”会被冻结,所有的变量的值会被保留下来,下一行要执行的代码的位置也会被记录,直到再次调用 next()。一旦next() 再次被调用,生成器函数会从它上次离开的地方开始。如果永远不调用 next(),yield 保存的状态就被无视了。
def function(a):
    print("第1天在好好学习!")
    yield a
    
    yield a+1
    print("第2天在好好学习!")
    b=yield a+2
    print("第3天在好好学习!",b)
    yield b

#f就是一个迭代器
f=function(1)
#第1次:输出:第1天在好好学习! \  1
next(f())    #到yield a就停止,并记录这个位置,下次从这里往下运行
#第2次:输出:2
next(f())    #yield a+2就停止
#第3次:输出:第2天在好好学习! \  3
next(f())	 #从上次停止的地方开始
#第4次:输出:第3天在好好学习! None
#为什么b是None呢,因为yield a+2就将返回了并没有赋值给b,所以是None

比较特殊的是:用普通小括号括起来的推导式就是生成器表达式,可以用next()来进行迭代

(i for i in range(10))

copy的一些笔记

generator 是用来产生一系列值的
yield 则像是 generator 函数的返回结果
yield 唯一所做的另一件事就是保存一个 generator 函数的状态
generator 就是一个特殊类型的迭代器(iterator)
和迭代器相似,我们可以通过使用 next() 来从 generator 中获取下一个值
通过隐式地调用 next() 来忽略一些值

8. __call__方法

如果类中定义了__call__()方法,那么该类的实例对象也将成为可调用对象。该对象被调用时,将执行__call__()方法中的代码。该方法相当于’()'这个运算符。比如:

class Solution:
	def __call__(self,a):
		print("代码快进步",a)
solu=Solution()
solu("好的!")
#会输出:”代码快进步 好的!“

这样solu实例对象就变成了可调用对象。

Python 中,凡是可以将 () 直接应用到自身并执行,都称为可调用对象。可调用对象包括自定义的函数、Python 内置函数以及本节所讲的类实例对象。

也就是说,()==_call_()

二、实现数据生成器 1.实现dataset
  • transform定义:(假设输入图像是RGB格式的)首先是用torchvision中的transforms包对图像输入进行预处理,transform这个函数把接收到的图像首先缩放到 256 × 256 × 3 256\times 256\times 3 256×256×3,然后随机截取成 224 × 224 × 3 224\times224\times3 224×224×3,然后转化成tensor格式,totensor会把像素归一化到[0,1]范围内,至于后面的标准化(transforms.Normalize),这篇文章讲的很清楚。

https://zhuanlan.zhihu.com/p/476297637

  • 思考1:Resize()和RandomCrop()有什么区别,具体是怎么 *** 作的【详细过程见Resize()和RandomCrop()的源码】
    假设我们输入的图片大小是(height,weight)
transforms方法具体作用和实现
Resize()缩放,图像的长宽比变换前后没有变化,在下面的例子中,我们输入Resize((256,256)),那么图像会直接缩放成256*256,如果我们只输入一个树,Resize(256),那么resize会找到较短的那个边(比如height>width),把width变成256,然后对height等比缩放,最后变为( 256 × h e i g h t w i d t h 256\times \frac{height}{width} 256×widthheight,size)
RandomCrop()随机裁剪,还是下面的代码,RandomCrop()首先会生成两个数 i , j i,j i,j,这两个数分别在 ( 0 , h e i g h t − 224 ) , ( 0 , w e i g h t − 224 ) (0,height-224),(0,weight-224) (0,height224),(0,weight224)之间,作为裁剪后图片不同方向的下届,最后裁剪出来的图片就是i~i+224, j~j+224
CentorCrop()同RandomCrop(),只是 i 和 j 不再随机生成
  • 又有一个问题,都是变化,它们有什么不同呢,假设原图像是 300 × 300 300 \times 300 300×300,缩放后成了 256 × 256 256 \times 256 256×256,这其中去掉的44个像素是怎么没有的呢?是使用了一些插值方法,具体就不说了,又是很琐碎的一堆知识。相比之下,裁剪和它的名字一样,就是把多余的部分直接剪掉,把i~i+224, j~j+224之外的像素直接丢掉就好了。
    更多的方法细节见源码,源码真的很清楚。
from PIL.Image import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms

transform=transforms.Compose([transforms.Resize((256,256)),
                              transforms.RandomCrop(224),
                              transforms.ToTensor(),
                              transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

def default_loader(path):
    return Image.open(path).convert('RGB')

class MyDataset(Dataset):
    def __init__(self,images,labels,loader=default_loader,transform=None):
        self.images=images
        self.labels=labels
        self.loader=loader
        self.transform=transform

    def __getitem__(self, index):  #返回tensor
        img,target=self.images[index],self.labels[index]
        img=self.loader(img)
        if self.transform is not None:
            img=self.transform(img)
        return img,target

    def __len__(self):
        return len(self.images)

写完MyDataset,接下来就可以直接使用pytorch中的dataloader了

from torch.utils.data import DataLoader
#train_list和train_labels分别是一个图片路径组成的数组和一个标签类别组成的数组
train_loader=DataLoader(MyDataset(train_list,train_labels,transform=centre_crop),
batch_size=batch_size,shuffle=True,num_workers=8)
2.DataLoader

为什么写成上面的Mydataset的形式就可以直接使用DataLoder了呢,这就又要说到DataLoader的原理了(两行搞定的代码原理怎么这么复杂这么复杂)
首先MyDataset()的输出:是索引 i n d e x index index i m g , t a r g e t img,target img,target
接下来看DataLoader的init中比较重要的几个部分:

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
        torch._C._log_api_usage_once("python.data_loader")
        self.dataset = dataset
        ##省略了一些代码
        #重要的部分来了
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = batch_sampler    
2.1 shuffle和batch

有三个部分是对数据的shuffle和batch处理

方法作用\含义
RandomSampler随机采样来返回DataSet的索引位置 ,返回: [ 2 , 3 , 5 , 4 , 6 , 1 , 0 ] [2,3,5,4,6,1,0] [2,3,5,4,6,1,0]
SequentialSampler顺序采样来返回DataSet的索引位置 ,返回: [ 0 , 1 , 2 , 3 , 4 , 5 , 6 ] [0,1,2,3,4,5,6] [0,1,2,3,4,5,6]
BatchSampler当达到一个batch的容量,就会被yield出去,假设batch为3,返回: [ [ 0 , 1 , 2 ] , [ 3 , 4 , 5 ] , [ 6 ] [[0,1,2],[3,4,5],[6] [[0,1,2],[3,4,5],[6]
  • 看一下BatchSampler中__iter__(self)的代码
ef __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

2.2 DataLoader的__iter__
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)
    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

大致思路就是,随机生成数据长度大小的索引数,把读取进来的数据按照前面的索引放入一个又一个的batch中
上面没有考虑num_works>0的情况,有关于多进程的部分感觉这篇文章讲的很仔细

https://blog.csdn.net/g11d111/article/details/81504637

创建了这个DataLoader对象后,就可以循环这个对象加载到模型中训练了

3.在一个类中实现

’来自一位优秀的同门’

class DataLoader(object):
    def __init__(self,X,y,batch_size):
        self.X = X
        self.y = y
        self.length = len(y)
        self.arr = np.array(range(self.length))
        self.batch_size = batch_size
        
    def __iter__(self):
        self.num = 0
        self.seq = np.random.permutation(self.arr)
        return self
        
    def __next__(self):
        if self.num+self.batch_size <= self.length:
            sample = self.seq[self.num:(self.num+self.batch_size)]
            self.image = self.X[sample]
            self.label = self.y[sample]
            self.num += self.batch_size
            return self.image, self.label
        else:
            raise StopIteration
            
    def __len__(self):
        return len(self.y)

总结

平时很简单几行的代码内在逻辑和实现细节真的是琐碎啊,道阻且长,加油喽!

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/755938.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-30
下一篇 2022-04-30

发表评论

登录后才能评论

评论列表(0条)

保存