首页 > 解决方案 > Sample_weights Keras 模型 - IndexError:数组索引过多

问题描述

我有一个相当不平衡的数据集,我想在其中对一些数据进行不同的权衡,以便使用 Keras 实现我的神经网络。
我发现我可以使用 sample_weights 。

我的代码如下所示:

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train < 1] = 2.0


history = model.fit(x_train, y_train, batch_size=32, sample_weight=sample_weight, epochs=100, validation_data=(x_val, y_val))

但它给了我第 2 行的以下错误: IndexError: too many indices for array

如果我打印我的 y_train,它看起来像这样:

           Ertrag
41799      0.979252
48595      1.000000
50681      1.000000
51678      1.000000
4896       1.000000

是因为y_train中的索引列吗?

提前致谢!

标签: pythonpandas

解决方案


该错误可能是由 y_train 和 sample_weight 的维度差异引起的。以下是故障排除的想法:

  1. 打印并检查 y_train ( len(y_train)) 的长度,看看返回的形状是否符合您的预期
  2. len(sample_weight)打印并检查 sample_weight ( )的长度
  3. (1)和(2)的输出应该相同sample_weight[y_train < 1] = 2.0才能工作

推荐阅读