tensorflow - 张量流教程中位置编码的大小
问题描述
我正在尝试理解和玩这个关于变压器架构的 tensorflow 教程,但我在类解码器中发现了一些我不理解的东西。为什么 self.pos_encoding = positional_encoding(target_vocab_size, self.d_model) 用 targe_vocab_size 而不是序列的最大长度调用?请参阅以下课程的此链接和代码。任何想法?https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/text/transformer.ipynb
class Decoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(target_vocab_size, self.d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training,
look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
# x.shape == (batch_size, target_seq_len, d_model)
return x, attention_weights
解决方案
好的,我想我说服自己该教程有一个错误。在构建位置嵌入时self.pos_encoding = positional_encoding(target_vocab_size, self.d_model)
,您应该使用 MAX_LENGTH 而不是 target_vocab_size。这解决了我在使用较小的词汇和较长的句子时遇到的一些问题。教程中的示例并没有中断,因为在他们的示例中target_vocab_size > MAX_LENGTH
,所以他们的设置没有问题。
推荐阅读
- graphics - 使用 Nvidia 驱动程序运行 Intel GPU 工具
- android - 如何在谷歌照片中为应用程序提供多个共享选项?
- python - Python生成器在应该提供的时候没有提供yield
- rx-java - 如何在 RxJava 中仅获取 zip 的最后一个值?
- reactjs - React:如何在渲染之前处理从 API 获得的数据?
- php - 有没有办法在 Laravel 5.5 中更改 url?
- regex - 为什么不能在 oracle 中将 ']' 的字符与 regexp_like 匹配?
- python - 将数据框从 Python 写入 html 时修复表头
- java - org.hibernate.property.access.spi.PropertyAccessException:访问字段时出错 [private java.lang.String,
- css - 尝试在画布中垂直对齐文本时,`context.textBaseline = 'middle'` 不起作用