首页 > 解决方案 > 为稀疏张量返回每行的 top_k 项

问题描述

对于稠密张量,我们可以使用 tf.nn.topk 来查找最后一维的 k 个最大条目的值和索引。

对于稀疏张量,我想有效地获取每行的前 n 项,而不会将稀疏张量转换为密集。

标签: tensorflowsparse-matrix

解决方案


这有点棘手,但这是可行的(假设 2D 稀疏张量,尽管我认为对于更多外部维度应该同样有效)。这个想法是首先对整个稀疏张量进行排序(不使其密集),然后对第一列进行切片。为此,我需要类似的东西np.lexsort,据我所知,TensorFlow 中没有提供这样的东西 - 但是,tf.sparse.reorder实际上做了类似 lexsort 的东西,所以我制作了另一个中间稀疏张量来利用它。

import tensorflow as tf
import numpy as np

np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .5] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
#  [2 4 0 6 8 0 0 6]
#  [7 0 0 1 5 9 8 9]
#  [4 0 0 3 0 0 0 3]
#  [8 1 0 3 3 7 0 1]
#  [0 0 0 0 7 0 0 7]]

# List of value indices
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Sort values
s = tf.dtypes.cast(tf.argsort(sp.values, direction='DESCENDING'), sp.indices.dtype)
# Find destination index of each sorted value
si = tf.scatter_nd(tf.expand_dims(s, 1), r, [n])
# Abuse sparse tensor functionality to do lexsort with column and destination index
sp2 = tf.sparse.SparseTensor(indices=tf.stack([sp.indices[:, 0], si], axis=1),
                             values=r,
                             dense_shape=[sp.dense_shape[0], n])
sp2 = tf.sparse.reorder(sp2)
# Build top-k result
row = sp.indices[:, 0]
# Make column indices
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# Get only up to k elements per row
m = col < k
row_m = tf.boolean_mask(row, m)
col_m = tf.boolean_mask(col, m)
idx_m = tf.boolean_mask(sp2.values, m)
# Make result
scatter_idx = tf.stack([row_m, col_m], axis=-1)
scatter_shape = [sp.dense_shape[0], k]
# Use -1 for rows with less than k values
# (0 is ambiguous)
values = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.values.dtype),
                                     scatter_idx, tf.gather(sp.values, idx_m))
indices = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.indices.dtype),
                                      scatter_idx, tf.gather(sp.indices[:, 1], idx_m))
print(values.numpy())
# [[ 3 -1 -1]
#  [ 8  6  6]
#  [ 9  9  8]
#  [ 4  3  3]
#  [ 8  7  3]
#  [ 7  7 -1]]
print(indices.numpy())
# [[ 6 -1 -1]
#  [ 4  3  7]
#  [ 5  7  6]
#  [ 0  3  7]
#  [ 0  5  3]
#  [ 4  7 -1]]

编辑:这是另一种可能性,如果您的张量在所有行中都非常稀疏,它可能会很好。这个想法是将所有稀疏张量值“压缩”到第一列中(就像前面的代码片段已经为 所做的那样sp3),然后将其变成密集张量并像往常一样应用 top-k。需要注意的是,索引将被称为压缩张量,所以如果你想获得关于初始稀疏张量的正确索引,你必须采取另一个步骤。

import tensorflow as tf
import numpy as np

np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .8] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
#  [0 4 0 6 0 0 0 0]
#  [0 0 0 0 5 0 0 9]
#  [0 0 0 0 0 0 0 0]
#  [8 0 0 0 0 7 0 0]
#  [0 0 0 0 7 0 0 0]]

# Build "condensed" sparse tensor
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Make indices
row = sp.indices[:, 0]
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# At least as many columns as k
ncols = tf.maximum(tf.math.reduce_max(col) + 1, k)
sp2 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
                             values=sp.values,
                             dense_shape=[sp.dense_shape[0], ncols])
# Get in dense form
condensed = tf.sparse.to_dense(sp2)
# Top-k (indices do not correspond to initial sparse matrix)
values, indices = tf.math.top_k(condensed, k)
print(values.numpy())
# [[3 0 0]
#  [6 4 0]
#  [9 5 0]
#  [0 0 0]
#  [8 7 0]
#  [7 0 0]]

# Now get the right indices
sp3 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
                             values=sp.indices[:, 1],
                             dense_shape=[sp.dense_shape[0], ncols])
condensed_idx = tf.sparse.to_dense(sp3)
actual_indices = tf.gather_nd(condensed_idx, tf.expand_dims(indices, axis=-1),
                              batch_dims=1)
print(actual_indices.numpy())
# [[6 0 0]
#  [3 1 0]
#  [7 4 0]
#  [0 0 0]
#  [0 5 0]
#  [4 0 0]]

不确定这是否会更快。


推荐阅读