首页 > 解决方案 > 在android中使用Tensorflow模型

问题描述

我有一个 Tensorflow 模型,我已将其转换为“.tflite”,但我不知道如何在 android 上实现它。我遵循 TensorFlow 指南在 android 中实现它,但由于 TensorFlow 网站没有 XML 代码,我正在努力将它与前端 (XML) 连接起来。我需要清楚地解释如何使用 java 在 android studio 中使用我的模型。

我按照 TensorFlow 网站上给出的官方说明在 android 中实现模型。

标签: pythonandroidtensorflowmachine-learningtensorflow2.0

解决方案


tfliteTensorflow中如何实现基于模型的对象检测的示例代码。我想这些答案不是最好的答案,但我碰巧有一个简单的例子来说明你的确切问题。

注意:它确实会检测对象并将其标签输出到标准输出中,使用Log.d. 不会在检测到的图像周围绘制任何框或标签。

从这里下载开始的模型和标签。将它们放入assets项目的文件夹中。

爪哇

import android.content.pm.PackageManager;
import android.media.Image;
import android.os.Bundle;
import android.util.Log;
import android.widget.Toast;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.Camera;
import androidx.camera.core.CameraSelector;
import androidx.camera.core.ExperimentalGetImage;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.lifecycle.ProcessCameraProvider;
import androidx.camera.view.PreviewView;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.mlkit.common.model.LocalModel;
import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.objects.DetectedObject;
import com.google.mlkit.vision.objects.ObjectDetection;
import com.google.mlkit.vision.objects.ObjectDetector;
import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.ExecutionException;

public class ActivityExample extends AppCompatActivity {
    private ListenableFuture<ProcessCameraProvider> cameraProviderFuture;
    private ObjectDetector objectDetector;
    private PreviewView prevView;
    private List<String> labels;

    private int REQUEST_CODE_PERMISSIONS = 101;
    private String[] REQUIRED_PERMISSIONS =
            new String[]{"android.permission.CAMERA"};

    @Override
    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        setContentView(R.layout.activity_fullscreen);
        prevView = findViewById(R.id.viewFinder);

        prepareObjectDetector();
        prepareLabels();

