C++调用Libtorch常见函数

C++调用Libtorch常见函数,第1张

C++调用Libtorch常见函数
#创建变量 
std::vector inputs;
inputs.push_back(torch::ones({ 1,3,224,224 }));
torch::jit::IValue inputs;

#定义模型变量
torch::jit::script::Module model = torch::jit::load("path");

at::Tensor output = model.forward(inputs).toTensor();
#获取尺寸

ouput.sizes()
int heigh = output.size(0);
int weight = output.size(1);

torch::Tensor out_tensor = output.detach(); # requires_grad为false,
out_tensor = out_tensor.squeeze().detach().permute({ 1, 2, 0 });
// squeeze 减少图像尺寸 permute 交换维度
out_tensor = out_tensor.mul(255).clamp(0, 255).to(torch::kU8); //*255,转uint8 
out_tensor = out_tensor.to(torch::kCPU); //迁移至CPU
cv::Mat resultImg(img_h, img_w, CV_8UC3, out_tensor.data_ptr()); // 将Tensor数据拷贝至Mat
// cv::cvtColor(resultImg, resultImg, CV_RGB2BGR); 


#
cv::Mat tensor2Mat(torch::Tensor &i_tensor)
{
	int height = i_tensor.size(0), width = i_tensor.size(1);
	//i_tensor = i_tensor.to(torch::kF32);
	i_tensor = i_tensor.to(torch::kCPU);
	cv::Mat o_Mat(cv::Size(width, height), CV_32F, i_tensor.data_ptr());
	return o_Mat;
}

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5713950.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存