首页 > 解决方案 > 使用 Tensorflow,如何将 LSTM 生成的权重加载到 CudnnLSTM 模型中?

问题描述

我用 tensorflow 训练了 LSTM 模型,我可以将 LSTM 生成的权重加载到 CudnnLSTM 模型中吗?我的 LSTM 代码是

lstm_cell = tf.contrib.rnn.LSTMCell(hidden_size)
outputs, (c, h) = tf.nn.dynamic_rnn(lstm_cell,
                                    input_seq,
                                    dtype = tf.float32)

CudnnLSTM 代码是

cudnn_cell_fw = cudnn_rnn.CudnnLSTM(num_layers = 1,
                                    num_units = hidden_size,
                                    direction = cudnn_rnn.CUDNN_RNN_UNIDIRECTION,
                                    input_mode = cudnn_rnn.CUDNN_INPUT_LINEAR_MODE,
                                    dtype = tf.float32)
outputs, (h, c) = cudnn_cell_fw(inputs = input_seq)

标签: tensorflowlstm

解决方案


我尝试将 LSTM 的权重和偏差收集为:

frozen_graph_path = './train_one_lstm_model.pb'
frozen_graphdef = get_graphdef(frozen_graph_path)
for node in frozen_graphdef.node:
    if (node.name == 'rnn/lstm_cell/kernel'):
        lstm_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)
    if (node.name == 'rnn/lstm_cell/bias'):
        lstm_bias = tensor_util.MakeNdarray(node.attr['value'].tensor)

然后我将它们发送到 CudnnLSTM 节点

weight_shape_0 = lstm_weight.shape[0]
weight_shape_1 = lstm_weight.shape[1]
new_cudnn_weight = np.zeros(((weight_shape_0 + 2) * weight_shape_1), dtype = np.float32)
index = 0
for i in range(weight_shape_0):
    for j in range(weight_shape_1):
        new_cudnn_weight[index] = lstm_weight[i][j]
        index += 1

for j in range(weight_shape_1):
    new_cudnn_weight[index] = lstm_bias[j]
    index += 1

frozen_graph_path = './train_one_culstm_model.pb'
frozen_graphdef = get_graphdef(frozen_graph_path)

for node in frozen_graphdef.node:
    if (node.name == 'cudnn_lstm/opaque_kernel'):
        ori_cudnn_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)
        node.attr['value'].tensor.CopyFrom(tensor_util.make_tensor_proto(new_cudnn_weight))
        new_cudnn_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)

因此,CudnnLSTM 节点获得与 LSTM 相同的权重和偏差。但是,当我发送相同的输入时,输出是不同的。


推荐阅读