tensorflow自定义算子开发1:CPU实例

tensorflow自定义算子开发1:CPU实例,第1张

本文将介绍如果用C++在tensorflow中新建一个算子,参考官方文档通过一个简单的例子来说明。 *** 作系统是Ubuntu,且系统已经安装tensorflow。

首先,创建一个名为 zero_out.cc 的文件,所有内容均在本文件中实现

定义运算接口

对于一个新的 *** 作,首先要在C++中定义这个 *** 作,通过将接口注册到 TensorFlow 系统来定义运算的接口。注册中需要指定该运算的名称、输入类型和名称以及输出类型和名称,还有文档字符串和该运算可能需要的任意特性。下面给出示例:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

这里给出了具体的注册过程,这里调用REGISTER_OP宏注册了一个ZeroOut的 *** 作,输入命名为to_zero,类型为int32,输出命名zeroed,类型为int32,最后set_output用来保证输入输出的维度是一致的。

实现运算内核

定义完接口之后,可以为此 *** 作定义一个或多个内核实现,内核的实现需要继承OpKernel类,并且重载Compute方法,Compute参数中类型为OpKernelContext* 的参数context,从中可以访问输入输出张量等有用信息。内核代码如下:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 得到输入张量
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat();

    // 创建输出张量
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat();

    // 除了第一个元素,其他元素全部置为0
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // 如果输入维度大于0,输出第1维度等于输入第1维度
    if (N > 0) output_flat(0) = input(0);
  }
};
内核注册

内核实现后,需要将其注册到tensorflow系统,需要指定内核执行时的不同约束条件,例如针对CPU和GPU通常会有两个内核。

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

需要注意OpKernel可能并行方位,需要保证线程安全。这里只给出了CPU内核的实现,GPU内核将在下一节介绍。

完整代码

给出完整代码:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });


class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};
构建运算库

这里采用系统编译库来实现,采用g++编译器,且系统已经安装二进制tensorflow, PIP 包管理器来安装二进制 TensorFlow 时,已经包含了编译 *** 作所需的头文件和库文件。但是,TensorFlow Python 库已经提供了 get_include 函数来获取头文件目录,以及 get_lib 函数来回去库目录。 U可以测试如下函数的输出

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
/usr/local/lib/python3.5/dist-packages/tensorflow/include
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'
构建

基于tensorflow的两个接口函数,我们用g++来编译新的算子

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

正常情况下会生成动态库:

python中使用

tensorflow的python接口,提供了函数 tf.load_op_library 来加载动态库并向tensorflow 框架注册运算。load_op_library 会返回一个 python 模块,其中包含运算和内核的 Python 封装容器。因此,在构建此运算后,就可以执行刚才定义的算子

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
sess = tf.Session()
result = zero_out_module.zero_out([[1, 2], [3, 4]])
print("****************")
print(sess.run(result))

执行结果如下:

[[1 0]
 [0 0]]

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/876281.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-13
下一篇 2022-05-13

发表评论

登录后才能评论

评论列表(0条)

保存