首页 > 解决方案 > 如何在张量流中选择特定元素?

问题描述

我需要从列表中选择点击的项目,我的输入如下:

item_arr=tf.constant([['item001', 'item010', 'item020', 'item030', 'item041', 'item051'], 
                      ['item101', 'item110', 'item120', 'item130', 'item140', 'item151']])
clicked_arr=tf.constant([[1, 0, 0, 0, 1, 1], [1, 0, 0, 0, 0, 1]])

是物品的item_arr批次数据,clicked_arr被点击标志(1是点击,0是没有点击)对应item_arr

我希望得到这样的输出来获取点击的项目:

clicked_item_arr(for example, output shape is [2,4]):
[
['item001', 'item041', 'item051', 'item_placeholder'], 
['item101', 'item151', 'item_placeholder', 'item_placeholder']
]

我还需要访问未点击的项目,但应该使用相同的点击项目解决方案。

我尝试使用gather_ndsparse_to_dense

index_arr=tf.where(tf.equal(clicked_arr, 1))
>>> array([[0, 0],
       [0, 4],
       [0, 5],
       [1, 0],
       [1, 5]])
sparse_item_value= tf.gather_nd(item_arr, index_arr)
>>> array(['item001', 'item041', 'item051', 'item101', 'item151'],
      dtype=object)

但我无法得到我想要的结果,因为我需要 sparse_indices 像:

array([[0, 0],
       [0, 1],
       [0, 2],
       [1, 0],
       [1, 1]])

这样我就可以使用:

tf.sparse_to_dense(
    sparse_indices=sparse_indices,
    sparse_values=sparse_item_value,
    output_shape=[2, 10],
    default_value='item_placeholder'
)

但我不知道如何获得 sparse_indices。希望您能给我一些建议,并提前感谢。

标签: pythontensorflow

解决方案


推荐阅读