python - 空维度形状“ValueError:检查目标时出错:预期dense_2有2个维度,但得到了形状()的数组”
问题描述
我在创建我的 Keras CNN+LSTM 模型时遇到问题:
ValueError:检查目标时出错:预期dense_2有2维,但得到了形状为()的数组
我已经删除了一些层来测试这个问题。但什么都没有改变。
背景:我正在尝试分析 4D 数据:3D 图像 + 1D 时间序列。我的工作方式是一次添加一张图像,由我的 CNN+LSTM 模型进行分析。我设法使尺寸正确并流经模型。但是后来我遇到了上面提到的错误。
# define CNN model
model = Sequential()
#Layer 1
model.add(TimeDistributed(Conv3D(32, kernel_size=(5, 5, 5), strides=(1, 1, 1),
activation='relu'),
input_shape=input_shape))
model.add(TimeDistributed(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))))
#Layer 2
model.add(TimeDistributed(Conv3D(64, (5, 5, 5), activation='relu')))
model.add(TimeDistributed(MaxPooling3D(pool_size=(2, 2, 2))))
#Layer 3
model.add(TimeDistributed(Conv3D(128, (5, 5, 5), activation='relu')))
model.add(TimeDistributed(MaxPooling3D(pool_size=(2, 2, 2))))
#Flatten
model.add(TimeDistributed(Flatten()))
# LSTM
model.add(LSTM(512, return_sequences=True))
model.add(LSTM(512))
model.add(Dense(1201))
#Dense
model.add(Dense(num_classes, activation='sigmoid'))
#Compile
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
train_gen = generate_images(train_dataset)
model.fit_generator(train_gen, samples_per_epoch=5, nb_epoch=10)
###################################################################
#Data Generator
def generate_images(dataframe):
while True:
sub_dataframe = dataframe.sample(n=1)
batch_input = []
# batch_output = []
# for index, row in sub_dataframe.iterrows(): # iterate through each row
input_path = os.path.join(base_directory, sub_dataframe['Image'].values[0])
img = get_fmri_sequence(input_path)
img = np.expand_dims(img, axis=-1)
batch_input.append(img)
batch_input = np.array(batch_input)
yield (batch_input, sub_dataframe["DX"].values[0])
###################################################################
以下是模型摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
time_distributed_1 (TimeDist (None, None, 43, 54, 45, 4032
_________________________________________________________________
time_distributed_2 (TimeDist (None, None, 21, 27, 22, 0
_________________________________________________________________
time_distributed_3 (TimeDist (None, None, 17, 23, 18, 256064
_________________________________________________________________
time_distributed_4 (TimeDist (None, None, 8, 11, 9, 64 0
_________________________________________________________________
time_distributed_5 (TimeDist (None, None, 4, 7, 5, 128 1024128
_________________________________________________________________
time_distributed_6 (TimeDist (None, None, 2, 3, 2, 128 0
_________________________________________________________________
time_distributed_7 (TimeDist (None, None, 1536) 0
_________________________________________________________________
lstm_1 (LSTM) (None, None, 512) 4196352
_________________________________________________________________
lstm_2 (LSTM) (None, 512) 2099200
_________________________________________________________________
dense_1 (Dense) (None, 1201) 616113
_________________________________________________________________
dense_2 (Dense) (None, 2) 2404
=================================================================
Total params: 8,198,293
Trainable params: 8,198,293
Non-trainable params: 0
_________________________________________________________________
解决方案
推荐阅读
- jenkins - 将时间戳添加到从节点项目生成的战争文件到 Jenkins 管道中的工件
- wordpress - 木材:使用 ACF 的精选帖子除外是原始帖子
- linux - 从 apt-get install 安装时 parrot 4.7 上的 wireguard-dkms 错误
- javascript - 为什么我按下按钮时总是收到此错误?
- c++11 - 如何使用带有指向非 const 对象的 const 共享指针的重置?
- spring-integration - “文件夹处于只读模式”异常
- java - Spring boot Hibernate 一对多关系
- wpf - 如何将 WPF 框架扩展到其父宽度?
- multithreading - 阻止除一个之外的所有线程的最佳方法是什么?
- javascript - Material UI Theme - 交换原色和副色