首页 > 解决方案 > 我不明白使用 mxnet 的“形状不一致”错误

问题描述

来自 Keras,我尝试使用 MXNet 重现我的简单模型,以使用 Module 进行预测。

我正在使用那个简单的数据集:https ://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data

我有 13 个输入(从酒精到脯氨酸)要发送到模型,我需要对第一列“葡萄酒类型”进行分类,因此我创建了一个包含 3 个条目的 nd.array。


x = data.values[: , 1:14]
y = data.values[:, 0]

X = mx.nd.array(x)
Y = []
for i, v in enumerate(y):
    d = [0,0,0]
    d[int(v)-1] = 1
    Y.append(d)
Y = mx.nd.array(Y)
Y.shape, X.shape
# ((178, 3), (178, 13))

然后我创建模型和一个 NDIterator:


net = mx.symbol.Variable('winechemical')
net = mx.symbol.FullyConnected(net, num_hidden=64)
net = mx.symbol.Activation(net, act_type='relu')
net = mx.symbol.FullyConnected(net, num_hidden=32)
net = mx.symbol.Activation(net, act_type='relu')
net = mx.symbol.FullyConnected(net, num_hidden=16)
net = mx.symbol.SoftmaxOutput(net, name='wineclass')

model = Module(symbol=net, context=mx.cpu(),
                  data_names=['winechemical'],
                  label_names=['wineclass_label'])

gen = mx.io.NDArrayIter(X, label=Y, 
                        batch_size=10, 
                        shuffle=True, data_name='winechemical', 
                        label_name='wineclass_label')

但是当我尝试使用“fit”方法“训练”模型时,我得到了这个错误:

model.fit(gen, num_epoch=5)

[...]
运算符 wineclass 中的错误:形状不一致,提供 = [10,3],推断形状 = [10]

我很确定我不了解要使用的形状,因为我来自使用不同形状的 Keras……但我错在哪里?

谢谢你的帮助。

标签: pythonmxnet

解决方案


您已经自己找到了解决方案。但是如果你再次遇到类似的问题,你可以使用mx.visualization.print_summary()这个函数对于检查模型中不同层的形状非常有用。


推荐阅读