PyTorch & Android 开发笔记(一)

PyTorch & Android 开发笔记(一),第1张

PyTorch & Android 开发笔记(一) 一、Pytorch安装

​ 下载Whl文件,并使用pip install本地安装

避坑:下载tar.bz2版本会缺失setup.py,使用conda可安装,但pip install安装会报错,且安装速度非常慢

​ Pytorch&torchvision .Whl文件下载地址

​ Pytorch&torchvision对应关系

​ 由于项目要求mobilenet_v3_small,在TorchVision低版本中不可用。如果使用它,需要升级到0.10.0(稳定版本)或至少0.9.0,这里直接选择使用torch==1.9.0,torch==1.9.0和torchvision==0.10.0对应,故选择torch-1.9.0+cu111-cp37-cp37m-win_amd64.whl、torchvision-0.10.0+cu111-cp37-cp37m-win_amd64.whl

​ 安装时cmd运行进入whl所在目录,pip install "torch-1.9.0+cu111-cp37-cp37m-win_amd64.whl"

​ 如果提示torchvision等其他包版本不匹配,先pip uninstall 包即可


二、HelloWorld

环境要求:Android SDK & Android NDK

  • 预训练模型配置测试
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model.save("model.pt")

在根目录下生成model.pt,大小9.7MB左右

  • Android测试
git clone https://github.com/pytorch/android-demo-app.git
cd HelloWorldApp

​ 将第一步生成的model.pt文件替换HelloWorldApp/app/src/main/assets/model.pt

​ 随后利用android studio打开HelloWorldApp这个project,然后点击Build菜单,选择Build Bundle(s)/APK(s) 菜单中的“Build APK(s)”

​ 提示Could not find org.pytorch:pytorch_android:1.8.0-SNAPSHOT,打开app/build.gradle,按照如下修改:

apply plugin: 'com.android.application'

android {
    compileSdkVersion 28
    buildToolsVersion "29.0.2"
    defaultConfig {
        applicationId "org.pytorch.helloworld"
        minSdkVersion 21
        //noinspection ExpiredTargetSdkVersion
        targetSdkVersion 28
        versionCode 1
        versionName "1.0"
    }
    buildTypes {
        release {
            minifyEnabled false
        }
    }
}

dependencies {
    implementation 'androidx.appcompat:appcompat:1.1.0'
    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}

修改app/src/main/java/org/pytorch/helloworld/MainActivity.java

package org.pytorch.helloworld;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import androidx.appcompat.app.AppCompatActivity;

public class MainActivity extends AppCompatActivity {

  @Override
  protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    Bitmap bitmap = null;
    Module module = null;
    try {
      // creating bitmap from packaged into app android asset 'image.jpg',
      // app/src/main/assets/image.jpg
      bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
      // loading serialized torchscript module from packaged into app android asset model.pt,
      // app/src/model/assets/model.pt
      module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
    } catch (IOException e) {
      Log.e("PytorchHelloWorld", "Error reading assets", e);
      finish();
    }

    // showing image on UI
    ImageView imageView = findViewById(R.id.image);
    imageView.setImageBitmap(bitmap);

    // preparing input tensor
    final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

    // running the model
    final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

    // getting tensor content as java array of floats
    final float[] scores = outputTensor.getDataAsFloatArray();

    // searching for the index with maximum score
    float maxScore = -Float.MAX_VALUE;
    int maxScoreIdx = -1;
    for (int i = 0; i < scores.length; i++) {
      if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxScoreIdx = i;
      }
    }

    String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

    // showing className on UI
    TextView textView = findViewById(R.id.text);
    textView.setText(className);
  }

  
  public static String assetFilePath(Context context, String assetName) throws IOException {
    File file = new File(context.getFilesDir(), assetName);
    if (file.exists() && file.length() > 0) {
      return file.getAbsolutePath();
    }

    try (InputStream is = context.getAssets().open(assetName)) {
      try (OutputStream os = new FileOutputStream(file)) {
        byte[] buffer = new byte[4 * 1024];
        int read;
        while ((read = is.read(buffer)) != -1) {
          os.write(buffer, 0, read);
        }
        os.flush();
      }
      return file.getAbsolutePath();
    }
  }
}

修改trace_model.py

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.pt")

Terminal运行py脚本,生成model.pt,大小9.72 MB (10,199,082 字节),syns now,Build Apk,模拟器运行,效果详情参考Github


三、Image Segmentation 图像分割综述

英文原版PDF

CSDN翻译版

新增拍照,选图,Live功能

AndroidManifest.xml定义申请权限permission



MainActivity.java的onCreate中检查并申请权限

if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
    ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
}

if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
    ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, 1);
}

选择图片时APP闪退等问题解决

(1)先判断所用的sdk是否大于19;

(2)如果大于19则使用Intent.ACTION_PICK来选择图片;

(3)小于19使用Intent.ACTION_GET_CONTENt来选择图片;

Intent intent;
if (Build.VERSION.SDK_INT < 19) {
    intent = new Intent(Intent.ACTION_GET_CONTENT);
    intent.setType("image/*");
} else {
    intent = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
}
startActivityForResult(intent, 1);

如果用的Android 6.0及以上的Android设备,请动态申请权限

如果用的Android 10.0及以上的Android设备,请动态申请权限,并在 AndroidManifest.xml中application标签内加上android:requestLegacyExternalStorage="true"


四、Object Detection

选择图片时APP闪退等问题解决

(1)先判断所用的sdk是否大于19;

(2)如果大于19则使用Intent.ACTION_PICK来选择图片;

(3)小于19使用Intent.ACTION_GET_CONTENt来选择图片;

Intent intent;
if (Build.VERSION.SDK_INT < 19) {
    intent = new Intent(Intent.ACTION_GET_CONTENT);
    intent.setType("image/*");
} else {
    intent = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
}
startActivityForResult(intent, 1);

权限申请不全导致闪退问题解决:多权限动态申请

如果用的Android 6.0及以上的Android设备,请动态申请权限

如果用的Android 10.0及以上的Android设备,请动态申请权限,并在 AndroidManifest.xml中application标签内加上android:requestLegacyExternalStorage="true"


附:15种ARGB颜色,用于标记图像分割标记

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/4015048.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-10-22
下一篇 2022-10-22

发表评论

登录后才能评论

评论列表(0条)

保存