使用TensorFlow Lite在Android手机上实现图像分类

  |   0 评论   |   0 浏览   |   夜雨飘零

前言

TensorFlow Lite 是一款专门针对移动设备的深度学习框架,移动设备深度学习框架是部署在手机或者树莓派等小型移动设备上的深度学习框架,可以使用训练好的模型在手机等设备上完成推理任务。这一类框架的出现,可以使得一些推理的任务可以在本地执行,不需要再调用服务器的网络接口,大大减少了预测时间。在前几篇文章中已经介绍了百度的 paddle-mobile,小米的 mace,还有腾讯的 ncnn。这在本章中我们将介绍谷歌的 TensorFlow Lite。

Tensorflow Lite 的 GitHub 地址:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite
Tensorflow 的版本为:Tensorflow 1.14.0

转换模型

手机上执行预测,首先需要一个训练好的模型,这个模型不能是 TensorFlow 原来格式的模型,TensorFlow Lite 使用的模型格式是另一种格式的模型。下面就介绍如何使用这个格式的模型。

获取模型主要有三种方法,第一种是在训练的时候就保存 tflite 模型,另外一种就是使用其他格式的 TensorFlow 模型转换成 tflite 模型,第三中是检查点模型转换。

  1. 最方便的就是在训练的时候保存 tflite 格式的模型,主要是使用到 tf.contrib.lite.toco_convert() 接口,下面就是一个简单的例子:
import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")

with tf.Session() as sess:
  tflite_model = tf.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

最后获得的 converteds_model.tflite 文件就可以直接在 TensorFlow Lite 上使用。

  1. 第二种就是把 tensorflow 保存的其他模型转换成 tflite,我们可以在以下的链接下载模型:

tensorflow 模型:https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models

上面提供的模型同时也包括了 tflite 模型,我们可以直接拿来使用,但是我们也可以使用其他格式的模型来转换。比如我们下载一个 mobilenet_v1_1.0_224.tgz,解压之后获得以下文件:

mobilenet_v1_1.0_224.ckpt.data-00000-of-00001  mobilenet_v1_1.0_224_eval.pbtxt  mobilenet_v1_1.0_224.tflite
mobilenet_v1_1.0_224.ckpt.index                mobilenet_v1_1.0_224_frozen.pb
mobilenet_v1_1.0_224.ckpt.meta                 mobilenet_v1_1.0_224_info.txt

首先要安装 Bazel,可以参考:https://docs.bazel.build/versions/master/install-ubuntu.html ,只需要完成 Installing using binary installer 这一部分即可。

然后克隆 TensorFlow 的源码:

git clone https://github.com/tensorflow/tensorflow.git

接着编译转换工具,这个编译时间可能比较长:

cd tensorflow/
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco

获得到转换工具之后,我们就可以开始转换模型了,以下操作是冻结图。

  • input_graph 对应的是 .pb 文件;
  • input_checkpoint 对应的是 mobilenet_v1_1.0_224.ckpt.data-00000-of-00001,但是在使用的使用是去掉后缀名的。
  • output_node_names 这个可以在 mobilenet_v1_1.0_224_info.txt 中获取。

不过要注意的是我们下载的模型已经是冻结过来,所以不用再执行这个操作。但如果是其他的模型,要先冻结图,然后再执行之后的操作。

./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \
  --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
  --input_binary=true \
  --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
  --output_node_names=MobilenetV1/Predictions/Reshape_1

以下操作就是把已经冻结的图转换成 .tflite

  • input_file 是已经冻结的图;
  • output_file 是转换后输出的路径;
  • output_arrays 这个可以在 mobilenet_v1_1.0_224_info.txt 中获取;
  • input_shapes 这个是预测数据的 shape
./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
  --inference_type=FLOAT \
  --input_type=FLOAT \
  --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3
  1. 检查点模型转换,使用训练保存的检查点和 export_inference_graph.py 输出的预测图,来冻结模型。

在冻结之前需要知道模型最后一层输出层的名称,通过以下命令可以得到:

