首页 > 解决方案 > 分配给 TensorFlow 中变量的扁平视图的切片

问题描述

在 TensorFlow 中,我错过了一种直接将某些内容分配给变量扁平视图的切片的方法。

这是一个以迂回方式实现相同结果的示例:

var = tf.Variable(tf.reshape(tf.range(12), [4,3]))

# <tf.Variable 'Variable:0' shape=(4, 3) dtype=int32, numpy=
# array([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11]], dtype=int32)>

flat_indices = tf.range(4, 8)
multi_dim_indices = tf.transpose(tf.unravel_index(flat_indices, dims=[4,3]))

# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 1],
#        [1, 2],
#        [2, 0],
#        [2, 1]], dtype=int32)>

update = [40, 50, 60, 70]

var.scatter_nd_update(multi_dim_indices, update)

# <tf.Variable 'UnreadVariable' shape=(4, 3) dtype=int32, numpy=
# array([[ 0,  1,  2],
#        [ 3, 40, 50],
#        [60, 70,  8],
#        [ 9, 10, 11]], dtype=int32)>

但这不是大张量的有效解决方案。建筑multi_dim_indices应该是不必要的。scatter_nd_update是一个稀疏操作,但我正在寻找的是对连续内存的密集分配。

使用类似 numpy 的 API,我可以编写:

var.flat[4:8] = update

有没有一种有效的方法可以在 TensorFlow 中实现相同的结果,也许使用更丑的 API?

标签: pythontensorflow

解决方案


中做类似的操作,如下:

var = tf.Variable(tf.reshape(tf.range(12), [4,3]))
var = tf.Variable(tf.reshape(var, [-1]))  # flatten the vector 
var[4:8].assign([40, 50, 60, 70]) # update / assign item in particular indices 
var = tf.Variable(tf.reshape(var, [4,3])) # reshape back to original shape
var

<tf.Variable 'Variable:0' shape=(4, 3) dtype=int32, numpy=
array([[ 0,  1,  2],
       [ 3, 40, 50],
       [60, 70,  8],
       [ 9, 10, 11]], dtype=int32)>

推荐阅读