shuffle用来将数据随机打乱,batch_size为2,则两个为一组
codeimport torch from torch.utils.data import Dataset #抽象类,需要继承 from torch.utils.data import DataLoader import numpy as np #此处的Dataset,表示继承这个类 class DiabetesDataset(Dataset): def __init__(self,filepath): xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32) self.len = xy.shape[0] #xy.shape 是一个元祖,该元祖包含xy的维数:N行9列 self.x_data = torch.from_numpy(xy[:,:-1]) self.y_data = torch.from_numpy(xy[:,[-1]]) def __getitem__(self, index): return self.x_data[index],self.y_data[index] def __len__(self): return self.len class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear1 = torch.nn.Linear(8,6) self.linear2 = torch.nn.Linear(6,4) self.linear3 = torch.nn.Linear(4,1) self.sigmoid = torch.nn.Sigmoid() def forward(self,x): x = self.sigmoid(self.linear1(x)) x = self.sigmoid(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x model = Model() dataset = DiabetesDataset('diabetes.csv.gz') #第四个参数,需要多少个并行进程来读取数据,2个进程,每一个读取32个数据 train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2) criterion = torch.nn.BCELoss(size_average=True) optimizer = torch.optim.SGD(model.parameters(),lr=0.01) if __name__ == '__main__': for epoch in range(100): #enumerate 可以反馈当前的迭代次数,即第几次迭代 for i,data in enumerate(train_loader,0): inputs,labels = data y_pred = model(inputs) loss = criterion(y_pred,labels) print(epoch,i,loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)