deep-learning - 我在计算模型的准确率、召回率、精度和 f1 分数时遇到问题
问题描述
我的混淆矩阵工作正常,只是在生成分数时遇到了一些麻烦。一点帮助会大有帮助。我目前收到错误消息。“张量对象不可调用”。
def get_confused(model_ft):
nb_classes = 120
from sklearn.metrics import precision_recall_fscore_support as score
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
cm = confusion_matrix(classes, preds)
recall = np.diag(cm) / np.sum(cm, axis = 1)
precision = np.diag(cm) / np.sum(cm, axis = 0)
print(confusion_matrix)
print(confusion_matrix.diag()/confusion_matrix.sum(1))
解决方案
问题在于这条线。
cm = confusion_matrix(classes, preds)
confusion_matrix
是张量,你不能像函数一样调用它。因此Tensor is not callable
。我也不确定你为什么需要这条线。相反,我认为您会想编写cm= confusion_matrix.cpu().data.numpy()
它以使其成为我认为的 numpy 数组。从您的代码来看,似乎cm
是np.array
.
推荐阅读
- web-services - SOAP 和 HTTP 响应代码
- java - Spring:无法找到 XML 模式命名空间的 Spring NamespaceHandler [http://www.springframework.org/schema/util]
- python - 更改 QPushbutton 或 QToolbutton 的图标
- mysql - mysql - 由于高索引基数导致查询缓慢
- laravel - 通过数据库记录进行 Laravel 验证
- android - 如何检查 PeriodicWorkRequest 是否已入队?
- c# - TFS 构建服务器 2015 - 转换 web.config 不正确
- html - 淡出 HTML 元素/CSS 遮罩
- string - Julia:用分隔符连接字符串(相当于 R 的粘贴)
- sql - ssis visual studio 2010和excel时间戳作为日期