欢迎关注我的公众号 [极智视界],回复001获取Google编程规范
O_o >_< o_O O_o ~_~ o_O
大家好,我是极智视界,本文介绍一下 TensorRT8 自定义算子 Plugin 实现方法。
TensorRT8 中对于 Plugin 的开发做了较大变更,这里介绍一下。
在 TensorRT8 中实现 plugin 需要构造两个类:custom_plugin、create_plugin,在 create_plugin 中调用 custom_plugin。下面开始。
文章目录1、构造 custom_plugin2、构造 create_plugin3、调用
1、构造 custom_pluginClass custom_plugin 继承于 IPluginV2Ext or IPluginV2 or IPluginV2DynamicExt。
class ClipPlugin : public IPluginV2 { public: ClipPlugin(const std::string name, float clipMin, float clipMax); ClipPlugin(const std::string name, const void* data, size_t length); // ClipPlugin 不带参是没有意义的,所以删除了默认构造. ClipPlugin() = delete; int getNbOutputs() const noexcept override; Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override; int initialize() noexcept override; void terminate() noexcept override; size_t getWorkspaceSize(int) const noexcept override { return 0; }; int enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // 通过 cuda c 实现 size_t getSerializationSize() const noexcept override; void serialize(void* buffer) const noexcept override; void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override; bool supportsFormat(DataType type, PluginFormat format) const noexcept override; const char* getPluginType() const noexcept override; const char* getPluginVersion() const noexcept override; void destroy() noexcept override; nvinfer1::IPluginV2* clone() const noexcept override; // addPluginV2时会调用,该方法会去调用构造函数 void setPluginNamespace(const char* pluginNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; private: // 算子参数 const std::string mLayerName; float mClipMin, mClipMax; size_t mInputVolume; std::string mNamespace; };
2、构造 create_plugin
class ClipPluginCreator : public IPluginCreator { public: ClipPluginCreator(); const char* getPluginName() const noexcept override; const char* getPluginVersion() const noexcept override; const PluginFieldCollection* getFieldNames() noexcept override; IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; //实现该方法时调用 custom_plugin IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; void setPluginNamespace(const char* pluginNamespace) noexcept override; const char* getPluginNamespace() const noexcept override; private: static PluginFieldCollection mFC; static std::vectormPluginAttributes; std::string mNamespace; };
3、调用
官方给的方法是通过 create_plugin 去调用 custom_plugin,其本质上就是调用 custom_plugin,可以按照下面来写:
nvinfer1::IPluginV2 *clip = new ClipPlugin(scale, 512, Dtype); nvinfer1::IPluginV2Layer *Clip = m_network->addPluginV2(&Layers[inputName], 1, *clip);
其中 ClipPlugin 就是 custom_plugin。
以上分享了 TensorRT8 中自定义算子 Plugin 的实现方法,希望我的分享能对你的学习有一点帮助。
【公众号传送】
《【模型推理】TensorRT8 自定义算子 Plugin 实现方法》
扫描下方二维码即可关注我的微信公众号【极智视界】,获取更多AI经验分享,让我们用极致+极客的心态来迎接AI !
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)