python - Tensorflow 输入维度不匹配问题
问题描述
我创建了一个 LSTM-NN。我正在传递一个输入,但我得到了错误:
ValueError: Input 0 of layer lstm is incompatible with the layer: expected ndim=3, found ndim=1. Full shape received: (7,)
为了解决这个问题,我参考了一个堆栈溢出帖子,其中提到了参数的使用:input_shape。由于缺乏理解,我仍然无法解决我的问题。请帮忙。这是我的代码
# This is the definition of the model
class LSTMmodel(tf.Module):
def __init__(self, arg_name=None):
super().__init__(name=arg_name)
self.__input = tf.Variable(initial_value=[0 for x in range(7)])
self.__network = tf.keras.layers.LSTM(units=7, input_shape=(7,))
self.__output = tf.Variable(initial_value=[0 for x in range(7)])
@tf.function
def train(self, arg_data_train, labels, learning_rate):
with tf.GradientTape() as t:
self.__input = tf.Variable(initial_value=[0 for x in range(7)])
self.__output = self.__network(self.__input)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=None, logits=None)
dw, db = t.gradient(loss, [self.w, self.b])
self.w.assign_sub(learning_rate * dw)
self.b.assign_sub(learning_rate * db)
@tf.function
def __call__(self, arg_input=[0 for x in range(7)]):
self.__input = tf.Variable(arg_input)
self.__output = self.__network(self.__input)
return self.__output
# This is the input I provide for training where the problem occurs.
# The two vars ````cgm```` and ````labels```` are length 9222 lists.
# Each element of the list is a list with length 7 filled with only integers.
modela = LSTMmodel(arg_name='namea')
modela.train(cgm ,labels, 0.4)
```
解决方案
LSTM 将 3D 张量作为输入,并且您正在传递 1D 张量,因此您需要将其重塑为适当的形状。
self.__input = tf.Variable(initial_value=[0 for x in range(7)])
self.__input_reshaped = tf.reshape(self.__input, [1, 7, 1]) # shape of (1, 7, 1)
您还需要更改input_shape
LSTM 的。
self.__network = tf.keras.layers.LSTM(units=7, input_shape=(7,1))
然后将重构后的输入传递给网络以获得输出。
self.__output = self.__network(self.__input_reshaped)
推荐阅读
- excel - 搜索 ListBox 错误 _ 无法设置 List 属性。无效的属性值
- python - 功能不打印坐标
- javascript - 如何使用js正则表达式获取某个字符内的所有元素
- python - 为什么我在 python 中写入 XML 文件时出现子索引超出范围错误?
- html - 页面上的第二个轮播
- c# - T-SQL / EF Where Col 为真假过滤结果
- aws-appsync - 使用 GraphQL 转换在 AWS AppSync 中生成带有参数的订阅
- maven - 了解 Maven 插件
- python - 从模块导入类但功能不起作用
- html - JSoup:如何列出列表中的链接?