python - tensorflow boolean_mask 如何在两个张量之间进行掩码?
问题描述
我有如下代码:
def yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = .6):
"""Filters YOLO boxes by thresholding on object and class confidence.
Arguments:
box_confidence -- tensor of shape (3, 3, 5, 1)
boxes -- tensor of shape (3, 3, 5, 4)
box_class_probs -- tensor of shape (3, 3, 5, 80)
threshold -- real value, if [ highest class probability score < threshold], then get rid of the corresponding box
Returns:
scores -- tensor of shape (None,), containing the class probability score for selected boxes
boxes -- tensor of shape (None, 4), containing (b_x, b_y, b_h, b_w) coordinates of selected boxes
classes -- tensor of shape (None,), containing the index of the class detected by the selected boxes
Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
For example, the actual output size of scores would be (10,) if there are 10 boxes.
"""
# Step 1: Compute box scores
box_scores = np.multiply(box_confidence, box_class_probs)
# Step 2: Find the box_classes thanks to the max box_scores, keep track of the corresponding score
box_classes = K.argmax(box_scores, -1)
box_class_scores = K.max(box_scores, -1)
# Step 3: Create a filtering mask based on "box_class_scores" by using "threshold". The mask should have the
# same dimension as box_class_scores, and be True for the boxes you want to keep (with probability >= threshold)
filtering_mask = K.greater_equal(box_class_scores,threshold)
# Step 4: Apply the mask to scores, boxes and classes
print(filtering_mask.shape)
print(filtering_mask.eval())
print(box_class_scores.shape)
print(box_class_scores.eval())
scores = tf.boolean_mask(box_class_scores, filtering_mask)
print(scores.eval())
boxes = tf.boolean_mask(boxes, filtering_mask)
classes = tf.boolean_mask(box_classes, filtering_mask)
return scores, boxes, classes
with tf.Session() as test_a:
box_confidence = tf.random_normal([3, 3, 5, 1], mean=1, stddev=4, seed = 1)
boxes = tf.random_normal([3, 3, 5, 4], mean=1, stddev=4, seed = 1)
box_class_probs = tf.random_normal([3, 3, 5, 80], mean=1, stddev=4, seed = 1)
scores, boxes, classes = yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = 0.5)
print("scores[2] = " + str(scores[2].eval()))
print("boxes[2] = " + str(boxes[2].eval()))
print("classes[2] = " + str(classes[2].eval()))
print("scores.shape = " + str(scores.shape))
print("boxes.shape = " + str(boxes.shape))
print("classes.shape = " + str(classes.shape))
这是输出:
(3, 3, 5)
[[[ True True True True True]
[ True True True True True]
[ True False True True True]]
[[ True True True True True]
[ True True True True True]
[ True True True True True]]
[[ True True True True False]
[ True True True True True]
[ True True True True True]]]
(3, 3, 5)
[[[ 45.00004959 21.20238304 17.39275742 26.73288918 49.47431946]
[ 22.16205978 27.96604347 12.38916492 33.66600418 62.04590225]
[ 113.03194427 2.68868852 6.33391762 45.17211914 10.5103178 ]]
[[ 8.22186852 35.88579941 48.54780579 12.48789883 32.40937042]
[ 75.73269653 17.52830696 62.99983597 29.0468502 42.82471848]
[ 72.42234039 108.19727325 36.93912888 40.9789238 36.91137314]]
[[ 1.57321405 3.35663748 16.33576775 5.16499805 19.43038177]
[ 48.13769913 68.20082092 47.06818008 1.82166731 67.30760956]
[ 33.01203537 63.93298721 9.71860027 49.06838989 60.74739456]]]
[ 22.63684464 10.29589462 58.76845551 74.67560577 20.25722504
47.24279022 6.96320772 22.59087944 86.61974335 1.05248117
57.47060394 92.50878143 16.8335762 23.29385757 78.58971405
6.95861435 65.61254883 45.47106171 43.53435135 10.0660677
60.34520721 28.5535984 15.9668026 45.14865494 5.49425364
2.35473752 29.40540886 2.5579865 46.96302032 9.39739799
45.78501892 49.42660904 34.68322754 40.72031784 58.91592407
35.39850616 56.24537277 6.80519342 9.52552414 138.54457092
14.07888412 56.37608719 69.59171295 25.83714676]
scores[2] = 62.0051
boxes[2] = [-1.89158893 0.7749185 3.57417917 -0.05729628]
classes[2] = 36
scores.shape = (?,)
boxes.shape = (?, 4)
classes.shape = (?,)
我有一个简单的问题。结果是怎么scores
来的?它有 44 个元素,同时filtering_mask
有box_class_scores
45 个元素(3 * 3 * 5),而 filtering_mask 有 2 个错误值,必须使分数为 43 个元素。即使 filter_mask 有 1 个 false 值,分数中的数字都不匹配box_class_scores
。任何人都可以向我解释如何scores
计算
解决方案
问题不在于像您期望的那样起作用的掩蔽。问题是您在图表中使用随机值,其行为可能有点令人惊讶。每次调用eval()
实际上都是run
在默认会话中调用。问题在于 TensorFlow 中随机值的工作方式。每次run
在会话上调用时,都会生成一个新的随机值。这意味着每次调用都会根据和的eval
不同值产生结果。有可能解决它的方法,要么根本不使用随机值生成器作为输入,要么在同一次调用中评估所有输出(而不是box_confidence
boxes
box_class_probs
run
eval
)。由于您似乎正在编写测试代码,解决它的一种简单方法是将输入替换为由 NumPy 随机值制成的常量。
import tensorflow as tf
import numpy as np
def yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = .6):
# ...
with tf.Session() as test_a:
np.random.seed(1)
box_confidence = tf.constant(np.random.normal(loc=1, scale=4, size=[3, 3, 5, 1]), dtype=tf.float32)
boxes = tf.constant(np.random.normal(loc=1, scale=4, size=[3, 3, 5, 4]), dtype=tf.float32)
box_class_probs = tf.constant(np.random.normal(loc=1, scale=4, size=[3, 3, 5, 80]), dtype=tf.float32
scores, boxes, classes = yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = 0.5)
print("scores[2] = " + str(scores[2].eval()))
print("boxes[2] = " + str(boxes[2].eval()))
print("classes[2] = " + str(classes[2].eval()))
print("scores.shape = " + str(scores.shape))
print("boxes.shape = " + str(boxes.shape))
print("classes.shape = " + str(classes.shape))
或者您仍然可以使用 TensorFlow 随机数,但使用变量作为输入。变量的不同之处在于它们只在初始化时评估它们的初始值,然后它们在会话之间保持它们的值(直到它再次被改变),所以你不会每次都生成新的随机值。
import tensorflow as tf
def yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = .6):
# ...
with tf.Session() as test_a:
box_confidence = tf.Variable(tf.random_normal([3, 3, 5, 1], mean=1, stddev=4, seed = 1)))
boxes = tf.Variable(tf.random_normal([3, 3, 5, 4], mean=1, stddev=4, seed = 1))
box_class_probs = tf.Variable(tf.random_normal([3, 3, 5, 80], mean=1, stddev=4, seed = 1))
# You must initialize the variables
test_a.run(tf.global_variables_initializer())
scores, boxes, classes = yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold = 0.5)
print("scores[2] = " + str(scores[2].eval()))
print("boxes[2] = " + str(boxes[2].eval()))
print("classes[2] = " + str(classes[2].eval()))
print("scores.shape = " + str(scores.shape))
print("boxes.shape = " + str(boxes.shape))
print("classes.shape = " + str(classes.shape))
推荐阅读
- java - Java Swing:获取对 GUI 组件的引用
- java - 用 Java 计算产品,术语与“for”循环,不同的结果
- mysql - 最佳实践:截断并重新创建关联表条目或每天更新它们?
- sql - 如何在 SQL 中使用小于或等于运算符的 case 表达式?
- php - Wordpress 给出“未指定输入文件”错误,但仅在主域上
- javascript - 一次选择和取消选择所有复选框
- python - 处理来自 Google Speech 的响应
- python-3.x - 如何像其他公共 python 库一样从我的 python 代码生成 .exe?
- javascript - 从复杂的 JSON 迭代到表中
- django - Django使用外键将数据添加到数据库表中