import torch import torchvision import torch.nn as nn import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplot as plt # define hyper parameters Batch_size = 50 Lr = 0.1 Epoch = 1 train_dataset = torchvision.datasets.SVHN( root='/data/cli/dh/resnet_svhn/result/svhn-data/SVHN', split='train', download=False, transform=torchvision.transforms.ToTensor() ) test_dataset = torchvision.datasets.SVHN( root='/data/cli/dh/resnet_svhn/result/svhn-data/SVHN', split='test', download=False, transform=torchvision.transforms.ToTensor() ) # define train loader train_loader = Data.DataLoader( dataset=train_dataset, shuffle=True, batch_size=Batch_size ) # define test loader test_loader = Data.DataLoader( dataset=test_dataset, shuffle=True, batch_size=Batch_size ) # images_train, labels_train = next(iter(train_loader)) # images_test, labels_test = next(iter(test_loader)) # print("images_train, labels_train",images_train, labels_train) # print("Shape of train inputs: ", images_train.shape, "; Shape of train labels: ", labels_train.shape) # print("Shape of test inputs: ",images_test.shape, "; Shape of test inputs: ", labels_test.shape) # print("Batch size = 64") #取一个固定测试点作为测试数据 dataiter = iter(test_loader) test_x, test_y = dataiter.next() # print("test_x:",test_x) # print("test_y:",test_y) # for i, (x, y) in enumerate(test_loader): # print("i值为:",i) # print("x的值为:",x) # print("y的值为:",y) # print("x:",type(x)) # print("y:",type(y)) # break # construct network class Basicblock(nn.Module): def __init__(self, in_planes, planes, stride=1): super(Basicblock, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(planes), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(in_channels=planes, out_channels=planes, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(planes), ) if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( nn.Conv2d(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(planes) ) else: self.shortcut = nn.Sequential() def forward(self, x): out = self.conv1(x) out = self.conv2(out) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_block, num_classes): super(ResNet, self).__init__() self.in_planes = 16 self.conv1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(16), nn.ReLU() ) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) self.block1 = self._make_layer(block, 16, num_block[0], stride=1) self.block2 = self._make_layer(block, 32, num_block[1], stride=2) self.block3 = self._make_layer(block, 64, num_block[2], stride=2) # self.block4 = self._make_layer(block, 512, num_block[3], stride=2) self.outlayer = nn.Linear(64, num_classes) def _make_layer(self, block, planes, num_block, stride): layers = [] for i in range(num_block): if i == 0: layers.append(block(self.in_planes, planes, stride)) else: layers.append(block(planes, planes, 1)) self.in_planes = planes return nn.Sequential(*layers) def forward(self, x): x = self.maxpool(self.conv1(x)) x = self.block1(x) # [200, 64, 28, 28] x = self.block2(x) # [200, 128, 14, 14] x = self.block3(x) # [200, 256, 7, 7] # out = self.block4(out) x = F.avg_pool2d(x, 7) # [200, 256, 1, 1] x = x.view(x.size(0), -1) # [200,256] out = self.outlayer(x) return out ResNet18 = ResNet(Basicblock, [1, 1, 1, 1], 10) # print(ResNet18) opt = torch.optim.SGD(ResNet18.parameters(), lr=Lr) loss_fun = nn.CrossEntropyLoss() a = [] ac_list = [] for epoch in range(Epoch): for i, (x, y) in enumerate(train_loader): print("i值为:",i) output = ResNet18(x) loss = loss_fun(output, y) opt.zero_grad() loss.backward() opt.step() if i % 100 == 0: a.append(i) test_output = torch.max(ResNet18(test_x), dim=1)[1] loss = loss_fun(ResNet18(test_x), test_y).item() accuracy = torch.sum(torch.eq(test_y, test_output)).item() / test_y.numpy().size ac_list.append(accuracy) print('Epoch:', Epoch, '|loss%.4f' % loss, '|accuracy%.4f' % accuracy) print('real value', test_y[: 10].numpy()) print('train value', torch.max(ResNet18(test_x)[: 10], dim=1)[1].numpy()) plt.plot(a, ac_list, color='r') plt.show()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)