首页 > 解决方案 > np.where IndexError 异常

问题描述

我有一个非常简单的代码如下:

import numpy as np
num_classes = 12
im_pred = np.random.randint(0, num_classes, (224, 244))
img = np.zeros((224, 224, 3))
print(im_pred.shape)
#(224, 244)
print(img.shape)
#(224, 224, 3)
for i in range(num_classes):
    img[np.where(im_pred==i), :] = [225, 0, 0]

Traceback(最近一次调用最后一次):
文件“”,第 2 行,在 <module>
IndexError:索引 227 超出轴 0 的范围,大小为 224

x, y = np.where(im_pred==i)
print(np.max(x), np.max(y))
#223 243

为什么我得到一个IndexError?至于我的理解np.where,返回的索引值应该小于224.

让我知道。我开始怀疑numpy安装是否有问题。

谢谢。

标签: pythonnumpy

解决方案


No Numpy is not buggy. Look at how you defined im_pred for a second, you are drawing a random integer between 0 and 11 for an array which has size 224 by 244. So the reason it is throwing an error is because the dimension of size 244 is too large for your variable img which is only 224 by 224 by 3. I think you may have meant for both to have the same 1rst and second dimensions, something like

img = np.zeros((224,244,3)) 

推荐阅读