java - 尝试在空对象引用上调用虚拟方法“void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)”
问题描述
我正在按照 tensorflow 给出的文本分类演示示例在 Android Studio 上运行。但是,在运行应用程序时,点击预测按钮后,应用程序崩溃并出现以下错误。
E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.mltest, PID: 6318
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
at com.example.mltest.TextClassificationClient.classify(TextClassificationClient.java:154)
at com.example.mltest.MainActivity.lambda$classify$3$MainActivity(MainActivity.java:73)
at com.example.mltest.-$$Lambda$MainActivity$iZpagZiqjnywt769FNidzy-9BHU.run(Unknown Source:4)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:193)
at android.app.ActivityThread.main(ActivityThread.java:6669)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)
这是 TextClassificationClient java 文件。
package com.example.mltest;
public class TextClassificationClient {
private static final String TAG = "TextClassificationDemo";
private static final String MODEL_PATH = "text_classification.tflite";
private static final String DIC_PATH = "text_classification_vocab.txt";
private static final String LABEL_PATH = "text_classification_labels.txt";
private static final int SENTENCE_LEN = 256;
private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\\\,|\\\\.|\\\\!|\\\\?|\\n";
private static final String START = "<START>";
private static final String PAD = "<PAD>";
private static final String UNKNOWN = "<UNKNOWN>";
private static final int MAX_RESULTS = 3;
private final Context context;
private final Map<String, Integer> dic = new HashMap<>();
private final List<String> labels = new ArrayList<>();
private Interpreter tflite;
public static class Result {
private final String id;
private final String title;
private final Float confidence;
public Result(String id, String title, Float confidence) {
this.id = id;
this.title = title;
this.confidence = confidence;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
@SuppressLint("DefaultLocale")
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
return resultString.trim();
}
};
public TextClassificationClient(Context context) {
this.context = context;
}
@WorkerThread
public void load() {
loadModel();
loadDictionary();
loadLabels();
}
@WorkerThread
private synchronized void loadModel() {
try {
ByteBuffer buffer = loadModelFile(this.context.getAssets());
tflite = new Interpreter(buffer);
Log.v(TAG, "TFLite Model Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void loadDictionary() {
try {
loadDictionaryFile(this.context.getAssets());
Log.v(TAG, "Dictionary Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void loadLabels() {
try {
loadLabelFile(this.context.getAssets());
Log.v(TAG, "Labels Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void unload(){
tflite.close();
dic.clear();
labels.clear();
}
@WorkerThread
public synchronized List<Result> classify(String text) {
float[][] input = tokenizeInputText(text);
Log.v(TAG, "Classifying with TFLite");
float[][] output = new float[1][labels.size()];
System.out.println("input inside classify in textclient" + Arrays.deepToString(input) + " and labels size is " + labels.size());
System.out.println("Out put is " + Arrays.deepToString(output));
tflite.run(input, output);
PriorityQueue<Result> pq = new PriorityQueue<>(
MAX_RESULTS, (lhs, rhs) -> Float.compare(rhs.getConfidence(), lhs.getConfidence()));
for(int i = 0; i < labels.size(); i++) {
pq.add(new Result("" + i, labels.get(i), output[0][i]));
}
final ArrayList<Result> results = new ArrayList<>();
while (!pq.isEmpty()){
results.add(pq.poll());
}
return results;
}
private static MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
try(AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
private void loadLabelFile(AssetManager assetManager) throws IOException{
try (InputStream ins = assetManager.open(LABEL_PATH);
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(ins))){
while (bufferedReader.ready()) {
labels.add(bufferedReader.readLine());
}
}
}
private void loadDictionaryFile(AssetManager assetManager) throws IOException{
try (InputStream ins = assetManager.open(DIC_PATH);
BufferedReader reader = new BufferedReader(new InputStreamReader(ins))){
while (reader.ready()){
List<String> line = Arrays.asList(reader.readLine().split(" "));
if (line.size() < 2){
continue;
}
dic.put(line.get(0), Integer.parseInt(line.get(1)));
}
}
}
float[][] tokenizeInputText(String text) {
float[] tmp = new float[SENTENCE_LEN];
List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));
int index = 0;
// Prepend <START> if it is in vocabulary file.
if (dic.containsKey(START)) {
tmp[index++] = dic.get(START);
}
for (String word : array) {
if (index >= SENTENCE_LEN) {
break;
}
tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
}
// Padding and wrapping.
Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
float[][] ans = {tmp};
return ans;
}
Map<String, Integer> getDic() {
return this.dic;
}
Interpreter getTflite() {
return this.tflite;
}
List<String> getLabels(){
return this.labels;
}
}
和 MainActivity java 文件。
public class MainActivity extends AppCompatActivity {
private static final String TAG = "TextClassificationDemo";
private TextClassificationClient client;
private TextView resultTextView;
private EditText inputEditText;
private Handler handler;
private ScrollView scrollView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Log.v(TAG, "On Create");
client = new TextClassificationClient(getApplicationContext());
handler = new Handler();
Button classifyButton = findViewById(R.id.button);
classifyButton.setOnClickListener( (View v) -> {
classify(inputEditText.getText().toString());
});
resultTextView = findViewById(R.id.result_text_view);
inputEditText = findViewById(R.id.input_text);
scrollView = findViewById(R.id.scroll_view);
}
@Override
protected void onStart(){
super.onStart();
Log.v(TAG, "OnStart");
handler.post(
() -> {
client.load();
}
);
}
@Override
protected void onStop(){
super.onStop();
Log.v(TAG, "OnStop");
handler.post(
() -> {
client.load();
}
);
}
private void classify(final String text) {
System.out.println("Text inside classify of Main Activity " + text);
handler.post(
() -> {
List<TextClassificationClient.Result> results = client.classify(text);
showResult(text, results);
}
);
}
private void showResult(final String inputText, final List<TextClassificationClient.Result> results){
runOnUiThread(
() -> {
String textToShow = "Input : " + inputText + "\nOutput : \n";
for (int i = 0; i < results.size(); i++) {
TextClassificationClient.Result result = results.get(i);
textToShow += String.format(" %s: %s\\n", result.getTitle(), result.getConfidence());
}
textToShow += "---------\\n";
resultTextView.append(textToShow);
inputEditText.getText().clear();
scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
}
);
}
}
这是我的 gradle 文件。
apply plugin: 'com.android.application'
android {
compileSdkVersion 28
buildToolsVersion "30.0.2"
defaultConfig {
applicationId "com.example.mltest"
minSdkVersion 28
targetSdkVersion 28
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
aaptOptions {
noCompress "tflite"
noCompress "lite"
}
}
dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
}
我已经关注了 SO 的其他链接,其中提出了同样的问题,但它们没有任何帮助。请帮我解决这个问题。先感谢您!
解决方案
解决!tflite 文件未正确添加到资产文件夹。添加后运行流畅
推荐阅读
- sql-update - UPDATE 与 JOIN 不选择期望值
- python - 是否可以更改(更新)已保存在 .npy 文件中的 numpy 数组的条目?如何?
- python-3.x - ModuleNotFoundError:没有名为“cassandra”的模块
- optaplanner - Optaplanner:可重现的解决方案
- html - Angularjs - 如何根据值限制 ng 重复
- algorithm - 模糊字符串记录搜索算法(支持单词转置和字符转置)
- javascript - 如何调试无法正确呈现的日期选择器
- git - 用于简单 git 项目的 Maven 发布插件 - 错误:不是工作副本
- python-3.x - 函数python 3.x中的全局变量
- php - 错误 json 解码 API