tensorrt,简单示例,使用parse

tensorrt,简单示例,使用parse,第1张

#include "NvInfer.h"
#include 
#include "NvOnnxParser.h"
#include 
#include "cuda_runtime_api.h"
#include 

using namespace nvonnxparser;
using namespace nvinfer1;
using namespace std;

class Logger : public ILogger
{
    void log(Severity severity, const char* msg) noexcept override
    {
        // suppress info-level messages
        if (severity <= Severity::kWARNING)
            std::cout << msg << std::endl;
    }
} logger;


int main(int argc,char ** argv){

    if(*argv[1]== 's') {
        IBuilder *builder = createInferBuilder(logger);

        uint32_t flag = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);

        INetworkDefinition *network = builder->createNetworkV2(flag);

        IParser *parser = createParser(*network, logger);

        parser->parseFromFile("/home/wjd1/Downloads/vgg16-7.onnx",
                              static_cast<int32_t>(ILogger::Severity::kWARNING));
        for (int32_t i = 0; i < parser->getNbErrors(); ++i) {
            std::cout << parser->getError(i)->desc() << std::endl;
        }

        IBuilderConfig *config = builder->createBuilderConfig();

        builder->setMaxBatchSize(1);
        config->setMaxWorkspaceSize(1 << 30);

        IHostMemory *serializedModel = builder->buildSerializedNetwork(*network, *config);

        ofstream p("/home/wjd1/tensorrttest/test.engine", ios::binary);
        if (!p.good()) {
            cout << "open failed" << endl;
        }
        p.write(reinterpret_cast<const char *>(serializedModel->data()), serializedModel->size());
        delete parser;
        delete network;
        delete config;
        delete builder;
        delete serializedModel;

    }else if(*argv[1]=='d'){
        size_t size{0};
        char * trtModelStream{nullptr};
        ifstream file("/home/wjd1/tensorrttest/test.engine", ios::binary);
        if(file.good()){
            file.seekg(0,file.end);
            size = file.tellg();
            file.seekg(0,file.beg);
            trtModelStream = new char[size];
            file.read(trtModelStream,size);
            file.close();

        }


        IRuntime * runtime = createInferRuntime(logger);
        ICudaEngine * engine = runtime->deserializeCudaEngine(trtModelStream,size);
        IExecutionContext *context = engine->createExecutionContext();
        delete[] trtModelStream;


        void * buffers[2];
        const char * INPUT_NAME = "data";
        const char * OUTPUT_NAME="vgg0_dense2_fwd";
        int32_t inputIndex = engine->getBindingIndex(INPUT_NAME);
        int32_t outputIndex = engine->getBindingIndex(OUTPUT_NAME);

        int batchSize = 1;
        int INPUT_H=224;
        int INPUT_W=224;
        int OUTPUT_SIZE=1000;

        cudaMalloc(&buffers[inputIndex],batchSize*3*INPUT_H*INPUT_W*sizeof (float ));
        cudaMalloc(&buffers[outputIndex],batchSize*OUTPUT_SIZE*sizeof (float ));

        float data [3*INPUT_H*INPUT_W];
        float prob [OUTPUT_SIZE];

    for(int i = 0 ; i <3*INPUT_H*INPUT_W;++i){
        data[i]=1;
    }

        cudaStream_t stream;
        cudaStreamCreate(&stream);

        cudaMemcpyAsync(buffers[inputIndex],data,batchSize*3*INPUT_H*INPUT_W*sizeof (float ), cudaMemcpyHostToDevice,stream);
        context->enqueueV2(buffers,stream, nullptr);

        cudaMemcpyAsync(prob,buffers[outputIndex],batchSize*OUTPUT_SIZE*sizeof (float ), cudaMemcpyDeviceToHost,stream);
        cudaStreamSynchronize(stream);

        cudaStreamDestroy(stream);
        cudaFree(buffers[inputIndex]);
        cudaFree(buffers[outputIndex]);

        delete context;
        delete engine;
        delete runtime;

        for(auto i = 0 ; i<OUTPUT_SIZE;++i){
            cout<<prob[i]<<",";
        }


    }


}
cmake_minimum_required(VERSION 3.0)
project(test)
aux_source_directory(. src)

set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/)
set(CMAKE_BUILD_TYPE Debug)

include_directories(/usr/local/cuda-11.3/include)
link_directories(/usr/local/cuda-11.3/lib64)

include_directories(/home/wjd1/TensorRT-8.0.0.3/include)
link_directories(/home/wjd1/TensorRT-8.0.0.3/lib)

add_executable(testrt ${src})

target_link_libraries(testrt cudart cudnn cublas nvinfer nvparsers nvinfer_plugin nvonnxparser nvrtc )

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/675979.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-19
下一篇 2022-04-19

发表评论

登录后才能评论

评论列表(0条)

保存