首页 > 解决方案 > 在张量流中,如何在 y 上使用 tf.where() 从 x 中选择 x 和 y 具有不同形状的行?

问题描述

我有一个大小为 x = 32x512 的矩阵和一个大小为 y = 32x1 的标签向量。标签在 0-3 范围内。我需要从 x 中选择相应 y 中标签为 1 的行。

我尝试使用以下命令: temp_maps = tf.where( tf.equal(y,1) , x , tf.math.scalar_mul(0,x) ) 但这给了我这个错误:

ValueError:尺寸必须相等,但对于 op:SelectV2 输入形状为 32 和 512:[32]、[32,512]、[32,512]

我想要的是 x 中标签为 1 的行。我正在使用 tf.math.scalar_mul(0,x) 因为在条件为 false 的情况下必须选择某些东西,所以我选择了一个零张量。

标签: pythontensorflow

解决方案


创建虚拟矩阵:

        import tensorflow as tf
        tf.enable_eager_execution()
        import numpy as np
        B = tf.convert_to_tensor(np.random.randint(0, 3, 32).reshape((32, 1)))
        A = tf.convert_to_tensor(np.arange(32*512).reshape((32, 512)))

用于tf.equal获取所有 2-标签的布尔张量:

    eq = tf.equal(B, 2)
    In [16]: print(eq)
    tf.Tensor(
    [[ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [ True]
     [False]
     [False]], shape=(32, 1), dtype=bool)

现在您可以使用tf.where来获取位置索引:

In [19]: tf.where(eq)
Out[19]: 
<tf.Tensor: id=45, shape=(8, 2), dtype=int64, numpy=
array([[ 0,  0],
       [ 9,  0],
       [10,  0],
       [12,  0],
       [18,  0],
       [24,  0],
       [26,  0],
       [29,  0]])>

如果您想获得 A 的一部分,可以使用tf.gather

In [30]: tf.gather(A, tf.where(tf.equal(B, 2))[:, 0])
Out[30]: 
<tf.Tensor: id=105, shape=(8, 512), dtype=int64, numpy=
array([[    0,     1,     2, ...,   509,   510,   511],
       [ 4608,  4609,  4610, ...,  5117,  5118,  5119],
       [ 5120,  5121,  5122, ...,  5629,  5630,  5631],
       ...,
       [12288, 12289, 12290, ..., 12797, 12798, 12799],
       [13312, 13313, 13314, ..., 13821, 13822, 13823],
       [14848, 14849, 14850, ..., 15357, 15358, 15359]])>

推荐阅读