之前遇到的问题是,我自己定义了dataset的类,类似于下面的代码
class DealDataset(Dataset): """ 下载数据、初始化数据,都可以在这里完成 """ def __init__(self): xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据 self.x_data = torch.from_numpy(xy[:, 0:-1]) self.y_data = torch.from_numpy(xy[:, [-1]]) self.len = xy.shape[0] def __getitem__(self, index): x_data=self.x_data[index] y_data=self.y_data[index] return {'x_data':x_data,'y_data':y_data} def __len__(self): return self.len
这样就在读取上非常迷惑,不知道用enumerate和tqdm要怎么读数据,搞清楚后在这里简要记录一下对应关系
1.return {'x_data':x_data,'y_data':y_data},
目前只会用enemerate读取
for i, data in enumerate(train_loader): x_data, y_data= data['x_data'], data['y_data']
2.把return改变,改为return self.x_data[index],self.y_data[index]
这样tqdm读取
for x_data,y_data in tqdm(train_loader):
enumerate读取
for idx,data in enumerate(train_loader): x_label=data[0] y_label=data[1]
另外,除了放在batch那里,tqdm也可以放在epoch的循环那里
for epoch in tqdm(range(100)):
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)