tensorflow - 为稀疏张量返回每行的 top_k 项
问题描述
对于稠密张量,我们可以使用 tf.nn.topk 来查找最后一维的 k 个最大条目的值和索引。
对于稀疏张量,我想有效地获取每行的前 n 项,而不会将稀疏张量转换为密集。
解决方案
这有点棘手,但这是可行的(假设 2D 稀疏张量,尽管我认为对于更多外部维度应该同样有效)。这个想法是首先对整个稀疏张量进行排序(不使其密集),然后对第一列进行切片。为此,我需要类似的东西np.lexsort
,据我所知,TensorFlow 中没有提供这样的东西 - 但是,tf.sparse.reorder
实际上做了类似 lexsort 的东西,所以我制作了另一个中间稀疏张量来利用它。
import tensorflow as tf
import numpy as np
np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .5] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
# [2 4 0 6 8 0 0 6]
# [7 0 0 1 5 9 8 9]
# [4 0 0 3 0 0 0 3]
# [8 1 0 3 3 7 0 1]
# [0 0 0 0 7 0 0 7]]
# List of value indices
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Sort values
s = tf.dtypes.cast(tf.argsort(sp.values, direction='DESCENDING'), sp.indices.dtype)
# Find destination index of each sorted value
si = tf.scatter_nd(tf.expand_dims(s, 1), r, [n])
# Abuse sparse tensor functionality to do lexsort with column and destination index
sp2 = tf.sparse.SparseTensor(indices=tf.stack([sp.indices[:, 0], si], axis=1),
values=r,
dense_shape=[sp.dense_shape[0], n])
sp2 = tf.sparse.reorder(sp2)
# Build top-k result
row = sp.indices[:, 0]
# Make column indices
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# Get only up to k elements per row
m = col < k
row_m = tf.boolean_mask(row, m)
col_m = tf.boolean_mask(col, m)
idx_m = tf.boolean_mask(sp2.values, m)
# Make result
scatter_idx = tf.stack([row_m, col_m], axis=-1)
scatter_shape = [sp.dense_shape[0], k]
# Use -1 for rows with less than k values
# (0 is ambiguous)
values = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.values.dtype),
scatter_idx, tf.gather(sp.values, idx_m))
indices = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.indices.dtype),
scatter_idx, tf.gather(sp.indices[:, 1], idx_m))
print(values.numpy())
# [[ 3 -1 -1]
# [ 8 6 6]
# [ 9 9 8]
# [ 4 3 3]
# [ 8 7 3]
# [ 7 7 -1]]
print(indices.numpy())
# [[ 6 -1 -1]
# [ 4 3 7]
# [ 5 7 6]
# [ 0 3 7]
# [ 0 5 3]
# [ 4 7 -1]]
编辑:这是另一种可能性,如果您的张量在所有行中都非常稀疏,它可能会很好。这个想法是将所有稀疏张量值“压缩”到第一列中(就像前面的代码片段已经为 所做的那样sp3
),然后将其变成密集张量并像往常一样应用 top-k。需要注意的是,索引将被称为压缩张量,所以如果你想获得关于初始稀疏张量的正确索引,你必须采取另一个步骤。
import tensorflow as tf
import numpy as np
np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .8] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
# [0 4 0 6 0 0 0 0]
# [0 0 0 0 5 0 0 9]
# [0 0 0 0 0 0 0 0]
# [8 0 0 0 0 7 0 0]
# [0 0 0 0 7 0 0 0]]
# Build "condensed" sparse tensor
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Make indices
row = sp.indices[:, 0]
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# At least as many columns as k
ncols = tf.maximum(tf.math.reduce_max(col) + 1, k)
sp2 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
values=sp.values,
dense_shape=[sp.dense_shape[0], ncols])
# Get in dense form
condensed = tf.sparse.to_dense(sp2)
# Top-k (indices do not correspond to initial sparse matrix)
values, indices = tf.math.top_k(condensed, k)
print(values.numpy())
# [[3 0 0]
# [6 4 0]
# [9 5 0]
# [0 0 0]
# [8 7 0]
# [7 0 0]]
# Now get the right indices
sp3 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
values=sp.indices[:, 1],
dense_shape=[sp.dense_shape[0], ncols])
condensed_idx = tf.sparse.to_dense(sp3)
actual_indices = tf.gather_nd(condensed_idx, tf.expand_dims(indices, axis=-1),
batch_dims=1)
print(actual_indices.numpy())
# [[6 0 0]
# [3 1 0]
# [7 4 0]
# [0 0 0]
# [0 5 0]
# [4 0 0]]
不确定这是否会更快。
推荐阅读
- reactjs - 如何动态更新数组的一个元素
- reactjs - 是否可以将clevertap推送通知与expo reactnative集成
- arrays - 使用 C 查找 N 元素数组中所有可能对的总和
- android - 为什么我的 ViewModel 在 Fragment 中为 null 而不是 Fragment 的绑定布局?
- qt - Qt Designer 使用 Python 将用户输入从行编辑存储和解析到机器人操作系统 (ROS)
- java - 运行 Cucumber JUnit 时出错:initializationError
- java - 计数器在程序中关闭
- video.js - 在 videojs 中将 currentTime 设置为任何值都会将时间设置回 0
- javascript - Javascript:如何在响应字符串解析中处理 syntaxError
- c - CS50 PSET4 Sobel 过滤器 - 边缘(边界像素)