python - 如何增加 BERT keras hub 层输入的秩(ndim)以进行学习排名
问题描述
我正在尝试使用 tensorflow hub 上可用的预训练 BERT 来实现学习排序模型。我正在使用 ListNet 损失函数的变体,它要求每个训练实例都是与查询相关的几个排名文档的列表。我需要模型能够接受形状(batch_size、list_size、sentence_length)的数据,其中模型在每个训练实例中的“list_size”轴上循环,返回排名并将它们传递给损失函数。在仅由密集层组成的简单模型中,这很容易通过增加输入层的维度来完成。例如:
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
input = Input([6,10])
x = Dense(20,activation='relu')(input)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=input, outputs=output)
...现在模型将在计算损失和更新梯度之前对长度为 10 的向量执行 6 次前向传递。
我正在尝试对 BERT 模型及其预处理层做同样的事情:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
bert_preprocess_model = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1')
bert_model = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
processed_input = bert_preprocess_model(text_input)
output = bert_model(processed_input)
model = tf.keras.Model(text_input, output)
但是,当我尝试将“text_input”的形状更改为(6)或以任何方式干预它时,它总是会导致相同类型的错误:
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* Tensor("inputs:0", shape=(None, 6), dtype=string)
* False
* None
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* False
* None
Keyword arguments: {}
....
根据https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer,您似乎可以通过 tf.keras.layers.InputSpec 配置 hub.KerasLayer 的输入形状。就我而言,我想它会是这样的:
bert_preprocess_model.input_spec = tf.keras.layers.InputSpec(ndim=2)
bert_model.input_spec = tf.keras.layers.InputSpec(ndim=2)
当我运行上面的代码时,属性确实发生了变化,但是在尝试构建模型时,出现了同样的错误。
有没有什么方法可以轻松解决这个问题而无需创建自定义训练循环?
解决方案
假设您有一批 B 个示例,每个示例恰好有 N 个文本字符串,这构成了一个形状为 [B, N] 的二维张量。使用tf.reshape(),您可以将其转换为形状为 [B*N] 的一维张量,通过 BERT(保留输入顺序)将其发送,然后将其重新整形为 [B,N]。(还有tf.keras.layers.Reshape,但这对您隐藏了批处理维度。)
如果每次都不完全是 N 个文本字符串,则您必须在旁边做一些簿记(例如,将输入存储在tf.RaggedTensor中,在其上运行 BERT .values
,并从结果中构造一个具有相同内容的新 RaggedTensor .row_splits
。)
推荐阅读
- python - 在 IntelliJ 中使用烧瓶“处理完成,退出代码 0”;服务器不运行
- amazon-web-services - Terragrunt - 更有效地重用模块
- model - OpenUI5 格式化程序对模型更改没有反应
- r - 通过 R 中的 ID 识别数据框中的任何 NA
- python - 写到excel给我重复
- excel - 由另一个标准VBA合并后按标准求和
- python-3.x - 用于在括号内获取多个字符串的正则表达式
- awk - 无法在 gensub (awk) 中使用 sprint
- r - 阻止 Shiny sidebarPanel 中的元素在列中重叠
- firebase - 在新的 FlutterFire API 中使用 setData 和 merge: true