bazel build tensorflow/tools/graph_transforms:summarize_graph

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
  --in_graph=/tmp/output_file/mobilenet_v2_inf_graph.pb

开始冻结图:

bazel build tensorflow/python/tools:freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph \
  --input_graph=/tmp/output_file/mobilenet_v2_inf_graph.pb \
  --input_checkpoint=/tmp/ckpt/mobilenet_v2.ckpt-6900 \
  --input_binary=true \
  --output_graph=/tmp/mobilenet_v2.pb \
  --output_node_names=MobilenetV2/Predictions/Reshape_1

冻结图之后使用输入层的名称和输出层的名称生成 lite 模型

bazel build tensorflow/lite/toco:toco

bazel-bin/tensorflow/lite/toco/toco --input_file=/tmp/mobilenet_v2.pb \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
  --inference_type=FLOAT \
  --input_type=FLOAT \
  --input_arrays=image \
  --output_arrays=MobilenetV2/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3

经过上面的步骤就可以获取到 mobilenet_v1_1.0_224.tflite 模型了,之后我们会在 Android 项目中使用它。

开发 Android 项目

有了上面的模型之后,我们就使用 Android Studio 创建一个 Android 项目,一路默认就可以了,并不需要 C++ 的支持,因为我们使用到的 TensorFlow Lite 是 Java 代码的,开发起来非常方便。

  1. 创建完成之后,在 app 目录下的 build.gradle 配置文件加上以下配置信息:
    dependencies 下加上包的引用,第一个是图片加载框架 Glide,第二个就是我们这个项目的核心 TensorFlow Lite:
    implementation 'com.github.bumptech.glide:glide:4.3.1'
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'

然后在 android 下加上以下代码,这个主要是限制不要对 tensorflow lite 的模型进行压缩,压缩之后就无法加载模型了:

    //set no compress models
    aaptOptions {
        noCompress "tflite"
    }
  1. main 目录下创建 assets 文件夹,这个文件夹主要是存放 tflite 模型和 label 名称文件。
  2. 以下是主界面的代码 MainActivity.java,这个代码比较长,我们来分析这段代码,重要的方法介绍如下:
  • loadModelFile() 方法是把模型文件读取成 MappedByteBuffer,之后给 Interpreter 类初始化模型,这个模型存放在 mainassets 目录下。
  • load_model() 方法是加载模型,并得到一个对象 tflite,之后就是使用这个对象来预测图像,同时可以使用这个对象设置一些参数,比如设置使用的线程数量 tflite.setNumThreads(4);
  • showDialog() 方法是显示弹窗,通过这个弹窗的选择不同的模型。
  • readCacheLabelFromLocalFile() 方法是读取文件种分类标签对应的名称,这个文件比较长,可以参考这篇文章获取标签名称,也可以下载笔者的项目,里面有对用的文件。这个文件 cacheLabel.txt 跟模型一样存放在 assets 目录下。
  • predict_image() 方法是预测图片并显示结果的,预测的流程是:获取图片的路径,然后使用对图片进行压缩,之后把图片转换成 ByteBuffer 格式的数据,最后调用 tflite.run() 方法进行预测。
  • get_max_result() 方法是获取最大概率的标签。
package com.yeyupiaoling.testtflite;

import android.Manifest;
import android.app.Activity;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import com.bumptech.glide.Glide;
import com.bumptech.glide.load.engine.DiskCacheStrategy;
import com.bumptech.glide.request.RequestOptions;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;

public class MainActivity extends AppCompatActivity {
    private static final String TAG = MainActivity.class.getName();
    private static final int USE_PHOTO = 1001;
    private static final int START_CAMERA = 1002;
    private String camera_image_path;
    private ImageView show_image;
    private TextView result_text;
    private String assets_path = "lite_images";
    private boolean load_result = false;
    private int[] ddims = {1, 3, 224, 224};
    private int model_index = 0;
    private List<String> resultLabel = new ArrayList<>();
    private Interpreter tflite = null;

