首页 > 解决方案 > 是否可以训练具体功能?

问题描述

我想将模型转换为 tflite 格式。但是,我不断收到不支持运算符 BroadcastTo 的错误。我能够解决此错误的唯一方法是按模型定义为具体函数。我如何训练一个具体的功能,甚至可能吗?

(不是我的实际模型,只是错误的一个最小示例)

    # -------------------- 不起作用 --------------------

    类CustomLayer(tf.keras.layers.Layer):
      def __init__(self, num_outputs):
        super(CustomLayer, self).__init__()

      def 调用(自我,输入):
        trans = tf.ones([1, 25, 37, 12])
        trans = tf.math.add(trans, 输入)
        m1s = tf.ones([1, 25, 37, 12, 5, 5])
        reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
        f = tf.multiply(reshape, m1s)
        返回 f

    输入 = tf.keras.Input(shape=(1), dtype=tf.float32)
    f = CustomLayer(1)(输入)
    模型 = tf.keras.Model(输入=输入,输出=f)
    转换器 = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = 转换器.convert()
    打开(“model.tflite”,“wb”).write(tflite_model)

    #-------------------- 具体功能(作品)--------------------

    根 = tf.Module()
    root.var = 无

    @tf.function
    def 示例(数字):
      trans = tf.ones([1, 25, 37, 12])
      trans = tf.add(trans, number)
      m1s = tf.ones([1, 25, 37, 12, 5, 5])
      reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
      f = tf.multiply(reshape, m1s)
      返回 f

    root.func = 示例
    具体函数 = root.func.get_concrete_function(3)
    转换器 = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    tflite_model = 转换器.convert()
    打开(“model.tflite”,“wb”).write(tflite_model)

请注意,我已经尝试了以下解决方案:

  1. 在Keras中定义模型(因此可以轻松训练)并使用
    tf.lite.TFLiteConverter.from_keras_model
  2. 将 Keras 模型保存为SavedModel并使用
    tf.lite.TFLiteConverter.from_saved_model
  3. 将 Keras 模型保存为 SavedModel 并使用从中获取具体功能
    concrete_func = model.signatures[
    tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

我知道也可以制作自定义运算符,但这需要对 tensorflow 的 C++ API 有深入了解,了解 BroadcastTo 在内部如何工作,知道将文件放在哪里,编译自定义 AAR,以及构建自定义 JNI 层。

标签: tensorflow-litetensorflow2.0

解决方案


试试这个代码!!

import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub

model_path='/content/model.h5'
model=keras.models.load_model(model_path)
reloaded = keras.models.load_model(model_path,custom_objects{'KerasLayer':hub.KerasLayer})

TFLITE_MODEL = f"path/model.tflite"


# Get the concrete function from the Keras model.
run_model = tf.function(lambda x : reloaded(x))

# Save the concrete function.
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)

# Convert the model to standard TensorFlow Lite model
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converted_tflite_model = converter.convert()
open(TFLITE_MODEL, "wb").write(converted_tflite_model)

推荐阅读