python-3.x - “布尔”对象不可下标 - Keras
问题描述
我目前正在参与一个深度学习项目,我需要使用敏感性和特异性指标进行评估,这些指标不包含在开箱即用的 keras 中。
我已经实现了灵敏度功能如下:
from keras import backend as K
def sensitivity(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = (y == i)
preds = (y_pred == i)
print(preds)
TP += K.sum(preds[true == 1])
FP += K.sum(true[np.invert(preds) == 1])
TN += K.sum(np.invert(preds)[true == 1])
FN += K.sum(true[preds == 0])
return TP / (TP + FN)
这本身就可以正常工作。但是,当我在编译模型时尝试使用它时,我收到错误“'bool' object is not subscriptable”。
我该如何解决这个问题?
编译代码和完整的错误消息包含在下面。谢谢!编辑:通过下面的回复和一些研究,我能够修复我粘贴在帖子底部的代码。在 Keras 中编译模型时,它由 Theano 或 TensorFlow 评估,因此您不能使用 Numpy 命令来制作自己的指标。
from keras.datasets import mnist
from keras.applications.vgg16 import VGG16
(x_train, y_train), (x_test, y_test) = mnist.load_data()
model = VGG16(
include_top=False,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=10)
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy', sensitivity])
TypeError Traceback (most recent call last)
<ipython-input-20-46c5d210c7ea> in <module>()
9 model.compile(optimizer='rmsprop',
10 loss='categorical_crossentropy',
---> 11 metrics=['accuracy', sensitivity])
/home/USER/Documents/deep_learning/custom_metrics/venv/lib/python3.6/site-packages/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
449 output_metrics = nested_metrics[i]
450 output_weighted_metrics = nested_weighted_metrics[i]
--> 451 handle_metrics(output_metrics)
452 handle_metrics(output_weighted_metrics, weights=weights)
453
/home/USER/Documents/deep_learning/custom_metrics/venv/lib/python3.6/site-packages/keras/engine/training.py in handle_metrics(metrics, weights)
418 metric_result = weighted_metric_fn(y_true, y_pred,
419 weights=weights,
--> 420 mask=masks[i])
421
422 # Append to self.metrics_names, self.metric_tensors,
/home/USER/Documents/deep_learning/custom_metrics/venv/lib/python3.6/site-packages/keras/engine/training_utils.py in weighted(y_true, y_pred, weights, mask)
402 """
403 # score_array has ndim >= 2
--> 404 score_array = fn(y_true, y_pred)
405 if mask is not None:
406 # Cast the mask to floatX to avoid float64 upcasting in Theano
<ipython-input-18-caa661fb93ac> in sensitivity(y, y_pred)
11 preds = (y_pred == i)
12 print(preds)
---> 13 TP += K.sum(preds[true == 1])
14 FP += K.sum(true[np.invert(preds) == 1])
15 TN += K.sum(np.invert(preds)[true == 1])
TypeError: 'bool' object is not subscriptable
def sensitivity(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = K.equal(y, i)
preds = K.equal(y_pred, i)
TP += K.sum(K.cast(tf.boolean_mask(preds, tf.math.equal(true, True)), 'int32'))
FP += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(~preds, True)), 'int32'))
TN += K.sum(K.cast(tf.boolean_mask(~preds, tf.math.equal(true, True)), 'int32'))
FN += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(preds, False)), 'int32'))
return TP / (TP + FN)
解决方案
推荐阅读
- sql - 我需要将 SQL 中的所有活动分配给一个 ID,但目前每个活动都有三个 ID
- php - 如何制作显示所选 ID 的视图?
- php - 如何从 Laravel 5.8 中的动态按钮传递正确的数据
- html - 通过将鼠标悬停在子元素上来影响父元素
- visual-studio-code - 如何停止VS。破坏我的代码格式的代码?
- git - 如何推送文件并在之后忽略它
- javascript - 我们如何在反应原生的两个选项卡更改上更新单个状态
- angular - 如何加载 config.json 并在另一个模块中使用配置值?
- mockito - 如何模拟自定义 util 类
- php - CakePHP 2.9 连接查询返回错误结果