python - 创建级联神经网络以及如何训练它
问题描述
我有一个关于级联神经网络的输入形状的问题。我有一个标记化的文本列,使用函数 pad_sequences 所有数据点的长度为 2395,用于训练我有 6493 个数据点。所以文本部分的形状是 (6493, 2395),不是吗?我有 17 个附加列要放入模型中。所以这个附加数据的形状是 (6493, 17)。
对于神经网络,我有以下代码:
embedding_dim = 300
inp_dim = X_train.shape[1]
text_data = Input(shape=(max_length,), name="X_train")
meta_data = Input(shape=X_train_zusatz.shape, name="X_train_zusatz")
x1 = (Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=max_length))(text_data)
x2 = (LSTM(300, dropout = 0.2, recurrent_dropout = 0.2, return_sequences=True))(x1)
x3 = (Dense(300, activation = "relu"))(meta_data)
x4 = concatenate([x2, x3], axis = 1)
x5 = (Dense(300, activation = "relu"))(x4)
x6 = Dropout(0.25)(x5)
x7 = (Dense(300, activation = "relu"))(x6)
x8 = BatchNormalization()(x7)
x9 = (Dense(4, activation='softmax'))(x8)
model = Model(inputs = [text_data, meta_data], outputs = x9)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
print(model.summary())
model.summary 如下所示:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
X_train (InputLayer) (None, 2395) 0
__________________________________________________________________________________________________
embedding_19 (Embedding) (None, 2395, 300) 23400600 X_train[0][0]
__________________________________________________________________________________________________
X_train_zusatz (InputLayer) (None, 6493, 17) 0
__________________________________________________________________________________________________
lstm_19 (LSTM) (None, 2395, 300) 721200 embedding_19[0][0]
__________________________________________________________________________________________________
dense_44 (Dense) (None, 6493, 300) 5400 X_train_zusatz[0][0]
__________________________________________________________________________________________________
concatenate_18 (Concatenate) (None, 8888, 300) 0 lstm_19[0][0]
dense_44[0][0]
__________________________________________________________________________________________________
dense_45 (Dense) (None, 8888, 300) 90300 concatenate_18[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 8888, 300) 0 dense_45[0][0]
__________________________________________________________________________________________________
dense_46 (Dense) (None, 8888, 300) 90300 dropout_1[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8888, 300) 1200 dense_46[0][0]
__________________________________________________________________________________________________
dense_47 (Dense) (None, 8888, 4) 1204 batch_normalization_7[0][0]
==================================================================================================
Total params: 24,310,204
Trainable params: 24,309,604
Non-trainable params: 600
__________________________________________________________________________________________________
None
所以我的问题是,为什么我在嵌入层之后看不到数据点的数量(6493)。我在这一层犯了什么错误吗?因为在密集层中我得到了形状 (None, 6493, 300),但在嵌入层中我得到了 (None, 2395, 300)。恐怕,这里的列和行是混合的,对吗?
除此之外,我无法训练模型。编码:
model.fit([X_train, X_train_zusatz], y_train, epochs=100, batch_size=500, validation_data=[[X_test, X_test_zusatz], y_test], class_weight=class_weight)
会导致错误:
ValueError: Error when checking input: expected X_train_zusatz to have 3 dimensions, but got array with shape (6493, 17)
我该如何解决?因为 (6493, 17) 是附加数据的正确形状,但我的神经网络不会接受它。
太感谢了!
最好的问候,丹尼尔
解决方案
您正在使用需要 3 维输入的 LSTM 层,但您知道您的输入是 2 维的。不要使用 LSTM,而是使用 Dense。
推荐阅读
- python-2.7 - 如何使用 microsoft graph API 下载 oneDrive 中的文件夹?蟒蛇请求?
- opencv - 是否有一个 Halide::BoundaryConditions 来模仿 OpenCV 默认边框类型?
- java - mvel2 在几分钟不活动后需要时间来评估
- c++ - 只改变成员的子类是有效的做法吗?
- python - Python多处理队列异步工作者
- c# - 反序列化 Json .NET CORE 5 时出错 - JsonException:检测到不支持的可能对象循环
- python - 如何解决 PyTorch 内存分配不足的错误?
- kernel - 在 buildroot 中使用两个不同的工具链编译内核和用户空间
- swift - 函数重载歧义使用 - Xcode 12.5
- objective-c - 选择器如何发送参数