(一)思路简介
1.pyTorch会提供TorchScript,它可以生成最有效的C++可读的模型格式。
2.TorchScript的记录法支持Python控制流的子集,由TorchScript创建的中间阶段可以由LibTorch读取。
3.接下来,由LibTorch加载的模型可以保存为C++对象,并且可以在C++程序中运行。
对于TorchScript的记录法,可以从torch.jit.ScriptModule继承。这有助于PyTorch避免使用纯Python方法,而这些方法无法转换为TorchScript。
(二)待训练模型采用
CIFAR10,10分类
按上述源码训练后得到模型参数文件:saveTextOnlyParams.pth
(三)使用跟踪法将模型序列化:
import os.path
from typing import Iterator
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset,DataLoader,Subset,random_split
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms,datasets
import torchvision as tv
from torch import nn
import torch.nn.functional as F
import time
import onnx
import onnxruntime
#查看命令:tensorboard --logdir=./myBorderText
#可用pycharm中code中的generater功能实现:
class myCustomerNetWork(nn.Module):
def __init__(self):
super().__init__()
#输入3通道输出6通道:
self.features=nn.Sequential(nn.Conv2d(3, 64, (3, 3)),nn.ReLU(),nn.Conv2d(64,128,(3,3)),
nn.ReLU(),nn.Conv2d(128,256,(3,3)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))
self.classfired=nn.Sequential(nn.Flatten(),nn.Linear(256,80),nn.Dropout(),nn.Linear(80,10))
def forward(self,x):
return self.classfired(self.features(x))
#网络输入要求为torch.Size([32, 3, 32, 32])格式
myNet=myCustomerNetWork()
pthfile = r'D:\flask_pytorch\saveTextOnlyParams.pth'
#当strict=false时,参数文件匹配得上就加载,没有就默认初始化。
myNet.load_state_dict(torch.load(pthfile),strict=False)
if torch.cuda.is_available():
myNet=myNet.cuda()
myNet.eval()
if __name__ == '__main__':
#对于opencv中的dnn模块,网络输入必须为(n,c,w,h)格式
imagePath = r"C:\Users360\Desktop\monodepth.jpeg"
img = cv2.imdecode(np.fromfile(imagePath, np.uint8), -1)
img = cv2.resize(img, (32, 32))
# bgr转rgb
img = img[:, :, ::-1].copy()
inputX = torch.FloatTensor(img).cuda()
inputX = inputX.permute(2, 0, 1).contiguous()
inputX = inputX.unsqueeze(0)
with torch.no_grad():
jit_model = torch.jit.trace(myNet, inputX)
# 将模型序列化
jit_model.save('jit_model.pth')
可用pytorch直接加载序列化模型并在python环境中使用:
if __name__ == '__main__':
#对于opencv中的dnn模块,网络输入必须为(n,c,w,h)格式
imagePath = r"C:\Users360\Desktop\monodepth.jpeg"
img = cv2.imdecode(np.fromfile(imagePath, np.uint8), -1)
img = cv2.resize(img, (32, 32))
# bgr转rgb
img = img[:, :, ::-1].copy()
inputX = torch.FloatTensor(img).cuda()
inputX = inputX.permute(2, 0, 1).contiguous()
inputX = inputX.unsqueeze(0)
with torch.no_grad():
jit_model = torch.jit.trace(myNet, inputX)
# 将模型序列化
#jit_model.save('jit_model.pth')
jit_model = torch.jit.load('jit_model.pth')
print(jit_model(inputX))
输出:
tensor([[ 71.6224, 10.6505, 165.0648, 313.5768, -148.1144, 329.7959,
109.9136, -266.1085, -171.0974, -272.6216]], device='cuda:0',
grad_fn=<AddmmBackward0>)
(四)使用记录法将模型序列化:
import os.path
from typing import Iterator
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset,DataLoader,Subset,random_split
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms,datasets
import torchvision as tv
from torch import nn
import torch.nn.functional as F
import time
import onnx
import onnxruntime
#查看命令:tensorboard --logdir=./myBorderText
#可用pycharm中code中的generater功能实现:
#向模型添加显式注释:
class myCustomerNetWork(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
#输入3通道输出6通道:
self.features=nn.Sequential(nn.Conv2d(3, 64, (3, 3)),nn.ReLU(),nn.Conv2d(64,128,(3,3)),
nn.ReLU(),nn.Conv2d(128,256,(3,3)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))
self.classfired=nn.Sequential(nn.Flatten(),nn.Linear(256,80),nn.Dropout(),nn.Linear(80,10))
@torch.jit.script_method
def forward(self,x):
return self.classfired(self.features(x))
#网络输入要求为torch.Size([32, 3, 32, 32])格式
myNet=myCustomerNetWork()
pthfile = r'D:\flask_pytorch\saveTextOnlyParams.pth'
#当strict=false时,参数文件匹配得上就加载,没有就默认初始化。
myNet.load_state_dict(torch.load(pthfile),strict=False)
if torch.cuda.is_available():
myNet=myNet.cuda()
myNet.eval()
if __name__ == '__main__':
# 将模型序列化
myNet.save('jit_model2.pth')
可用pytorch直接加载序列化模型并在python环境中使用:
if __name__ == '__main__':
#对于opencv中的dnn模块,网络输入必须为(n,c,w,h)格式
imagePath = r"C:\Users360\Desktop\monodepth.jpeg"
img = cv2.imdecode(np.fromfile(imagePath, np.uint8), -1)
img = cv2.resize(img, (32, 32))
# bgr转rgb
img = img[:, :, ::-1].copy()
inputX = torch.FloatTensor(img).cuda()
inputX = inputX.permute(2, 0, 1).contiguous()
inputX = inputX.unsqueeze(0)
# with torch.no_grad():
# jit_model = torch.jit.trace(myNet, inputX)
# 将模型序列化
#myNet.save('jit_model2.pth')
jit_model = torch.jit.load('jit_model2.pth')
print(jit_model(inputX))
输出:
tensor([[ 71.6224, 10.6505, 165.0648, 313.5768, -148.1144, 329.7959,
109.9136, -266.1085, -171.0974, -272.6216]], device='cuda:0',
grad_fn=<AddmmBackward0>)
未完待续。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)