python - 用于多标签问题的 keras 模型的 scikit 学习链分类器的拟合方法错误
问题描述
我正在为使用 KerasClassifier 模型的多类问题构建链分类器。我有 17 个标签作为分类目标,X_train 的形状是 (111300,107),y_train 是 (111300,17) 我的代码在这里:
def create_model():
input_size=length_long_sentence
embedding_size=128
lstm_size=64
output_size=len(unique_tag_set)
#----------------------------Model -------------------------------
current_input=Input(shape=(input_size,))
emb_current = Embedding(vocab_size, embedding_size, input_length=input_size)(current_input)
out_current=Bidirectional(LSTM(units=lstm_size))(emb_current )
#out_current = Reshape((1,2*lstm_size))(out_current)
output = Dense(units=len(unique_tag_set), activation='softmax')(out_current)
model = Model(inputs=current_input, outputs=output)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
print(model.summary())
return model
model = KerasClassifier(build_fn=create_model, epochs=1,batch_size=256)
print(type(model))
chain=ClassifierChain(model, order='random', random_state=42)
history=chain.fit(X_train, y_train)
模型摘要在这里:
当尝试在 ClassifierChain 上使用 fit 方法时,我收到此错误:
任何人都可以指导我这个错误,什么是(无,2)?
解决方案
来自链分类器的文档:
将二元分类器排列成链的多标签模型。
因此,使用最后一层中的单个节点和损失函数将您的 keras 模型转换为二元分类器作为 binary_crossentropy
推荐阅读
- r - 如何在 R 中循环修改多个单元格而不丢失格式?
- python - Python如果名称主要变量未定义
- javascript - 为fullcalendar js的每个事件插入一个谷歌地图链接
- python - 具有数字线间距分布的 Seaborn 条形图
- javascript - 如何在javascript中更改json文件导入路径
- javascript - 当屏幕宽度小于数量时更改文本不起作用
- python - 如何使用 python-sounddevice 录制音频输出?
- sql-server - SQL查找哪个表填充了diag_code,然后在lookup_table中查找
- javascript - 使用数据表导出信息时出错
- python - 为线性回归计算 R^2:SSreg/SStot 与 1-(SSSres/SStot) 导致不同的结果