部署篇:使用pytorch实现垃圾分类并部署使用,浏览器访问,前端Vue,后台Flask

部署篇:使用pytorch实现垃圾分类并部署使用,浏览器访问,前端Vue,后台Flask,第1张

部署篇:使用pytorch实现垃圾分类并部署使用,浏览器访问,前端Vue,后台Flask

部署篇:使用pytorch实现垃圾分类并部署使用,浏览器访问,前端Vue,后台Flask
  • 1.创建Flask工程
  • 2.创建Vue工程

1.创建Flask工程
  • 打开pycharm,新建工程,选择Flask
  • 在工程中新建一个文件夹model用来存放之前训练的模型,从之前的训练工程中把model_best_checkpoint_resnet50.pth.tar和label.txt文件拷贝到model文件夹中,
  • 在工程根目录中新建一个python文件,命名为rubbish.py,这个文件用来调用模型进行预测。代码如下,接下来,只需要调用getRubbish这个方法,将图片传入就可以得到预测值。
from torchvision import models
import torch.nn as nn
from torchvision import transforms
import torch
import os

def getLabel(line_num):
    f = open("ai/label.txt",encoding='utf-8')
    if line_num == 1:
        line = f.readline()
    else:
        line = f.readline()
        while line:
            line = f.readline()
            line_num -= 1
            if (line_num == 1):
                break
    return line

def getRubbish(img):
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    t = transforms.Compose([transforms.ToPILImage(),
                            transforms.Resize((224, 224)),
                            transforms.ToTensor()])
    num_classes = 123
    # 加载数据
    map_location = "cpu"
    model = models.resnet50(pretrained=False)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, num_classes)
    resume = 'model/model_best_checkpoint_resnet50.pth.tar'
    checkpoint = torch.load(resume, map_location=map_location)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    image = t(img).unsqueeze(0)
    with torch.no_grad():
        pred = model(image)
        _, pred = torch.max(pred, 1)
        pred = getLabel(int(str(pred.data)[8:-2])+1)
    return pred
  • 在根目录的app.py文件中,创建前端调用的接口,代码如下,前端只需要发送post请求,将图片传过来,就会得到一个json数据,其中data保存的就是预测值。
# 垃圾分类
@app.route('/rubbish', methods=['post'])
def rubbish():
    img = request.files.get('file')
    # 转成BufferedReader格式
    img_buff = BufferedReader(img)
    # 得到byte数据
    img_byte = BufferedReader.read(img_buff)
    # 转成numpy数组
    nparr = np.frombuffer(img_byte, dtype=np.uint8)
    # 转成cv2.imread相同效果的数据
    img_decode = cv2.imdecode(nparr, 1)
    pred = getRubbish(img_decode)
    result = {'data': pred}
    return json.dumps(result, ensure_ascii=False)
2.创建Vue工程
  • vue的安装网上教程很多,随便找个看看安装一下就好,这里不再赘述
  • 安装完Vue环境后,创建一个Vue工程,由于需要使用elementUI的上传组件,所以在安装一下elementUI
  • 新建一个页面,在这个页面中实现图片的上传和结果的展示,具体怎么实现,可以根据你们自己的想法,下面是我自己写的一个简单页面