python - 如何在 CNN 评估期间存储错误的预测
问题描述
在评估期间,我想存储被错误预测为进行更多处理的唯一 ID。这是一个多类预测问题
这是评估期间的代码:
outputs = model(imgs)
loss = criterion(outputs, targets) # Prediction error
val_loss += loss.item()
predicted = torch.argmax(outputs, dim=1)
t_predicted +=predicted.cpu().tolist()
total += targets.size(0)
good_answers = (predicted == targets)
correct += good_answers.sum().item()
知道这ids
是图像 ID 的列表,当我尝试获取错误的 ID 时:
wrong_ids += ids[~(good_answers.to('cpu'))]
我收到此错误:
add(): argument 'other' (position 1) must be Tensor
解决方案
这个问题包含一个tensorflow
标签,所以我正在准备一个答案。完成我的写作后,我发现这个标签被删除了。但是,我相信我的回答可以深入了解他们是否使用tf
或pytorch
.
数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# train set / data
x_train = x_train.astype('float32') / 255
# validation set / data
x_test = x_test.astype('float32') / 255
# train set / target
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
# validation set / target
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
火车
import tensorflow as tf
# declare input shape
input = tf.keras.Input(shape=(32,32,3))
# Block 1
x = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")(input)
x = tf.keras.layers.MaxPooling2D(3)(x)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(x)
# Finally, we add a classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(input, output)
print('\nFunctional API')
func_model.compile(
metrics=['accuracy'],
loss = 'categorical_crossentropy',
optimizer = tf.keras.optimizers.Adam()
)
func_model.fit(x_train, y_train)
错误预测
# Predict the values from the validation dataset
y_pred = func_model.predict(x_test)
# Convert predictions classes to one hot vectors
y_pred_classes = np.argmax(y_pred, axis = 1)
y_test = np.argmax(y_test, axis=1)
# Errors are difference between predicted labels and true labels
errors = (y_pred_classes - y_test != 0)
y_pred_classes_errors = y_pred_classes[errors]
y_pred_errors = y_pred[errors]
y_true_errors = y_test[errors]
x_test_errors = x_test[errors]
# Probabilities of the wrong predicted numbers
y_pred_errors_prob = np.max(y_pred_errors, axis = 1)
# Predicted probabilities of the true values in the error set
true_prob_errors = np.diagonal(np.take(y_pred_errors, y_true_errors, axis=1))
# Difference between the probability of the predicted label and the true label
delta_pred_true_errors = y_pred_errors_prob - true_prob_errors
# Sorted list of the delta prob errors
sorted_dela_errors = np.argsort(delta_pred_true_errors)
# Top 6 errors
most_important_errors = sorted_dela_errors[-6:]
展示
import matplotlib.pyplot as plt
def display_errors(errors_index,img_errors,pred_errors, obs_errors):
""" This function shows 6 images with their predicted and real labels"""
n = 0
nrows = 2
ncols = 3
fig, ax = plt.subplots(nrows,ncols,sharex=True,sharey=True)
for row in range(nrows):
for col in range(ncols):
error = errors_index[n]
ax[row,col].imshow((img_errors[error]).reshape((32,32,3)))
ax[row,col].set_title("Predicted label :{}\nTrue label :{}".format(pred_errors[error],obs_errors[error]))
n += 1
# Show the top 6 errors
display_errors(most_important_errors, x_test_errors, y_pred_classes_errors, y_true_errors)
推荐阅读
- string - 为什么使用 -n 测试非空字符串总是正确的?
- java - SimpleDateFormat 对象更改时区而不添加/减去正确的小时数
- reactjs - React Jest:如何模拟返回数据的异步 API 调用?
- python - ImportError:无法从“树”导入名称“_assert_shallow_structure”
- postgresql - 如何为 postgres 安装 plpython?
- javascript - 循环遍历数组并显示数据
- azure - 如何使用azure-ad-b2c自定义登录页面的html和css
- rest - 使用rest api将任务分配给salesforce中的多个潜在客户
- unity3d - 如何动态地在 3d 模型上应用纹理?
- debian - Vagrant 用户失去了 sudo 权限并且服务停止工作