首页 > 解决方案 > 无法在 tensorflow lite 1.15 中运行 LSTM

问题描述

TLDR:有人可以展示如何创建 LSTM,将其转换为 TFLite,并在 android 版本 1.15 中运行它吗?

我正在尝试创建一个简单的 LSTM 模型并在带有 tensorflow v115 的 android 应用程序中运行。

** 使用 GRU 和 SimpleRNN 层时情况相同 **

创建简单的 LSTM 模型

我在 Python 中工作,尝试了两个 tensorflow 和 keras 版本:最新(2.4.1 内置 keras)和 1.1.5(我安装了 keras 版本 2.2.4)。

我创建了这个简单的模型:

model = keras.Sequential()
model.add(layers.Embedding(input_dim=1000, output_dim=64))
model.add(layers.LSTM(128))
model.add(layers.Dense(10))
model.summary()

保存它

我将其保存为“SavedModel”和“h5”格式:

model.save(f'output_models/simple_lstm_saved_model_format_{tf.__version__}', save_format='tf')
model.save(f'output_models/simple_lstm_{tf.__version__}.h5', save_format='h5')

转换为 TFLite

我尝试在 v115 和 v2 版本中创建并保存模型。

然后,我尝试通过几种方法将其转换为 TFLite。

在 TF2 中:

  1. 我尝试从 keras 模型转换:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(f"output_models/simple_lstm_tf_v{tf.__version__}.tflite", 'wb') as f:
    f.write(tflite_model)
  1. 我尝试从保存的模型转换:
converter_saved_model = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
tflite_model_from_saved_model = converter_saved_model.convert()
with open(f"{saved_model_path}_converted_tf_v{tf.__version__}.tflite", 'wb') as f:
    f.write(tflite_model_from_saved_model)
  1. 我尝试从 keras 保存的模型 (h5) 转换 - 我尝试同时使用tf.compat.v1.lite.TFLiteConverter和 tf..lite.TFLiteConverter。
converter_h5 = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(h5_model_path)
# converter_h5 = tf.lite.TFLiteConverter.from_keras_model_file(h5_model_path) # option 2
tflite_model_from_h5 = converter_h5.convert()
with open(f{h5_model_path.replace('.h5','')}_converted_tf_v1_lite_from_keras_model_file_v{tf.__version__}.tflite", 'wb') as f:
f.write(tflite_model_from_h5)

安卓应用

build.gradle(模块:app)

当我想使用 v2 时,我使用:

    implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'

当我想使用 v115 时,我implementation 'org.tensorflow:tensorflow-lite:1.15.0' 在构建等级中使用。

然后,我按照 android 中常见的 tflite 加载代码:

private MappedByteBuffer loadModelFile(Activity activity) throws IOException {

        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    LoadLSTM(Activity activity) {
        try {
            tfliteModel = loadModelFile(activity);
        } catch (IOException e) {
            e.printStackTrace();
        }
        tflite = new Interpreter(tfliteModel, tfliteOptions);
        Log.d(TAG, "*** Loaded model *** " + getModelPath());
    }

当我使用 v2 时,模型已加载。当我使用 v115 时,在我尝试过的所有选项中,我收到如下错误: A/libc: Fatal signal 11 (SIGSEGV), code 1 (SEGV_MAPERR), fault addr 0x70 in tid 17686 (CameraBackgroun), pid 17643 (flitecamerademo)

我需要一个简单的结果 - 创建 LSTM 并使其在 android v115 中工作。

我错过了什么?谢谢

标签: androidtensorflowlstmtensorflow-litelstm-stateful

解决方案


推荐阅读