python - ValueError:检查目标时出错:预期activation_5 的形状为(1,),但得到的数组的形状为(100,)
问题描述
我目前正在尝试使用顺序视频数据训练 LSTM 网络。不过,现在不断发生的问题是数据本身的输出形状错误。从视频中,我生成了 100 个时间步长(帧)的 686 个样本片段。然后使用另一个 CNN,我创建了每个图像的形状 2048 的嵌入。换句话说,在我的例子中,X_train 的形状是 (686,100,2048),而 Y_train 的形状是 (686,100)。现在,当我通过我的网络传递我的数据集时,我得到了这个形状错误。
我的模型:
from keras.layers import Activation, Input, Dense, Lambda, LSTM, Flatten
from keras.models import Model
def model_builder(input_shape):
base_input = Input(shape = input_shape)
x = LSTM(units=50, name='LSTM1', return_sequences=True)(base_input)
x = Flatten()(x)
x = Dense(units = 3)(x)
x = Activation('softmax')(x)
classification_model = Model(base_input, x,name='classifier')
classification_model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
return classification_model
我像这样运行它:
batch_size = 64
epochs = 20
timesteps = 100
embedding_shape=2048
classification_model = model_builder((timesteps,embedding_shape))
try:
Y_train=Y_train.reshape((686,timesteps))
X_train = np.reshape(X_train,(686, timesteps,embedding_shape))
outcome = classification_model.fit(x=X_train, y=Y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=None, validation_split=(6200/68600), validation_data=None, shuffle=False, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1)
except KeyboardInterrupt:
pass
关于我可能做错了什么的任何想法?
解决方案
如果没有深入检查您的代码,如果您有错误的损失函数,99% 的错误都会发生。
请使用 loss='categorical_crossentropy' 而不是 loss='sparse_categorical_crossentropy' 修改您的模型编译:
classification_model.compile(loss='categorical_crossentropy',........)
不同之处在于您的目标的编码。如果您的目标是一次性编码的,请使用 categorical_crossentropy。
推荐阅读
- c++ - 为什么 cin>>(string) 在 cin>>(int) 失败后停止?
- javascript - 动态决定使用哪种 Fetch 响应方法
- macos - Visual Studio Mac 调试器——“找不到 FileStream.cs”
- xamarin.forms - Mono.Linker.MarkException:错误处理方法:
- ios - 相同的警报显示多个“取消”按钮
- java - Android Listview - 在线程上加载每一行
- python - ADAL Python 刷新 PowerBI 数据集
- svn - SVN 客户端错误“[...] 处的服务器不支持 HTTP/DAV 协议”
- python - Python3 print (something, flush=True) 仅在 localhost 上工作,在外部缓冲
- android-studio - 在 ABI 的 NDK 工具链文件夹中找不到工具链,前缀为:mipsel-linux-android