首页 > 解决方案 > 如何在tensorflow的vectorized_map中使用TopK算子

问题描述

我试图在形状[batchsize, listsize]的张量上执行tf.vectorized_map并将tf.math.top_k运算符应用于批处理中的每一行,但没有成功。

例如,数据可能是:

[ [1,2,4,5,6], [9,5,4,2,1] ]

我想申请 topk on[1,2,4,5,6]和 on [9,5,4,2,1]

然而,我成功地做同样的事情,tf.map_fnvectorized_map应该跑得更快。我使用tensorflow 1.15

import tensorflow as tf
import numpy as np

# create fake data
x = tf.convert_to_tensor([
    [1,2,4,5,6],
    [9,5,4,2,1],
], dtype=tf.float32)
x = tf.reshape(x, (2, -1))

B = x.shape[0] # batchsize
L = x.shape[1] # list size

print(f"B {B}, L {L}")

sess = tf.Session()

print(f"x tensor: {sess.run(x)}\n")

def fv(_x):
    #_tensor = tf.reshape(_x, (L,))  # doesnt work (1)
    _tensor = tf.reshape(tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32), (L,)) # work (2)
    #_tensor = tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32) # work (3)

    print(f"_tensor: {_tensor}")
    values, indices = tf.math.top_k(_tensor, k=3)
    # i just need the indices
    return indices

indices = tf.vectorized_map(
        fv,
        x,
)
print("\nindices ")
print(sess.run(indices))

正如我们所看到的 (2) 和 (3) 运行,所以 topk 运算符应该是可用的。此外,即使 (1) 不起作用,我也可以使用 _x 并返回它,例如:

def fv(_x):
    return _x * 10

所以 _x 是可用的。

因此,当我使用 (1) 运行代码时,出现错误:

ValueError: No converter defined for TopKV2
name: "loop_body/TopKV2"
op: "TopKV2"
input: "loop_body/Reshape"
input: "loop_body/TopKV2/k"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "sorted"
  value {
    b: true
  }
}

inputs: [WrappedTensor(t=<tf.Tensor 'loop_body/Reshape/pfor/Reshape:0' shape=(2, 5) dtype=float32>, is_stacked=True, is_sparse_stacked=False), WrappedTensor(t=<tf.Tensor 'loop_body/TopKV2/k:0' shape=() dtype=int32>, is_stacked=False, is_sparse_stacked=False)]. 
Either add a converter or set --op_conversion_fallback_to_while_loop=True, which may run slower

Process finished with exit code 1

在这里,我只是尝试获取索引,之后我需要处理向量以在输出中具有类似[[0,0,1,1,1], [1,1,1,0,0] ]的输出K=3(如果值在 topk 中,则为 1,否则为 0)。并且还要给出另一个形状为 [batchsize, 1] 的张量,其中包含每行的 K 参数。(我已经用 map_fn 成功了,所以我认为以后不会有问题)。

也许可以在矢量化地图中实现我自己的 topk 运算符,但我宁愿不这样做。

标签: pythonnumpytensorflow

解决方案


我终于做了类似的事情:这不使用vectorized_map,但这是我想做的。但是,如果有人可以使它与 vectorized_map 一起使用,我会看看解决方案。:)

def topk(x, k):
    """
    x : shape [B, L]
    k : shape [B, 1]
    return : final_mask of shape [B,L] with final_mask[b,i] = 0 if x[b,i] is in  
    the k[b] biggest values of x[b,:], else final_mask[b,i] = 1
    """
    B = x.shape[0]  # batchsize
    L = x.shape[1]  # list size

    # the indices sorted in descending order
    indices_des = tf.argsort(x, axis=-1, direction='DESCENDING', stable=False, name='sorting_for_topk')

    
    mask = tf.reshape(tf.range(start=0, limit=L, dtype=tf.int32), [1, L])
    mask = tf.repeat(mask, [B], axis=0)
    mask = mask<k

    one_hot = tf.one_hot(indices_des, depth=L) * tf.cast(tf.reshape(mask, [B, L, 1]), tf.float32)
    final_mask = tf.reduce_sum(one_hot, axis=1)
    
    return final_mask

推荐阅读