首页 > 解决方案 > 如何用张量索引 LSTMStateTuple 列表?

问题描述

我有一个 pythonLSTMStateTuple对象列表,我必须使用张量作为索引来检索它们。例如:

index = tf.constant(0)
lstm = tf.nn.rnn_cell.LSTMCell(128)
states = [lstm.zero_state(10, tf.float32), lstm.zero_state(10, tf.float32)]

如果我尝试state = states[index]我得到一个错误并state = tf.gather(states, index)转换states为张量并返回一个张量 shape [10, 2, 128]

我怎样才能得到 aLSTMStateTuple而不是张量?当我将状态传递给 lstm 时,我想避免从列表LSTMStateTuple到张量以及从张量到的转换。LSTMStateTuple

标签: pythontensorflowlstm

解决方案


您创建两个状态并将它们放在一个LSTMStateTuple.

cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

推荐阅读