python - 在android中使用Tensorflow模型
问题描述
我有一个 Tensorflow 模型,我已将其转换为“.tflite”,但我不知道如何在 android 上实现它。我遵循 TensorFlow 指南在 android 中实现它,但由于 TensorFlow 网站没有 XML 代码,我正在努力将它与前端 (XML) 连接起来。我需要清楚地解释如何使用 java 在 android studio 中使用我的模型。
我按照 TensorFlow 网站上给出的官方说明在 android 中实现模型。
解决方案
tflite
Tensorflow中如何实现基于模型的对象检测的示例代码。我想这些答案不是最好的答案,但我碰巧有一个简单的例子来说明你的确切问题。
注意:它确实会检测对象并将其标签输出到标准输出中,使用Log.d
. 不会在检测到的图像周围绘制任何框或标签。
从这里下载开始的模型和标签。将它们放入assets
项目的文件夹中。
爪哇
import android.content.pm.PackageManager;
import android.media.Image;
import android.os.Bundle;
import android.util.Log;
import android.widget.Toast;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.Camera;
import androidx.camera.core.CameraSelector;
import androidx.camera.core.ExperimentalGetImage;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.lifecycle.ProcessCameraProvider;
import androidx.camera.view.PreviewView;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.mlkit.common.model.LocalModel;
import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.objects.DetectedObject;
import com.google.mlkit.vision.objects.ObjectDetection;
import com.google.mlkit.vision.objects.ObjectDetector;
import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.ExecutionException;
public class ActivityExample extends AppCompatActivity {
private ListenableFuture<ProcessCameraProvider> cameraProviderFuture;
private ObjectDetector objectDetector;
private PreviewView prevView;
private List<String> labels;
private int REQUEST_CODE_PERMISSIONS = 101;
private String[] REQUIRED_PERMISSIONS =
new String[]{"android.permission.CAMERA"};
@Override
protected void onCreate(@Nullable Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_fullscreen);
prevView = findViewById(R.id.viewFinder);
prepareObjectDetector();
prepareLabels();
if (allPermissionsGranted()) {
startCamera();
} else {
ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
}
}
private void prepareLabels() {
try {
InputStreamReader reader = new InputStreamReader(getAssets().open("labels_mobilenet_quant_v1_224.txt"));
labels = readLines(reader);
} catch (IOException e) {
e.printStackTrace();
}
}
private List<String> readLines(InputStreamReader reader) {
BufferedReader bufferedReader = new BufferedReader(reader, 8 * 1024);
Iterator<String> iterator = new LinesSequence(bufferedReader);
ArrayList<String> list = new ArrayList<>();
while (iterator.hasNext()) {
list.add(iterator.next());
}
return list;
}
private void prepareObjectDetector() {
CustomObjectDetectorOptions options = new CustomObjectDetectorOptions.Builder(loadModel("mobilenet_v1_1.0_224_quant.tflite"))
.setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
.enableMultipleObjects()
.enableClassification()
.setClassificationConfidenceThreshold(0.5f)
.setMaxPerObjectLabelCount(3)
.build();
objectDetector = ObjectDetection.getClient(options);
}
private LocalModel loadModel(String assetFileName) {
return new LocalModel.Builder()
.setAssetFilePath(assetFileName)
.build();
}
private void startCamera() {
cameraProviderFuture = ProcessCameraProvider.getInstance(this);
cameraProviderFuture.addListener(() -> {
try {
ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
bindPreview(cameraProvider);
} catch (ExecutionException e) {
// No errors need to be handled for this Future.
// This should never be reached.
} catch (InterruptedException e) {
}
}, ContextCompat.getMainExecutor(this));
}
private void bindPreview(ProcessCameraProvider cameraProvider) {
Preview preview = new Preview.Builder().build();
CameraSelector cameraSelector = new CameraSelector.Builder()
.requireLensFacing(CameraSelector.LENS_FACING_BACK)
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis.Builder()
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
.build();
YourAnalyzer yourAnalyzer = new YourAnalyzer();
yourAnalyzer.setObjectDetector(objectDetector, labels);
imageAnalysis.setAnalyzer(
ContextCompat.getMainExecutor(this),
yourAnalyzer);
Camera camera =
cameraProvider.bindToLifecycle(
this,
cameraSelector,
preview,
imageAnalysis
);
preview.setSurfaceProvider(prevView.createSurfaceProvider(camera.getCameraInfo()));
}
private Boolean allPermissionsGranted() {
for (String permission : REQUIRED_PERMISSIONS) {
if (ContextCompat.checkSelfPermission(
this,
permission
) != PackageManager.PERMISSION_GRANTED
) {
return false;
}
}
return true;
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
startCamera();
} else {
Toast.makeText(this, "Permissions not granted by the user.", Toast.LENGTH_SHORT)
.show();
finish();
}
}
}
private static class YourAnalyzer implements ImageAnalysis.Analyzer {
private ObjectDetector objectDetector;
private List<String> labels;
public void setObjectDetector(ObjectDetector objectDetector, List<String> labels) {
this.objectDetector = objectDetector;
this.labels = labels;
}
@Override
@ExperimentalGetImage
public void analyze(@NonNull ImageProxy imageProxy) {
Image mediaImage = imageProxy.getImage();
if (mediaImage != null) {
InputImage image = InputImage.fromMediaImage(
mediaImage,
imageProxy.getImageInfo().getRotationDegrees()
);
objectDetector
.process(image)
.addOnFailureListener(e -> imageProxy.close())
.addOnSuccessListener(detectedObjects -> {
// list of detectedObjects has all the information you need
StringBuilder builder = new StringBuilder();
for (DetectedObject detectedObject : detectedObjects) {
for (DetectedObject.Label label : detectedObject.getLabels()) {
builder.append(labels.get(label.getIndex()));
builder.append("\n");
}
}
Log.d("OBJECTS DETECTED", builder.toString().trim());
imageProxy.close();
});
}
}
}
static class LinesSequence implements Iterator<String> {
private BufferedReader reader;
private String nextValue;
private Boolean done = false;
public LinesSequence(BufferedReader reader) {
this.reader = reader;
}
@Override
public boolean hasNext() {
if (nextValue == null && !done) {
try {
nextValue = reader.readLine();
} catch (IOException e) {
e.printStackTrace();
nextValue = null;
}
if (nextValue == null) done = true;
}
return nextValue != null;
}
@Override
public String next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
String answer = nextValue;
nextValue = null;
return answer;
}
}
}
XML 布局
<?xml version="1.0" encoding="utf-8"?>
<androidx.camera.view.PreviewView
xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/viewFinder"
android:layout_width="match_parent"
android:layout_height="match_parent" />
Gradle 文件配置
android {
...
aaptOptions {
noCompress "tflite" // Your model\'s file extension: "tflite", "lite", etc.
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
...
implementation 'com.google.mlkit:object-detection-custom:16.0.0'
def camerax_version = "1.0.0-beta03"
// CameraX core library using camera2 implementation
implementation "androidx.camera:camera-camera2:$camerax_version"
// CameraX Lifecycle Library
implementation "androidx.camera:camera-lifecycle:$camerax_version"
// CameraX View class
implementation "androidx.camera:camera-view:1.0.0-alpha10"
}
推荐阅读
- java - 如何使用 Avro (schemaRegistry) 对 Kafka Streams 进行功能测试?
- cmake - Ubuntu 上的 CMake/Mingw/Qt5 无法找到 stdlib.h
- python - Django 中的嵌套视图 - Python
- python - matplotlib 限制大量 y 图的 y 标记数
- python - 显示变量的值
- ios - 如何防止 react-native-device-info 在我的设置中导致 React 歧义?
- java - 弹簧验证不起作用。当我使用 entitymanager.persist() 提交空白表格时;
- angular - 在 ngx-datatable-column 中使用 ngFor 和 ng 模板
- java - 处理程序没有适配器 - DispatcherServlet 配置需要包含支持此处理程序的 HandlerAdapter
- python - 我们如何使用 python pandas 将 csv 并排合并为列?