seq2seq - 在 seq2seq Maluuba 模型中实现注意力机制
问题描述
您好,我正在尝试增加对简单 Maluuba/qgen-workshop seq2seq 模型的关注,但我无法弄清楚我应该传递到初始状态的正确 batch_size 我试过这个:
# Attention
# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
encoder_cell.state_size, attention_states,
memory_sequence_length=None)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
decoder_cell, attention_mechanism,
attention_layer_size=encoder_cell.state_size)
batch = next(training_data())
batch = collapse_documents(batch)
initial_state = decoder_cell.zero_state(batch["size"], tf.float32).clone(cell_state=encoder_state)
decoder = seq2seq.BasicDecoder(decoder_cell, helper, initial_state, output_layer=projection)
它给了我这个错误:
InvalidArgumentError (see above for traceback): assertion failed: [When applying AttentionWrapper attention_wrapper_1: Non-matching batch sizes between the memory (encoder output) and the query (decoder output).
Are you using the BeamSearchDecoder? You may need to tile your memory input via the tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.] [Condition x == y did not hold element-wise:] [x (decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/x:0) = ] [99] [y (LuongAttention/strided_slice_1:0) = ] [29]
[[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_INT32, DT_STRING, DT_INT32], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/All, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_0, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_1, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_2, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/x, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_4, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Equal/Enter)]]
解决方案
我们目前有_MAX_BATCH_SIZE = 128
,但每批都有不同的大小,因为我们想确保一个故事的所有问题都在同一个批次中。所以每个批次都有一个'size'
键来指示它的大小。
好像你已经知道了。我认为问题是别的。也许encoder_cell.state_size
是为较早的批次设置了批次大小?
推荐阅读
- julia - 如何根据 Julia 中的列中的值查找数据框行的平均值?
- python - 尝试用 Python 实现自上而下的 DP,相信缓存不起作用
- node.js - Okta 教程:缺少 appBaseUrl
- javascript - 在为对象的键定义值时获取未定义
- php - 无法在 PHP 中重新打开已关闭的会话
- excel - Excel 单元格公式:比较计算的时间值与手动输入的时间值的间歇性错误
- python-3.x - 为以太坊生成数据集 DAG
- python - discord.py 只运行某些 cog
- registry - 使用存储在注册表中的版本号在 Inno Setup 中检测并卸载旧版本的应用程序
- java - 为什么我看不到 Spring Security 的日志?