import torch import torch.nn as nn import torchvision from #所在程序名字 import #自己网络保存的名字 import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms import numpy as np def test_mydata(): im = plt.imread('33333.jpg') #自己手写的图片读入 images = Image.open('33333.jpg') #可以转换为黑底白字,读取更准确 images = images.resize((28,28)) images = images.convert('L') transform = transforms.ToTensor() images = transform(images) images = images.resize(1,1,28,28) # 加载网络和参数 model = #加上自己网络的名字() model.load_state_dict(torch.load('pathh')) model.eval() outputs = model(images) values, indices = outputs.data.max(1) # 返回最大概率值和下标 plt.title('{}'.format((int(indices[0])))) plt.imshow(im) plt.show() test_mydata()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)