首页 > 解决方案 > 如何增加 BERT keras hub 层输入的秩(ndim)以进行学习排名

问题描述

我正在尝试使用 tensorflow hub 上可用的预训练 BERT 来实现学习排序模型。我正在使用 ListNet 损失函数的变体,它要求每个训练实例都是与查询相关的几个排名文档的列表。我需要模型能够接受形状(batch_size、list_size、sentence_length)的数据,其中模型在每个训练实例中的“list_size”轴上循环,返回排名并将它们传递给损失函数。在仅由密集层组成的简单模型中,这很容易通过增加输入层的维度来完成。例如:

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

input = Input([6,10])
x = Dense(20,activation='relu')(input)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=input, outputs=output)

...现在模型将在计算损失和更新梯度之前对长度为 10 的向量执行 6 次前向传递。

我正在尝试对 BERT 模型及其预处理层做同样的事情:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

bert_preprocess_model = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1')
bert_model = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
    
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
processed_input = bert_preprocess_model(text_input)
output = bert_model(processed_input)
model = tf.keras.Model(text_input, output)

但是,当我尝试将“text_input”的形状更改为(6)或以任何方式干预它时,它总是会导致相同类型的错误:

 ValueError: Could not find matching function to call loaded from the SavedModel. Got:
      Positional arguments (3 total):
        * Tensor("inputs:0", shape=(None, 6), dtype=string)
        * False
        * None
      Keyword arguments: {}
    
    Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
        * False
        * None
      Keyword arguments: {}
     ....

根据https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer,您似乎可以通过 tf.keras.layers.InputSpec 配置 hub.KerasLayer 的输入形状。就我而言,我想它会是这样的:

bert_preprocess_model.input_spec = tf.keras.layers.InputSpec(ndim=2)
bert_model.input_spec = tf.keras.layers.InputSpec(ndim=2)

当我运行上面的代码时,属性确实发生了变化,但是在尝试构建模型时,出现了同样的错误。

有没有什么方法可以轻松解决这个问题而无需创建自定义训练循环?

标签: pythonkerasinformation-retrievalbert-language-modeltensorflow-hub

解决方案


假设您有一批 B 个示例,每个示例恰好有 N 个文本字符串,这构成了一个形状为 [B, N] 的二维张量。使用tf.reshape(),您可以将其转换为形状为 [B*N] 的一维张量,通过 BERT(保留输入顺序)将其发送,然后将其重新整形为 [B,N]。(还有tf.keras.layers.Reshape,但这对您隐藏了批处理维度。)

如果每次都不完全是 N 个文本字符串,则您必须在旁边做一些簿记(例如,将输入存储在tf.RaggedTensor中,在其上运行 BERT .values,并从结果中构造一个具有相同内容的新 RaggedTensor .row_splits。)


推荐阅读