python - TensorFlow:获取非零张量中最小元素索引的有效方法?
问题描述
我使用 TensorFlow 1.12。我有一个一维张量tag_mask_sizes
,它主要包含零,但也包含一些正整数。如何有效地获取不为零的最小元素的索引?我尝试了以下方法:
tag_mask_sizes_suppressed = tf.map_fn(lambda x: x if tf.not_equal(x, tf.constant(0, dtype=tf.uint8)) else 9999999, tag_mask_sizes)
smallest_mask_index = tf.argmin(tag_mask_sizes_suppressed)
但是,tf.not_equal()
会产生一个布尔张量,我无法在 lambda 内的 if-else 条件下有效地评估它。还有其他像这样优雅的解决方案吗?
虽然我通常急切地执行,但这个问题发生在我使用的函数中tf.Dataset.map()
,该函数没有急切地执行。
解决方案
事实上,你的代码等价于下面的代码。
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
矢量化方法将明显快于tf.map_fn()
. 此外,还有一些矢量化方法可以获取一维张量中不为零的最小元素的索引。一个例子:
import tensorflow as tf
# tf.enable_eager_execution()
tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)
# approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
# approach 2, only the first minimum can be found.
tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]
# approach 3, find all minimum
tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))
with tf.Session() as sess:
print(sess.run(smallest_mask_index1))
print(sess.run(smallest_mask_index2))
print(sess.run(smallest_mask_index3))
# print
2
[2]
[2 4]
推荐阅读
- hierarchical-data - 检查复杂层次模型的收敛性 JAGS
- javascript - Firebase - 如何在数组内的时间戳字段中插入当前日期?
- javascript - 如何从一个大数组中获取一些数据
- php - 如何用 php 创建图片幻灯片?
- javascript - polling-xhr.js:229 GET http://localhost:3000/socket.io/?EIO=4&transport=polling&t=NMtL4rR net::ERR_CONNECTION_REFUSED
- javascript - 正则表达式查看字符串是整数还是分数
- javascript - QML中矩形上的两种不同阴影
- javascript - 在数字输入上使用按钮不会更新第一个输入
- ios - 将 ipa 文件上传到 Testflight 时 iOS Fastlane 构建失败
- java - JPA,Hibernate:子类的字段为空,@MappedSuperclass 和 InheritanceType.JOINED 策略