        if (allPermissionsGranted()) {
            startCamera();
        } else {
            ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
        }
    }

    private void prepareLabels() {
        try {
            InputStreamReader reader = new InputStreamReader(getAssets().open("labels_mobilenet_quant_v1_224.txt"));
            labels = readLines(reader);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private List<String> readLines(InputStreamReader reader) {
        BufferedReader bufferedReader = new BufferedReader(reader, 8 * 1024);
        Iterator<String> iterator = new LinesSequence(bufferedReader);

        ArrayList<String> list = new ArrayList<>();

        while (iterator.hasNext()) {
            list.add(iterator.next());
        }

        return list;
    }

    private void prepareObjectDetector() {
        CustomObjectDetectorOptions options = new CustomObjectDetectorOptions.Builder(loadModel("mobilenet_v1_1.0_224_quant.tflite"))
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableMultipleObjects()
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();
        objectDetector = ObjectDetection.getClient(options);
    }

    private LocalModel loadModel(String assetFileName) {
        return new LocalModel.Builder()
                .setAssetFilePath(assetFileName)
                .build();
    }

    private void startCamera() {
        cameraProviderFuture = ProcessCameraProvider.getInstance(this);
        cameraProviderFuture.addListener(() -> {
            try {
                ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
                bindPreview(cameraProvider);
            } catch (ExecutionException e) {
                // No errors need to be handled for this Future.
                // This should never be reached.
            } catch (InterruptedException e) {
            }
        }, ContextCompat.getMainExecutor(this));
    }

    private void bindPreview(ProcessCameraProvider cameraProvider) {
        Preview preview = new Preview.Builder().build();
        CameraSelector cameraSelector = new CameraSelector.Builder()
                .requireLensFacing(CameraSelector.LENS_FACING_BACK)
                .build();
        ImageAnalysis imageAnalysis = new ImageAnalysis.Builder()
                .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
                .build();
        YourAnalyzer yourAnalyzer = new YourAnalyzer();
        yourAnalyzer.setObjectDetector(objectDetector, labels);
        imageAnalysis.setAnalyzer(
                ContextCompat.getMainExecutor(this),
                yourAnalyzer);

        Camera camera =
                cameraProvider.bindToLifecycle(
                        this,
                        cameraSelector,
                        preview,
                        imageAnalysis
                );

        preview.setSurfaceProvider(prevView.createSurfaceProvider(camera.getCameraInfo()));
    }

    private Boolean allPermissionsGranted() {
        for (String permission : REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(
                    this,
                    permission
            ) != PackageManager.PERMISSION_GRANTED
            ) {
                return false;
            }
        }
        return true;
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera();
            } else {
                Toast.makeText(this, "Permissions not granted by the user.", Toast.LENGTH_SHORT)
                        .show();
                finish();
            }
        }
    }

    private static class YourAnalyzer implements ImageAnalysis.Analyzer {
        private ObjectDetector objectDetector;
        private List<String> labels;

        public void setObjectDetector(ObjectDetector objectDetector, List<String> labels) {
            this.objectDetector = objectDetector;
            this.labels = labels;
        }

        @Override
        @ExperimentalGetImage
        public void analyze(@NonNull ImageProxy imageProxy) {
            Image mediaImage = imageProxy.getImage();
            if (mediaImage != null) {
                InputImage image = InputImage.fromMediaImage(
                        mediaImage,
                        imageProxy.getImageInfo().getRotationDegrees()
                );
                objectDetector
                        .process(image)
                        .addOnFailureListener(e -> imageProxy.close())
                        .addOnSuccessListener(detectedObjects -> {
                            // list of detectedObjects has all the information you need
                            StringBuilder builder = new StringBuilder();
                            for (DetectedObject detectedObject : detectedObjects) {
                                for (DetectedObject.Label label : detectedObject.getLabels()) {
                                    builder.append(labels.get(label.getIndex()));
                                    builder.append("\n");
                                }
                            }
                            Log.d("OBJECTS DETECTED", builder.toString().trim());
                            imageProxy.close();
                        });
            }
        }
    }


    static class LinesSequence implements Iterator<String> {
        private BufferedReader reader;
        private String nextValue;
        private Boolean done = false;

        public LinesSequence(BufferedReader reader) {
            this.reader = reader;
        }

        @Override
        public boolean hasNext() {
            if (nextValue == null && !done) {
                try {
                    nextValue = reader.readLine();
                } catch (IOException e) {
                    e.printStackTrace();
                    nextValue = null;
                }
                if (nextValue == null) done = true;
            }
            return nextValue != null;
        }

        @Override
        public String next() {
            if (!hasNext()) {
                throw new NoSuchElementException();
            }
            String answer = nextValue;
            nextValue = null;
            return answer;
        }
    }
}

XML 布局

<?xml version="1.0" encoding="utf-8"?>
<androidx.camera.view.PreviewView
    xmlns:android="http://schemas.android.com/apk/res/android"
    android:id="@+id/viewFinder"
    android:layout_width="match_parent"
    android:layout_height="match_parent" />

Gradle 文件配置

android {
    ...
    aaptOptions {
        noCompress "tflite"  // Your model\'s file extension: "tflite", "lite", etc.
    }
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}


dependencies {
    ...
    
    implementation 'com.google.mlkit:object-detection-custom:16.0.0'
    def camerax_version = "1.0.0-beta03"
    // CameraX core library using camera2 implementation
    implementation "androidx.camera:camera-camera2:$camerax_version"
    // CameraX Lifecycle Library
    implementation "androidx.camera:camera-lifecycle:$camerax_version"
    // CameraX View class
    implementation "androidx.camera:camera-view:1.0.0-alpha10"
}

推荐阅读