首页 > 解决方案 > 尝试在空对象引用上调用虚拟方法“void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)”

问题描述

我正在按照 tensorflow 给出的文本分类演示示例在 Android Studio 上运行。但是,在运行应用程序时,点击预测按钮后,应用程序崩溃并出现以下错误。

E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.mltest, PID: 6318
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
    at com.example.mltest.TextClassificationClient.classify(TextClassificationClient.java:154)
    at com.example.mltest.MainActivity.lambda$classify$3$MainActivity(MainActivity.java:73)
    at com.example.mltest.-$$Lambda$MainActivity$iZpagZiqjnywt769FNidzy-9BHU.run(Unknown Source:4)
    at android.os.Handler.handleCallback(Handler.java:873)
    at android.os.Handler.dispatchMessage(Handler.java:99)
    at android.os.Looper.loop(Looper.java:193)
    at android.app.ActivityThread.main(ActivityThread.java:6669)
    at java.lang.reflect.Method.invoke(Native Method)
    at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)

这是 TextClassificationClient java 文件。

package com.example.mltest;
public class TextClassificationClient {

private static final String TAG = "TextClassificationDemo";
private static final String MODEL_PATH = "text_classification.tflite";
private static final String DIC_PATH = "text_classification_vocab.txt";
private static final String LABEL_PATH = "text_classification_labels.txt";

private static final int SENTENCE_LEN = 256;
private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\\\,|\\\\.|\\\\!|\\\\?|\\n";

private static final String START = "<START>";
private static final String PAD = "<PAD>";
private static final String UNKNOWN = "<UNKNOWN>";

private static final int MAX_RESULTS = 3;

private final Context context;
private final Map<String, Integer> dic = new HashMap<>();
private final List<String> labels = new ArrayList<>();
private Interpreter tflite;

public static class Result {

    private final String id;
    private final String title;
    private final Float confidence;

    public Result(String id, String title, Float confidence) {
        this.id = id;
        this.title = title;
        this.confidence = confidence;
    }

    public String getId() {
        return id;
    }

    public String getTitle() {
        return title;
    }

    public Float getConfidence() {
        return confidence;
    }

