首页 > 解决方案 > Tensorflow初始化一个只有一行/列不为零的稀疏张量?

问题描述

假设我想初始化一个稀疏矩阵但指定一行不为零(TF1.14):

喜欢

a = [[0, 0, 0, 0, 0],  
     [0, 0, 0, 0, 0],  
     [1, 1, 1, 1, 1],   
      ...,  
     [0, 0, 0, 0, 0]]  

或者

a = [[0, 0, 0, 1, 0],  
     [0, 0, 0, 1, 0],  
     [0, 0, 0, 1, 0],   
      ...,  
     [0, 0, 0, 1, 0]]  

我看到SparseTensor做了类似的事情,但问题是它需要手动指定每个不为零的元素的所有索引(我只想指定一行或一列),有没有更简单的方法来实现这一点?

标签: pythontensorflow

解决方案


即使对于一般的多维情况,创建这样的张量也不难。你可以使用这样的函数:

import tensorflow as tf

def line_tensor(shape, idx, axis=0, dtype=tf.float32):
    shape = tf.convert_to_tensor(shape)
    s_before = tf.concat([shape[:axis], [idx], shape[axis + 1:]], axis=0)
    s_line = tf.concat([shape[:axis], [1], shape[axis + 1:]], axis=0)
    s_after = tf.concat([shape[:axis], [shape[axis] - idx - 1], shape[axis + 1:]], axis=0)
    return tf.concat([tf.zeros(s_before, dtype), tf.ones(s_line, dtype),
                      tf.zeros(s_after, dtype)], axis=axis)

print(line_tensor([3, 4], 2, axis=0, dtype=tf.int32).numpy())
# [[0 0 0 0]
#  [0 0 0 0]
#  [1 1 1 1]]
print(line_tensor([3, 4], 1, axis=1, dtype=tf.int32).numpy())
# [[0 1 0 0]
#  [0 1 0 0]
#  [0 1 0 0]]

如果您更喜欢创建稀疏张量,这也不难,尽管对于多维情况来说有点棘手。这是一个仅适用于矩阵的更简单的函数:

import tensorflow as tf

# Assumes 2D shape
def line_matrix_sp(shape, idx, axis=0, dtype=tf.float32):
    shape = tf.dtypes.cast(shape, tf.int64)
    n = shape[1 - idx]
    idx1 = tf.range(n)
    idx2 = tf.fill([n], tf.dtypes.cast(idx, tf.int64))
    idx = tf.gather(tf.stack([idx1, idx2], axis=1), [1 - axis, axis], axis=1)
    return tf.SparseTensor(idx, tf.ones([n], dtype), shape)

print(tf.sparse.to_dense(line_matrix_sp([3, 4], 2, axis=0, dtype=tf.int32)).numpy())
# [[0 0 0 0]
#  [0 0 0 0]
#  [1 1 1 1]]
print(tf.sparse.to_dense(line_matrix_sp([3, 4], 1, axis=1, dtype=tf.int32)).numpy())
# [[0 1 0 0]
#  [0 1 0 0]
#  [0 1 0 0]]

推荐阅读