python - 如何在tensorflow的vectorized_map中使用TopK算子
问题描述
我试图在形状[batchsize, listsize]的张量上执行tf.vectorized_map并将tf.math.top_k运算符应用于批处理中的每一行,但没有成功。
例如,数据可能是:
[ [1,2,4,5,6], [9,5,4,2,1] ]
我想申请 topk on[1,2,4,5,6]
和 on [9,5,4,2,1]
。
然而,我成功地做同样的事情,tf.map_fn
但vectorized_map
应该跑得更快。我使用tensorflow 1.15。
import tensorflow as tf
import numpy as np
# create fake data
x = tf.convert_to_tensor([
[1,2,4,5,6],
[9,5,4,2,1],
], dtype=tf.float32)
x = tf.reshape(x, (2, -1))
B = x.shape[0] # batchsize
L = x.shape[1] # list size
print(f"B {B}, L {L}")
sess = tf.Session()
print(f"x tensor: {sess.run(x)}\n")
def fv(_x):
#_tensor = tf.reshape(_x, (L,)) # doesnt work (1)
_tensor = tf.reshape(tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32), (L,)) # work (2)
#_tensor = tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32) # work (3)
print(f"_tensor: {_tensor}")
values, indices = tf.math.top_k(_tensor, k=3)
# i just need the indices
return indices
indices = tf.vectorized_map(
fv,
x,
)
print("\nindices ")
print(sess.run(indices))
正如我们所看到的 (2) 和 (3) 运行,所以 topk 运算符应该是可用的。此外,即使 (1) 不起作用,我也可以使用 _x 并返回它,例如:
def fv(_x):
return _x * 10
所以 _x 是可用的。
因此,当我使用 (1) 运行代码时,出现错误:
ValueError: No converter defined for TopKV2
name: "loop_body/TopKV2"
op: "TopKV2"
input: "loop_body/Reshape"
input: "loop_body/TopKV2/k"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "sorted"
value {
b: true
}
}
inputs: [WrappedTensor(t=<tf.Tensor 'loop_body/Reshape/pfor/Reshape:0' shape=(2, 5) dtype=float32>, is_stacked=True, is_sparse_stacked=False), WrappedTensor(t=<tf.Tensor 'loop_body/TopKV2/k:0' shape=() dtype=int32>, is_stacked=False, is_sparse_stacked=False)].
Either add a converter or set --op_conversion_fallback_to_while_loop=True, which may run slower
Process finished with exit code 1
在这里,我只是尝试获取索引,之后我需要处理向量以在输出中具有类似[[0,0,1,1,1], [1,1,1,0,0] ]
的输出K=3
(如果值在 topk 中,则为 1,否则为 0)。并且还要给出另一个形状为 [batchsize, 1] 的张量,其中包含每行的 K 参数。(我已经用 map_fn 成功了,所以我认为以后不会有问题)。
也许可以在矢量化地图中实现我自己的 topk 运算符,但我宁愿不这样做。
解决方案
我终于做了类似的事情:这不使用vectorized_map,但这是我想做的。但是,如果有人可以使它与 vectorized_map 一起使用,我会看看解决方案。:)
def topk(x, k):
"""
x : shape [B, L]
k : shape [B, 1]
return : final_mask of shape [B,L] with final_mask[b,i] = 0 if x[b,i] is in
the k[b] biggest values of x[b,:], else final_mask[b,i] = 1
"""
B = x.shape[0] # batchsize
L = x.shape[1] # list size
# the indices sorted in descending order
indices_des = tf.argsort(x, axis=-1, direction='DESCENDING', stable=False, name='sorting_for_topk')
mask = tf.reshape(tf.range(start=0, limit=L, dtype=tf.int32), [1, L])
mask = tf.repeat(mask, [B], axis=0)
mask = mask<k
one_hot = tf.one_hot(indices_des, depth=L) * tf.cast(tf.reshape(mask, [B, L, 1]), tf.float32)
final_mask = tf.reduce_sum(one_hot, axis=1)
return final_mask
推荐阅读
- javascript - Prisma findMany 函数不返回关系数据
- sql - 比较 2 个表并跟踪更改的列和其他设置为 null
- amazon-web-services - Runtime.ImportModuleError:错误:找不到模块“onCreateRadonData”
- django - SSE DJANGO 请求数
- python-3.x - 如何在 Tkinter 文件对话框中按选定的顺序对选定的文件进行排序?
- python - 如何按顺序从不同文件夹中的文件中读取数据?
- react-native - React-Native 中的阴影或高程问题
- html - HTML Flex 项目宽度随着项目数量的增加而变化
- python - Python - 从 python 中的列表中删除逗号和方括号 - 插入排序
- kubernetes - Kubernetes with route fanout - 对服务设置的基本理解