    private static final String[] PADDLE_MODEL = {
            "mobilenet_v1",
            "mobilenet_v2"
    };


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        init_view();
        readCacheLabelFromLocalFile();
    }

    // initialize view
    private void init_view() {
        request_permissions();
        show_image = (ImageView) findViewById(R.id.show_image);
        result_text = (TextView) findViewById(R.id.result_text);
        result_text.setMovementMethod(ScrollingMovementMethod.getInstance());
        Button load_model = (Button) findViewById(R.id.load_model);
        Button use_photo = (Button) findViewById(R.id.use_photo);
        Button start_photo = (Button) findViewById(R.id.start_camera);

        load_model.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                showDialog();
            }
        });

        // use photo click
        use_photo.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                if (!load_result) {
                    Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
                    return;
                }
                PhotoUtil.use_photo(MainActivity.this, USE_PHOTO);
            }
        });

        // start camera click
        start_photo.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                if (!load_result) {
                    Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
                    return;
                }
                camera_image_path = PhotoUtil.start_camera(MainActivity.this, START_CAMERA);
            }
        });
    }

    /**
     * Memory-map the model file in Assets.
     */
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }


    // load infer model
    private void load_model(String model) {
        try {
            tflite = new Interpreter(loadModelFile(model));
            Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
            Log.d(TAG, model + " model load success");
            tflite.setNumThreads(4);
            load_result = true;
        } catch (IOException e) {
            Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
            Log.d(TAG, model + " model load fail");
            load_result = false;
            e.printStackTrace();
        }
    }

    public void showDialog() {
        AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);

        // set dialog title
        builder.setTitle("Please select model");

        // set dialog icon
        builder.setIcon(android.R.drawable.ic_dialog_alert);

        // able click other will cancel
        builder.setCancelable(true);

        // cancel button
        builder.setNegativeButton("cancel", null);

        // set list
        builder.setSingleChoiceItems(PADDLE_MODEL, model_index, new DialogInterface.OnClickListener() {
            @Override
            public void onClick(DialogInterface dialog, int which) {
                model_index = which;
                load_model(PADDLE_MODEL[model_index]);
                dialog.dismiss();
            }
        });

        // show dialog
        builder.show();
    }


    private void readCacheLabelFromLocalFile() {
        try {
            AssetManager assetManager = getApplicationContext().getAssets();
            BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt")));
            String readLine = null;
            while ((readLine = reader.readLine()) != null) {
                resultLabel.add(readLine);
            }
            reader.close();
        } catch (Exception e) {
            Log.e("labelCache", "error " + e);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        String image_path;
        RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE);
        if (resultCode == Activity.RESULT_OK) {
            switch (requestCode) {
                case USE_PHOTO:
                    if (data == null) {
                        Log.w(TAG, "user photo data is null");
                        return;
                    }
                    Uri image_uri = data.getData();
                    Glide.with(MainActivity.this).load(image_uri).apply(options).into(show_image);
                    // get image path from uri
                    image_path = PhotoUtil.get_path_from_URI(MainActivity.this, image_uri);
                    // predict image
                    predict_image(image_path);
                    break;
                case START_CAMERA:
                    // show photo
                    Glide.with(MainActivity.this).load(camera_image_path).apply(options).into(show_image);
                    // predict image
                    predict_image(camera_image_path);
                    break;
            }
        }
    }

    //  predict image
    private void predict_image(String image_path) {
        // picture to float array
        Bitmap bmp = PhotoUtil.getScaleBitmap(image_path);
        ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims);
        try {
            // Data format conversion takes too long
            // Log.d("inputData", Arrays.toString(inputData));
            float[][] labelProbArray = new float[1][1001];
            long start = System.currentTimeMillis();
            // get predict result
            tflite.run(inputData, labelProbArray);
            long end = System.currentTimeMillis();
            long time = end - start;
            float[] results = new float[labelProbArray[0].length];
            System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
            // show predict result and time
            int r = get_max_result(results);
            String show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
            result_text.setText(show_text);
        } catch (Exception e) {
            e.printStackTrace();
        }

    // get max probability label
    private int get_max_result(float[] result) {
        float probability = result[0];
        int r = 0;
        for (int i = 0; i < result.length; i++) {
            if (probability < result[i]) {
                probability = result[i];
                r = i;
            }
        }
        return r;
    }

    // request permissions
    private void request_permissions() {

        List<String> permissionList = new ArrayList<>();
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            permissionList.add(Manifest.permission.CAMERA);
        }

        if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
            permissionList.add(Manifest.permission.WRITE_EXTERNAL_STORAGE);
        }

        if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
            permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
        }

        // if list is not empty will request permissions
        if (!permissionList.isEmpty()) {
            ActivityCompat.requestPermissions(this, permissionList.toArray(new String[permissionList.size()]), 1);
        }
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        switch (requestCode) {
            case 1:
                if (grantResults.length > 0) {
                    for (int i = 0; i < grantResults.length; i++) {

                        int grantResult = grantResults[i];
                        if (grantResult == PackageManager.PERMISSION_DENIED) {
                            String s = permissions[i];
                            Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show();
                        }
                    }
                }
                break;
        }
    }
}
  1. 以下的代码片段是一个工具类 PhotoUtil.java,各方法功能如下:
  • start_camera() 方法是启动相机拍照并返回图片的路径,兼容了 Android 7.0。
  • use_photo() 方法是打开相册,获取选择的图片的 URI。
  • get_path_from_URI() 方法是把图片的 URI 转换成图片路径。
  • getScaledMatrix() 方法是把图片的 Bitmap 格式转换成 TensorFlow Lite 所需的数据格式。
  • getScaleBitmap() 方法是压缩图片,防止内存溢出。
