简单实现Dataset和Dataloader |
- 一. 定义Dataset
- 二. 结合Dataloader
- 三. Python魔术方法
- 3.1. 运算符重载
- 3.2. 打印 *** 作的魔法方法
- 3.3. 属性 *** 作的魔法方法
- 3.4. 描述符
- 3.5. 定制序列(常用)
- 3.6. 迭代器(常用)
- 3.7. 补充:Python生成器(yield)
- 四. 参考文章
- 在Python中,所有以
“__”
双下划线包起来的方法,都统称为“Magic Method”
,中文称魔术方法
,例如类的初始化方法__init__
。Python的魔法方法会在特定的情况下自动调用。- 详解Python常用的魔法方法:https://www.jb51.net/article/214097.htm
import random
class MyDataset:
def __init__(self, all_datas, batch_size, shuffle=True):
self.all_datas = all_datas
self.batch_size = batch_size
self.shuffle = shuffle
self.cursor = 0 # 边界
# Python的魔术方法,在某种场景自动触发的方法
def __iter__(self): #### 返回一个具有__next__的对象。
print("hello __iter__")
if self.shuffle:
random.shuffle(self.all_datas)
self.cursor = 0
return self
def __next__(self):
if self.cursor >= len(self.all_datas):
raise StopIteration
batch_data = self.all_datas[self.cursor:self.cursor + self.batch_size]
self.cursor += self.batch_size
return batch_data
if __name__ == '__main__':
all_datas = [1, 2, 3, 4, 5, 6, 7]
batch_size = 2
shuffle = True
dataset = MyDataset(all_datas, batch_size, shuffle)
for e in range(2):
#### 如若dataset对象放在for循环上时会自动调用这个对象的__iter__,但只会触发一次
for batch_data in dataset:
print(batch_data)
# hello
# __iter__
# [2, 7]
# [1, 3]
# [4, 6]
# [5]
# hello
# __iter__
# [5, 6]
# [1, 4]
# [2, 3]
# [7]
二. 结合Dataloader
import random
import numpy as np
class MyDataset:
def __init__(self, all_datas, batch_size, shuffle=True):
self.all_datas = all_datas
self.batch_size = batch_size
self.shuffle = shuffle
self.cursor = 0 # 边界
# Python的魔术方法,在某种场景自动触发的方法
def __iter__(self): #### 返回一个具有__next__的对象。
return DataLoader(self)
def __len__(self):
return len(self.all_datas)
class DataLoader:
def __init__(self, dataset):
self.dataset = dataset
self.indexs = [i for i in range(len(self.dataset.all_datas))]
if self.dataset.shuffle == True:
random.shuffle(self.indexs) # 打乱索引
self.cursor = 0
def __next__(self):
if self.cursor >= len(self.dataset.all_datas):
raise StopIteration
index = self.indexs[self.cursor:self.cursor + self.dataset.batch_size]
batch_data = self.dataset.all_datas[index]
self.cursor += self.dataset.batch_size
return batch_data
if __name__ == '__main__':
all_datas = np.array([1, 2, 3, 4, 5, 6, 7])
batch_size = 2
shuffle = True
dataset = MyDataset(all_datas, batch_size, shuffle)
for e in range(2):
for batch_data in dataset: #### 如若dataset对象放在for循环上时会自动调用这个对象的__iter__,只会触发一次
print(batch_data)
三. Python魔术方法
3.1. 运算符重载
- 在Python中,所有以
“__”
双下划线包起来的方法,都统称为“Magic Method”
,中文称魔术方法
,例如类的初始化方法__init__
。Python的魔法方法会在特定的情况下自动调用。
- Python中同样有运算符重载,其实所有的运算符都是使用了对应的魔法方法来处理的对象的,魔法方法对应的 *** 作符如下
class A:
def __init__(self, x):
self.x = x
def __add__(self, other):
return int(self.x) + int(other.x)
a = A(1.2)
b = A(2.5)
print(a + b)
3.2. 打印 *** 作的魔法方法
- 在幕后,
for
语句会在容器对象上调用iter()
。 该函数返回一个定义了__next__()
方法的迭代器对象,此方法将逐一访问容器中的元素。 当元素用尽时,__next__()
将引发StopIteration
异常来通知终止 for 循环。
import time
import memory_profiler as mem
# Python Generator
# 1. 什么是Generator: 是一个生成器,可以生成一个个东西,通过next()方法,是一个iterable
# 2. 为什么需要用Generator?列表很大的时候,Generator按需给你产生,不会一次性生成而占用大量内存;数据不是一次性读入,节省内存。
# 最简单的例子
nums = [1, 2, 3, 4, 5]
squred_nums = [i ** 2 for i in nums] # 列表推导式
print(squred_nums)
squred_nums_2 = (i ** 2 for i in nums) # 用括号括起来,就是Gennerator
print(squred_nums_2)
print(next(squred_nums_2))
print(next(squred_nums_2))
print(next(squred_nums_2))
print(next(squred_nums_2))
print(next(squred_nums_2))
# 数量大一些
yi = 10000000
start = time.time()
print(f"内存前:{mem.memory_usage()}")
nums = list(range(10000000))
squred_nums = [i ** 2 * yi for i in nums] # 生成1千万数字存在内存中,消耗大约900兆
end1 = time.time()
print(end1 - start)
print(f"内存后1:{mem.memory_usage()}")
# 运行可以发现Generator速度快,消耗内存低,比如:列表推导式可以看成买了1000个汉堡放家里,占地方;生成器表示买了一个做汉堡的机器,想吃了做一点
# 调用next()方法,生成一个汉堡。
squred_nums_1 = (i ** 2 * yi for i in nums) # Generator并没有在内存中生成,而是生成了一个生成器,需要找他要就行了。
print(time.time() - end1)
print(f"内存后2:{mem.memory_usage()}")
print("==========" * 10)
## yield生成器的例子:yield和return一样的地方都返回了值,不一样的地方yield所在的函数并没有真正的结束,
## 下次可以继续调用next,会在上次执行的yield位置继续执行。
def gen_nums(nums1):
for n in nums1:
if n % 3 == 0:
yield 3 * yi
elif n % 5 == 0:
yield 5 * yi
else:
yield n * yi
def calc_nums(nums):
new_nums = []
for n in nums:
if n % 3 == 0:
new_nums.append(3 * yi)
elif n % 5 == 0:
new_nums.append(5 * yi)
else:
new_nums.append(n * yi)
return new_nums
nums = list(range(10))
cnums = calc_nums(nums)
gnums = gen_nums(nums)
for n in cnums: # 效果一样
print(n)
print("*" * 10)
for n in gnums: # 效果一样,gnums是一个生成器,For中调用next()这个时候送入生成器,yield给一个返回值,继续调用next()
print(n)
[1, 4, 9, 16, 25]
<generator object <genexpr> at 0x7f07df610250>
1
4
9
16
25
内存前:[61.30078125]
3.3733038902282715
内存后1:[989.6875]
0.10051727294921875
内存后2:[989.6875]
=====================================================================================
30000000
10000000
20000000
30000000
40000000
50000000
30000000
70000000
80000000
30000000
**********
30000000
10000000
20000000
30000000
40000000
50000000
30000000
70000000
80000000
30000000
Process finished with exit code 0
四. 参考文章
- 详解Python常用的魔法方法:https://www.jb51.net/article/214097.htm
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)