首页 > 解决方案 > 具有 onehot 编码输出的 TensorFlowInferenceInterface

问题描述

我已经训练了一个神经网络,它以 4 个浮点值作为输入,并为四个类标签返回一个热编码输出。

例如,{2,12,30,4} -> {0, 0, 1, 0}

训练后的模型生成并保存在 .pb 文件中。然后将该模型导入到我的 android 应用程序的资产文件夹中:

inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "tensorflow_lite_xor_nn.pb");

我有以下功能:

private float[] predict(float[] input){
    float output[] = new float[4];

    inferenceInterface.feed("dense_1_input", input, 4, input.length);
    inferenceInterface.run(new String[]{"dense_2/Sigmoid"});
    inferenceInterface.fetch("dense_2/Sigmoid", output);

    return output;
}

但我收到此错误:

java.lang.IllegalArgumentException: 具有 4 个元素的缓冲区与形状为 [4, 4] 的张量不兼容

标签: androidneural-networktensorflow-lite

解决方案


推荐阅读