package com.yeyupiaoling.testtflite;

import android.app.Activity;
import android.content.Context;
import android.content.Intent;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.net.Uri;
import android.os.Build;
import android.os.Environment;
import android.provider.MediaStore;
import android.support.v4.content.FileProvider;
import android.util.Log;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;


public class PhotoUtil {

    // start camera
    public static String start_camera(Activity activity, int requestCode) {
        Uri imageUri;
        // save image in cache path
        File outputImage = new File(Environment.getExternalStorageDirectory().getAbsolutePath()
                + "/lite_mobile/", System.currentTimeMillis() + ".jpg");
        Log.d("outputImage", outputImage.getAbsolutePath());
        try {
            if (outputImage.exists()) {
                outputImage.delete();
            }
            File out_path = new File(Environment.getExternalStorageDirectory().getAbsolutePath()
                    + "/lite_mobile/");
            if (!out_path.exists()) {
                out_path.mkdirs();
            }
            outputImage.createNewFile();
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (Build.VERSION.SDK_INT >= 24) {
            // compatible with Android 7.0 or over
            imageUri = FileProvider.getUriForFile(activity,
                    "com.yeyupiaoling.testtflite.fileprovider", outputImage);
        } else {
            imageUri = Uri.fromFile(outputImage);
        }
        // set system camera Action
        Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
        intent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
        // set save photo path
        intent.putExtra(MediaStore.EXTRA_OUTPUT, imageUri);
        // set photo quality, min is 0, max is 1
        intent.putExtra(MediaStore.EXTRA_VIDEO_QUALITY, 0);
        activity.startActivityForResult(intent, requestCode);
        // return image absolute path
        return outputImage.getAbsolutePath();
    }

    // get picture in photo
    public static void use_photo(Activity activity, int requestCode) {
        Intent intent = new Intent(Intent.ACTION_PICK);
        intent.setType("image/*");
        activity.startActivityForResult(intent, requestCode);
    }

    // get photo from Uri
    public static String get_path_from_URI(Context context, Uri uri) {
        String result;
        Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
        if (cursor == null) {
            result = uri.getPath();
        } else {
            cursor.moveToFirst();
            int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
            result = cursor.getString(idx);
            cursor.close();
        }
        return result;
    }

    // TensorFlow model,get predict data
    public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {
        ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
        imgData.order(ByteOrder.nativeOrder());
        // get image pixel
        int[] pixels = new int[ddims[2] * ddims[3]];
        Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);
        bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]);
        int pixel = 0;
        for (int i = 0; i < ddims[2]; ++i) {
            for (int j = 0; j < ddims[3]; ++j) {
                final int val = pixels[pixel++];
                imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
                imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
                imgData.putFloat((((val & 0xFF) - 128f) / 128f));
            }
        }

        if (bm.isRecycled()) {
            bm.recycle();
        }
        return imgData;
    }

    // compress picture
    public static Bitmap getScaleBitmap(String filePath) {
        BitmapFactory.Options opt = new BitmapFactory.Options();
        opt.inJustDecodeBounds = true;
        BitmapFactory.decodeFile(filePath, opt);

        int bmpWidth = opt.outWidth;
        int bmpHeight = opt.outHeight;
        int maxSize = 500;
        // compress picture with inSampleSize
        opt.inSampleSize = 1;
        while (true) {
            if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) {
                break;
            }
            opt.inSampleSize *= 2;
        }
        opt.inJustDecodeBounds = false;
        return BitmapFactory.decodeFile(filePath, opt);
    }
}
  1. AndroidManifest.xml 下加上申请的权限,用到了相机和读取外部存储的内存:
    <uses-permission android:name="android.permission.CAMERA"/>
    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>

