我遇到了同样的问题,我设法通过定义一个
__next__方法解决了这个问题:
class My_Generator(Sequence): def __init__(self, image_filenames, labels, batch_size): self.image_filenames, self.labels = image_filenames, labels self.batch_size = batch_size self.n = 0 self.max = self.__len__() def __len__(self): return np.ceil(len(self.image_filenames) / float(self.batch_size)) def __getitem__(self, idx): batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size] batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size] return np.array([ resize(imread(file_name), (200, 200))for file_name in batch_x]), np.array(batch_y) def __next__(self): if self.n >= self.max:self.n = 0 result = self.__getitem__(self.n) self.n += 1 return result
请注意,我在
__init__函数中声明了两个新变量。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)