Python+Android进行TensorFlow开发

Python+Android进行TensorFlow开发,第1张

概述Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译

Tensorflow是Google开源的一套机器学习框架,支持GPU、cpu、AndroID等多种计算平台。本文将介绍在Tensorflow在AndroID上的使用。

AndroID使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroID_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些 *** 作符运算),也可以自己通过bazel编译生成这两个文件。
将libandroID_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jnilibs目录下对应的ABI文件夹下。目录结构如下:


AndroID目录结构

同时在app的build.gradle中的dependencIEs模块下添加如下配置:

dependencIEs {    ...    compile files('libs/libandroID_tensorflow_inference_java.jar')    ...}

使用tensorflow框架进行机器学习分为四个步骤:

构造神经网络

训练神经网络模型

将训练好的模型输出为pb文件

ndroID上加载pb模型进行计算

前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

# -*-Coding:utf-8 -*-from __future__ import print_functionimport osimport tensorflow as tffrom numpy.random import RandomStateos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'"""训练模型"""def train():    # 定义训练数据集batch大小为8    batch_size = 8    # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层    w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")    w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")    # 定义输入输出格式    x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')    y_ = tf.placeholder(tf.float32, shape=(None, 1))    # 定义神经网络前向传播过程    a = tf.matmul(x, w1)    y = tf.matmul(a, w2, name="cal_node")    # 定义交叉熵和反向传播算法    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))    train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)    # 生成随机训练集    rdm = RandomState(1)    dataset_size = 128    # 定义映射关系    X = rdm.rand(dataset_size, 2)    Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]    with tf.Session() as sess:        # 初始化所有参数        init_op = tf.global_variables_initializer()        sess.run(init_op)        # print sess.run(w1)        # print sess.run(w2)        STEPS = 500        for i in range(STEPS):            start = (i * batch_size) % dataset_size            end = min(start + batch_size, dataset_size)            # 训练神经网络,更新神经网络参数            sess.run(train_step, Feed_dict={x: X[start:end], y_: Y[start:end]})            if i % 100 == 0:                total_cross_entropy = sess.run(cross_entropy, Feed_dict={x: X, y_: Y})                print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))            print(sess.run(w1))            print(sess.run(w2))        # 保存check point        saver = tf.train.Saver(tf.trainable_variables())        saver.save(sess, './model/checpt')

上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:

由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在androID上只需要这个pb文件即可使用这个训练好的模型:

"""存储pb模型"""def dump_graph_to_pb(pb_path):    with tf.Session() as sess:        check_point = tf.train.get_checkpoint_state("./model/")        if check_point:            saver = tf.train.import_Meta_graph(check_point.model_checkpoint_path + '.Meta')            saver.restore(sess, check_point.model_checkpoint_path)        else:            raise ValueError("Model load Failed from {}".format(check_point.model_checkpoint_path))        graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))        with tf.gfile.Gfile(pb_path, "wb") as f:            f.write(graph_def.SerializetoString())

拿到生成的pb模型,我们可以在androID上使用了。将pb文件在这main/assets下:

接下来就可以载入pb,进行计算了:

public class MainActivity extends AppCompatActivity {    private Graph graph_;    private Session session_;    private AssetManager assetManager;    private static ExecutorService executorService;    private static Handler handler;    @OverrIDe    protected voID onCreate(Bundle savedInstanceState) {        super.onCreate(savedInstanceState);        setContentVIEw(R.layout.activity_main);        executorService = Executors.newFixedThreadPool(5);        // 初始化tensorflow        initTensorFlow("outmodel.pb");        // 使用tensorflow进行计算        runTensorFlow();    }    ...}

通过如下方式载入pb模型,初始化tensorflow:

private boolean initTensorFlow(String modelfile) {        assetManager = getAssets();        // 新建Graph        graph_ = new Graph();        inputStream is = null;        try {            // 读取Assets pb文件            is = assetManager.open(modelfile);        } catch (IOException e) {            e.printstacktrace();            return false;        }        try {            // 加载pb到Graph            TensorUtil.loadGraph(is, graph_);            is.close();        } catch (IOException e) {            e.printstacktrace();            return false;        }        // 初始化session        session_ = new Session(graph_);        if (session_ == null) {            return false;        }        return true;    }

然后就可以使用tensorflow API进行运算了:

private voID runTensorFlow() {        executorService.execute(generatePredictRunnable(handler));    }    private Runnable generatePredictRunnable(Handler handler) {        return new Runnable() {            @OverrIDe            public voID run() {                float[][] input = new float[1][2];                input[0][0] = 1;                input[0][1] = 2;                // 定义输入tensor                Tensor inputTensor = Tensor.create(input);                // 指定输入,输出节点,运行并得到结果                Tensor resultTensor = session_.runner()                        .Feed("x_input", inputTensor)                        .fetch("cal_node")                        .run()                        .get(0);                float[][] dst = new float[1][1];                resultTensor.copyTo(dst);                // 处理结果                ArrayList<float> resultList = new ArrayList<>();                for (float val : dst[0]) {                    if (val != 0) {                        resultList.add(val);                    } else {                        break;                    }                }            }        };    }

上面就是通过python训练机器学习模型,并在androID平台进行调用的完整流程。

原创作者:JackMeGo,原文链接:https://www.jianshu.com/p/eef4ab014a12


欢迎关注我的微信公众号「码农突围」,分享Python、Java、大数据、机器学习、人工智能等技术,关注码农技术提升•职场突围•思维跃迁,20万+码农成长充电第一站,陪有梦想的你一起成长。

总结

以上是内存溢出为你收集整理的Python+Android进行TensorFlow开发全部内容,希望文章能够帮你解决Python+Android进行TensorFlow开发所遇到的程序开发问题。

如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/web/1064367.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-26
下一篇 2022-05-26

发表评论

登录后才能评论

评论列表(0条)

保存