在2010年的ImageNet LSVRC-2010上,AlexNet在给包含有1000种类别的共120万张高分辨率图片的分类任务中,在测试集上的top-1和top-5错误率为37.5%和17.0%(top-5 错误率:即对一张图像预测5个类别,只要有一个和人工标注类别相同就算对,否则算错。同理top-1对一张图像只预测1个类别),在ImageNet LSVRC-2012的比赛中,取得了top-5错误率为15.3%的成绩,而第二名的成绩为26.2%,可见AlexNet在当时有多强大。
网络结构
网络放在两块GPU上,一块GPU放一半。
input 224*224*3 Conv(kernel_size=11*11, kernel_num=96, stride=4, padding=2) output (224-11+2*2)/4+1=55 -> 55*55*96 Relu Pool(kernel_size=3*3, stride=2) output (55-3)/2+1=27 -> 27*27*96 Local Response Normalization(local_size=5) output 27*27*96第二阶段卷积
input 27*27*96 Conv(kernel_size=5*5, kernel_num=256, stride=1, padding=2) output (27-5+2*2)/1+1=27 -> 27*27*256 Relu Pool(kernel_size=3*3, stride=2) output (27-3)/2+1=13 -> 13*13*256 Local Response Normalization(local_size=5) output 13*13*256第三阶段卷积
input 13*13*256 Conv(kernel_size=3*3, kernel_num=384, stride=1, padding=1) output (13-3+2*1)/1+1=13 -> 13*13*384 Relu output 13*13*256第四阶段卷积
input 13*13*256 Conv(kernel_size=3*3, kernel_num=384, stride=1, padding=1) output (13-3+2*1)/1+1=13 -> 13*13*384 Relu output 13*13*384第五阶段卷积
input 13*13*256 Conv(kernel_size=3*3, kernel_num=256, stride=1, padding=1) output (13-3+2*1)/1+1=13 -> 13*13*256 Relu Pool(kernel_size=3*3, stride=2) output (13-3)/2+1=6 -> 6*6*256第六阶段全连接
input 6*6*256 Fc Relu Dropout output 4096第七阶段全连接
input 4096 Fc Relu Dropout output 4096第八阶段全连接
input 4096 Fc Relu Dropout output 1000
Alexnet网络中各个层发挥的作用如下表所述:
Pytorch实现
model.py
import torch.nn as nn import torch class AlexNet(nn.Module): def __init__(self, num_classes=1000, init_weights=False): super(AlexNet, self).__init__() self.features = nn.Sequential( # 打包 nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55] 自动舍去小数点后 nn.ReLU(inplace=True), # inplace 可以载入更大模型 nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27] kernel_num为原论文一半 nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27] nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13] nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13] nn.ReLU(inplace=True), nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13] nn.ReLU(inplace=True), nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13] nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6] ) self.classifier = nn.Sequential( nn.Dropout(p=0.5), # 全连接 nn.Linear(128 * 6 * 6, 2048), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(2048, 2048), nn.ReLU(inplace=True), nn.Linear(2048, num_classes), ) if init_weights: self._initialize_weights() def forward(self, x): x = self.features(x) x = torch.flatten(x, start_dim=1) # 展平 或者view() x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 何教授方法 if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) # 正态分布赋值 nn.init.constant_(m.bias, 0)
从http://download.tensorflow.org/example_images/flower_photos.tgz下载数据集
执行下面代码,将数据集划分为训练集与验证集。
split_data.py
import os from shutil import copy import random def mkfile(file): if not os.path.exists(file): os.makedirs(file) file = 'flower_data/flower_photos' flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla] mkfile('flower_data/train') for cla in flower_class: mkfile('flower_data/train/'+cla) mkfile('flower_data/val') for cla in flower_class: mkfile('flower_data/val/'+cla) split_rate = 0.1 for cla in flower_class: cla_path = file + '/' + cla + '/' images = os.listdir(cla_path) num = len(images) eval_index = random.sample(images, k=int(num*split_rate)) for index, image in enumerate(images): if image in eval_index: image_path = cla_path + image new_path = 'flower_data/val/' + cla copy(image_path, new_path) else: image_path = cla_path + image new_path = 'flower_data/train/' + cla copy(image_path, new_path) print("r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar print() print("processing done!")
训练模型 train.py
import torch import torch.nn as nn from torchvision import transforms, datasets, utils import matplotlib.pyplot as plt import numpy as np import torch.optim as optim from model import AlexNet import os import json import time # device : GPU or CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) # 数据转换 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} # data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path data_root = os.getcwd() image_path = data_root + "/flower_data/" # flower data set path train_dataset = datasets.ImageFolder(root=image_path + "/train", transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) validate_dataset = datasets.ImageFolder(root=image_path + "/val", transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=True, num_workers=0) test_data_iter = iter(validate_loader) test_image, test_label = test_data_iter.next() # print(test_image[0].size(),type(test_image[0])) # print(test_label[0],test_label[0].item(),type(test_label[0])) # 显示图像,之前需把validate_loader中batch_size改为4 # def imshow(img): # img = img / 2 + 0.5 # unnormalize # npimg = img.numpy() # plt.imshow(np.transpose(npimg, (1, 2, 0))) # plt.show() # # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4))) # imshow(utils.make_grid(test_image)) net = AlexNet(num_classes=5, init_weights=True) net.to(device) # 损失函数:这里用交叉熵 loss_function = nn.CrossEntropyLoss() # 优化器 这里用Adam optimizer = optim.Adam(net.parameters(), lr=0.0002) # 训练参数保存路径 save_path = './AlexNet.pth' # 训练过程中最高准确率 best_acc = 0.0 # 开始进行训练和测试,训练一轮,测试一轮 for epoch in range(10): # train net.train() # 训练过程中,使用之前定义网络中的dropout running_loss = 0.0 t1 = time.perf_counter() for step, data in enumerate(train_loader, start=0): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() # print train process rate = (step + 1) / len(train_loader) a = "*" * int(rate * 50) b = "." * int((1 - rate) * 50) print("rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") print() print(time.perf_counter()-t1) # validate net.eval() # 测试过程中不需要dropout,使用所有的神经元 acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): for val_data in validate_loader: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += (predict_y == val_labels.to(device)).sum().item() val_accurate = acc / val_num if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % (epoch + 1, running_loss / step, val_accurate)) print('Finished Training')
Output如下
使用向日葵图片进行测试 predict.py
import torch from model import AlexNet from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import json import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img = Image.open("./sunflower.jpg") # 验证太阳花 # img = Image.open("./roses.jpg") # 验证玫瑰花 plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict try: json_file = open('./class_indices.json', 'r') class_indict = json.load(json_file) except Exception as e: print(e) exit(-1) # create model model = AlexNet(num_classes=5) # load model weights model_weight_path = "./AlexNet.pth" model.load_state_dict(torch.load(model_weight_path)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img)) predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print(class_indict[str(predict_cla)], predict[predict_cla].item()) plt.show()
Output如下
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)