    @SuppressLint("DefaultLocale")
    @Override
    public String toString() {
        String resultString = "";

        if (id != null) {
            resultString += "[" + id + "] ";
        }

        if (title != null) {
            resultString += title + " ";
        }

        if (confidence != null) {
            resultString += String.format("(%.1f%%) ", confidence * 100.0f);
        }

        return resultString.trim();
    }
};

public TextClassificationClient(Context context) {
    this.context = context;
}

@WorkerThread
public void load() {
    loadModel();
    loadDictionary();
    loadLabels();
}

@WorkerThread
private synchronized void loadModel() {
    try {
        ByteBuffer buffer = loadModelFile(this.context.getAssets());
        tflite = new Interpreter(buffer);
        Log.v(TAG, "TFLite Model Loaded");

    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void loadDictionary() {
    try {
        loadDictionaryFile(this.context.getAssets());
        Log.v(TAG, "Dictionary Loaded");
    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void loadLabels() {
    try {
        loadLabelFile(this.context.getAssets());
        Log.v(TAG, "Labels Loaded");
    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void unload(){
    tflite.close();
    dic.clear();
    labels.clear();
}

@WorkerThread
public synchronized List<Result> classify(String text) {
    float[][] input = tokenizeInputText(text);

    Log.v(TAG, "Classifying with TFLite");

    float[][] output = new float[1][labels.size()];
    System.out.println("input inside classify in textclient" + Arrays.deepToString(input) + " and labels size is " + labels.size());
    System.out.println("Out put is " + Arrays.deepToString(output));
    tflite.run(input, output);

    PriorityQueue<Result> pq = new PriorityQueue<>(
            MAX_RESULTS, (lhs, rhs) -> Float.compare(rhs.getConfidence(), lhs.getConfidence()));
    for(int i = 0; i < labels.size(); i++) {
        pq.add(new Result("" + i, labels.get(i), output[0][i]));
    }

    final ArrayList<Result> results = new ArrayList<>();
    while (!pq.isEmpty()){
        results.add(pq.poll());
    }

    return results;
}

private static MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {

    try(AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

private void loadLabelFile(AssetManager assetManager) throws IOException{
    try (InputStream ins = assetManager.open(LABEL_PATH);
         BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(ins))){
        while (bufferedReader.ready()) {
            labels.add(bufferedReader.readLine());
        }
    }
}

private void loadDictionaryFile(AssetManager assetManager) throws IOException{
    try (InputStream ins = assetManager.open(DIC_PATH);
            BufferedReader reader = new BufferedReader(new InputStreamReader(ins))){
        while (reader.ready()){
            List<String> line = Arrays.asList(reader.readLine().split(" "));
            if (line.size() < 2){
                continue;
            }

            dic.put(line.get(0), Integer.parseInt(line.get(1)));
        }
    }
}

float[][] tokenizeInputText(String text) {

    float[] tmp = new float[SENTENCE_LEN];
    List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));

    int index = 0;
    // Prepend <START> if it is in vocabulary file.
    if (dic.containsKey(START)) {
        tmp[index++] = dic.get(START);
    }

    for (String word : array) {
        if (index >= SENTENCE_LEN) {
            break;
        }
        tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
    }
    // Padding and wrapping.
    Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
    float[][] ans = {tmp};
    return ans;
}

Map<String, Integer> getDic() {
    return this.dic;
}

Interpreter getTflite() {
    return this.tflite;
}

List<String> getLabels(){
    return this.labels;
}
}

和 MainActivity java 文件。

public class MainActivity extends AppCompatActivity {

private static final String TAG = "TextClassificationDemo";
private TextClassificationClient client;

private TextView resultTextView;
private EditText inputEditText;
private Handler handler;
private ScrollView scrollView;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);
    Log.v(TAG, "On Create");

    client = new TextClassificationClient(getApplicationContext());
    handler = new Handler();
    Button classifyButton = findViewById(R.id.button);

    classifyButton.setOnClickListener( (View v) -> {
        classify(inputEditText.getText().toString());
    });

    resultTextView = findViewById(R.id.result_text_view);
    inputEditText = findViewById(R.id.input_text);
    scrollView = findViewById(R.id.scroll_view);
}

@Override
protected void onStart(){
    super.onStart();
    Log.v(TAG, "OnStart");
    handler.post(
            () -> {
                client.load();
            }
    );
}

@Override
protected void onStop(){
    super.onStop();
    Log.v(TAG, "OnStop");
    handler.post(
            () -> {
                client.load();
            }
    );
}

private void classify(final String text) {

    System.out.println("Text inside classify of Main Activity " + text);
    handler.post(
            () -> {
                List<TextClassificationClient.Result> results = client.classify(text);

                showResult(text, results);
            }
    );
}

private void showResult(final String inputText, final List<TextClassificationClient.Result> results){
    runOnUiThread(
            () -> {
                String textToShow = "Input : " + inputText + "\nOutput : \n";
                for (int i = 0; i < results.size(); i++) {
                    TextClassificationClient.Result result = results.get(i);
                    textToShow += String.format("    %s: %s\\n", result.getTitle(), result.getConfidence());
                }

                textToShow += "---------\\n";

                resultTextView.append(textToShow);
                inputEditText.getText().clear();

                scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
            }
    );
}
}

这是我的 gradle 文件。

apply plugin: 'com.android.application'

android {
compileSdkVersion 28
buildToolsVersion "30.0.2"

defaultConfig {
    applicationId "com.example.mltest"
    minSdkVersion 28
    targetSdkVersion 28
    versionCode 1
    versionName "1.0"

    testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}

buildTypes {
    release {
        minifyEnabled false
        proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
    }
}
compileOptions {
    sourceCompatibility JavaVersion.VERSION_1_8
    targetCompatibility JavaVersion.VERSION_1_8
}

aaptOptions {
    noCompress "tflite"
    noCompress "lite"
}
}

dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'

}

我已经关注了 SO 的其他链接,其中提出了同样的问题,但它们没有任何帮助。请帮我解决这个问题。先感谢您!

标签: javaandroidandroid-studiotensorflowtensorflow-lite

解决方案


解决!tflite 文件未正确添加到资产文件夹。添加后运行流畅


推荐阅读