然后还要在 application 下加上以下的配置信息,这个主要是为了兼容 Android 7.0 的相机:

        <!-- FileProvider配置访问路径,适配7.0及其以上 -->
        <provider
            android:name="android.support.v4.content.FileProvider"
            android:authorities="com.yeyupiaoling.testtflite.fileprovider"
            android:exported="false"
            android:grantUriPermissions="true">
            <meta-data
                android:name="android.support.FILE_PROVIDER_PATHS"
                android:resource="@xml/file_paths"/>
        </provider>
  1. 之后在 res 创建一个 xml 目录,然后创建一个 file_paths.xml 文件,在这个文件中加上以下代码,这个是我们拍照之后图片存放的位置:
<?xml version="1.0" encoding="utf-8"?>
<resources>
    <external-path
        name="images"
        path="lite_mobile/" />
</resources>
  1. 主界面布局代码 activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <LinearLayout
        android:id="@+id/btn1_ll"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:orientation="horizontal">

        <Button
            android:id="@+id/use_photo"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="相册" />

        <Button
            android:id="@+id/start_camera"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="拍照" />
    </LinearLayout>

    <LinearLayout
        android:id="@+id/btn2_ll"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_above="@id/btn1_ll"
        android:orientation="horizontal">

        <Button
            android:id="@+id/load_model"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="加载模型" />
    </LinearLayout>

    <TextView
        android:id="@+id/result_text"
        android:layout_width="match_parent"
        android:layout_height="150dp"
        android:layout_above="@id/btn2_ll"
        android:hint="预测结果会在这里显示"
        android:inputType="textMultiLine"
        android:textSize="16sp"
        tools:ignore="TextViewEdits" />

    <ImageView
        android:id="@+id/show_image"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:layout_above="@id/result_text"
        android:layout_alignParentTop="true" />
</RelativeLayout>

以下就是效果图片:
在这里插入图片描述

上面已经提高了全部代码,这里为了方便读者调试,这里可以在这里源码下载,然后使用 Android Studio 打开。
源码地址:https://resource.doiduoyi.com/#c1uo2s4

参考资料

  1. https://www.tensorflow.org/mobile/tflite/devguide?hl=zh-cn
  2. https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo
  3. https://docs.bazel.build/versions/master/install.html
  4. https://blog.csdn.net/computerme/article/details/80699671

标题:使用TensorFlow Lite在Android手机上实现图像分类
作者:夜雨飘零
地址:https://blog.doiduoyi.com/articles/1584973744126.html

评论

发表评论