python - 在张量流中,如何在 y 上使用 tf.where() 从 x 中选择 x 和 y 具有不同形状的行?
问题描述
我有一个大小为 x = 32x512 的矩阵和一个大小为 y = 32x1 的标签向量。标签在 0-3 范围内。我需要从 x 中选择相应 y 中标签为 1 的行。
我尝试使用以下命令: temp_maps = tf.where( tf.equal(y,1) , x , tf.math.scalar_mul(0,x) ) 但这给了我这个错误:
ValueError:尺寸必须相等,但对于 op:SelectV2 输入形状为 32 和 512:[32]、[32,512]、[32,512]
我想要的是 x 中标签为 1 的行。我正在使用 tf.math.scalar_mul(0,x) 因为在条件为 false 的情况下必须选择某些东西,所以我选择了一个零张量。
解决方案
创建虚拟矩阵:
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
B = tf.convert_to_tensor(np.random.randint(0, 3, 32).reshape((32, 1)))
A = tf.convert_to_tensor(np.arange(32*512).reshape((32, 512)))
用于tf.equal
获取所有 2-标签的布尔张量:
eq = tf.equal(B, 2)
In [16]: print(eq)
tf.Tensor(
[[ True]
[False]
[False]
[False]
[False]
[False]
[False]
[False]
[False]
[ True]
[ True]
[False]
[ True]
[False]
[False]
[False]
[False]
[False]
[ True]
[False]
[False]
[False]
[False]
[False]
[ True]
[False]
[ True]
[False]
[False]
[ True]
[False]
[False]], shape=(32, 1), dtype=bool)
现在您可以使用tf.where
来获取位置索引:
In [19]: tf.where(eq)
Out[19]:
<tf.Tensor: id=45, shape=(8, 2), dtype=int64, numpy=
array([[ 0, 0],
[ 9, 0],
[10, 0],
[12, 0],
[18, 0],
[24, 0],
[26, 0],
[29, 0]])>
如果您想获得 A 的一部分,可以使用tf.gather
:
In [30]: tf.gather(A, tf.where(tf.equal(B, 2))[:, 0])
Out[30]:
<tf.Tensor: id=105, shape=(8, 512), dtype=int64, numpy=
array([[ 0, 1, 2, ..., 509, 510, 511],
[ 4608, 4609, 4610, ..., 5117, 5118, 5119],
[ 5120, 5121, 5122, ..., 5629, 5630, 5631],
...,
[12288, 12289, 12290, ..., 12797, 12798, 12799],
[13312, 13313, 13314, ..., 13821, 13822, 13823],
[14848, 14849, 14850, ..., 15357, 15358, 15359]])>
推荐阅读
- javascript - 如何在vuejs的内部循环中从点击事件传递数据
- nginx - 使用 Nginx 的 Canary 部署 - 流量未路由到 Canary 服务器
- ruby-on-rails - 如何选择尚未保存在 Rails 中的关联记录?
- sql - 在子过程中创建临时表
- python - 为什么小于运算符不能在这两个字符串之间工作?
- ruby - 如果用户输入无效输入,程序不应崩溃并再次提示用户 - 重构 cli 项目
- javascript - 以嵌套形式提取反应挂钩时挂钩调用无效
- javascript - 无法读取未定义的属性“fetchBans”
- javascript - 为什么当我删除回车时我的 javascript 不起作用?
- php - 自定义/扩展 laravel 分页方法