首页 > 解决方案 > Tensorflow:如何转换 tensorflow LSTM 的输入数据?

问题描述

所以,我正在尝试使用 tensorflow 进行简单分类,我的疑问是

如果我使用 LSTM 进行文本分类(例如:情感分类),那么我们对数据进行填充,之后我们使用 word_embedding 来馈送 LSTM tensorflow,因此在 word_embedding 查找后,2 维数据变为 3 维或 2 级矩阵变为 3 级:

就像我有两个文本:

import tensorflow as tf

text_seq=[[11,21,43,22,11,4,1,3,5,2,8],[4,2,11,4,11,0,0,0,0,0,0]]  #2x11 

#text_seq are index of words from word_to_index dict

a=tf.get_variable('word_embedding',shape=[50,50],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))

lookup=tf.nn.embedding_lookup(a,text_seq)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(lookup).shape)

我会得到 :

(2, 11, 50)

我可以很容易地喂给 LSTM,因为 LSTM 接受排名 3

但是如果我有数字浮点数据而不是文本数据并且我想使用 RNN 进行分类,我的问题是假设的,

所以假设我的数据是:

import numpy as np

float_data=[[11.1,21.5,43.6,22.1,11.44],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1],[33.5,12.7,7.4,73.1,89.1]]


labels=[1,2,3,4,5,6]
#2x5

batch_size=2

input_data_batch=[[11.1,21.5,43.6,22.1,11.44],[33.5,12.7,7.4,73.1,89.1]]


 #now should I reshape my data to make it rank 3 like this 


reshape_one=np.reshape(input_data_batch,[-1,batch_size,5])
print(reshape_one)


# or like this ?



reshape_two=np.reshape(input_data_batch,[batch_size,-1,5])

print(reshape_two)

输出:

first one

[[[11.1  21.5  43.6  22.1  11.44]
  [33.5  12.7   7.4  73.1  89.1 ]]]

second one


[[[11.1  21.5  43.6  22.1  11.44]]

 [[33.5  12.7   7.4  73.1  89.1 ]]]

标签: pythontensorflowkeraslstm

解决方案


LSTM 和其他序列模型可以采用时间主要(即维度是时间、批次、通道)或批次主要(维度是批次、时间、通道)的输入。我不知道您将哪些标志传递给 tf 的哪个实现,因此我无法从您提供的代码中判断您需要批量输入还是时间输入。


推荐阅读