Pytroch模型部署到Android设备

  |   0 评论   |   0 浏览   |   给我丶鼓励

本项目是一个简单的图像分类应用程序,演示了如何使用 PyTorch Android API。此应用程序在静态图像上运行 TorchScript 序列化的 TorchVision 预训练的 resnet18 模型,该模型作为 Android 资产打包在应用程序内部。

1.模型准备

让我们从模型准备开始。如果您熟悉 PyTorch,您可能应该已经知道如何训练和保存模型。如果您不这样做,我们将使用预先训练的图像分类模型(Resnet18),该模型包装在 TorchVision 中。要安装它,请运行以下命令:

pip install torchvision

要序列化模型,可以在 HelloWorld 应用的根文件夹中使用 python 代码:

import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/model.pt")

如果一切正常,我们应该拥有我们的模型- model.pt 在 Android 应用程序的 Assets 文件夹中生成。它将被打包为 Android 应用程序内部,asset 并且可以在设备上使用。

2.从 GitHub 克隆

git clone https://github.com/pytorch/android-demo-app.git
cd HelloWorldApp

如果已经安装了 Android SDK 和 Android NDK,则可以使用以下命令将此应用程序安装到连接的 Android 设备或模拟器上:

./gradlew installDebug

我们建议您在 Android Studio 3.5.1+ 中打开此项目。目前,PyTorch Android 和演示应用程序使用版本 3.5.0 的 Android gradle 插件,只有 Android Studio 版本 3.5.1 和更高版本才支持。使用 Android Studio,您将能够通过 Android Studio UI 安装 Android NDK 和 Android SDK。

3. Gradle 依赖

Pytorch Android 作为 build.gradle 中的 gradle 依赖项添加到项目中:

repositories {
    jcenter()
}

dependencies {
    implementation 'org.pytorch:pytorch_android:1.4.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

org.pytorch:pytorch_android PyTorch Android API 的主要依赖项在哪里,包括所有 4 个 Android abis(armeabi-v7a,arm64-v8a,x86,x86_64)的 libtorch 本机库。此外,在此文档中,您可以找到如何仅针对特定的 Android abis 列表重建它。

org.pytorch:pytorch_android_torchvision-具有实用功能的附加库,用于转换 android.media.Imageandroid.graphics.Bitmap 张量。

4.从 Android Asset 读取图像

所有逻辑都发生在中org.pytorch.helloworld.MainActivity。作为第一步,我们阅读 image.jpgandroid.graphics.Bitmap 使用标准 Android API 的信息。

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

5.加载 TorchScript 模块

Module module = Module.load(assetFilePath(this, "model.pt"));

org.pytorch.Module 表示 torch::jit::script::Module 可以使用 load 指定序列化到文件模型的文件路径的方法加载。

6.准备输入

Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

org.pytorch.torchvision.TensorImageUtilsorg.pytorch:pytorch_android_torchvision 图书馆的一部分。该 TensorImageUtils#bitmapToFloat32Tensor 方法在创建张量 torchvision 格式使用 android.graphics.Bitmap 作为源。

所有经过预训练的模型都希望输入图像以相同的方式归一化,即形状为(3 x H x W)的 3 通道 RGB 图像的迷你批,其中 H 和 W 至少应为 224。加载到的范围内 [0, 1],然后使用 mean = [0.485, 0.456, 0.406] 和进行归一化 std = [0.229, 0.224, 0.225]

inputTensor 的形状为 1x3xHxW,其中 HW 分别是位图的高度和宽度。

7.运行推理

Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

org.pytorch.Module.forward 方法运行加载的模块的 forward 方法,并 org.pytorch.Tensor 使用 shape 获得作为 outputTensor 的结果 1x1000

8.处理结果

使用以下 org.pytorch.Tensor.getDataAsFloatArray() 方法检索其内容:该方法返回浮点数的 Java 数组,并为每个图像网络类分配分数。

之后,我们只找到具有最高分数的索引,然后从 ImageNetClasses.IMAGENET_CLASSES 包含所有 ImageNet 类的数组中检索预测的类名。

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
  if (scores[i] > maxScore) {
    maxScore = scores[i];
    maxScoreIdx = i;
  }
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

在以下各节中,您可以找到 PyTorch Android API 的详细说明,用于更大的演示应用程序的代码演练,API 的实现细节,如何从源代码进行自定义和构建。

PYTORCH 演示应用程序

我们还创建了另一个更复杂的 PyTorch Android 演示应用程序,该应用程序从同一 GitHub 存储库中的摄像头输出和文本分类进行图像分类。

要获取设备相机的输出,它使用 Android CameraX API。与 CameraX 一起使用的所有逻辑都被org.pytorch.demo.vision.AbstractCameraXActivity分类。

void setupCameraX() {
    final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
    final Preview preview = new Preview(previewConfig);
    preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));

    final ImageAnalysisConfig imageAnalysisConfig =
        new ImageAnalysisConfig.Builder()
            .setTargetResolution(new Size(224, 224))
            .setCallbackHandler(mBackgroundHandler)
            .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
            .build();
    final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
    imageAnalysis.setAnalyzer(
        (image, rotationDegrees) -> {
          analyzeImage(image, rotationDegrees);
        });

    CameraX.bindToLifecycle(this, preview, imageAnalysis);
  }

  void analyzeImage(android.media.Image, int rotationDegrees)

analyzeImage 方法处理相机输出的位置 android.media.Image

它使用上述TensorImageUtils.imageYUV420CenterCropToFloat32Tensor转换方法 android.media.ImageYUV420 格式输入张量。

从模型中获得预测分数后,它会找到分数最高的前 K 个类别,并在用户界面上显示。

语言处理示例

另一个示例是基于 LSTM 模型的自然语言处理,并在 reddit 注释数据集上进行了训练。逻辑发生在中TextClassificattionActivity

结果类名称打包在 TorchScript 模型中,并在初始模块初始化后立即进行初始化。该模块具有一个 get_classes return 的方法,List[str] 可以使用 method 进行调用 Module.runMethod(methodName)

    mModule = Module.load(moduleFileAbsoluteFilePath);
    IValue getClassesOutput = mModule.runMethod("get_classes");

IValue 可以将返回的值转换为 IValue using 的 Java 数组,IValue.toList() 并使用以下方法处理为字符串数组 IValue.toStr()

    IValue[] classesListIValue = getClassesOutput.toList();
    String[] moduleClasses = new String[classesListIValue.length];
    int i = 0;
    for (IValue iv : classesListIValue) {
      moduleClasses[i++] = iv.toStr();
    }

输入的文本将转换为带有 UTF-8 编码的 Java 字节数组。从该字节数组 Tensor.fromBlobUnsigned 创建张量 dtype=uint8

    byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
    final long[] shape = new long[]{1, bytes.length};
    final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);

模型的运行推断与前面的示例相似:

Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()

之后,代码处理输出,找到得分最高的类。


标题:Pytroch模型部署到Android设备
作者:给我丶鼓励
地址:https://blog.doiduoyi.com/articles/1591263975715.html

评论

发表评论