python - 通过 TensorFlow 中的段长度计算 tf.math.segment_sum 中所需的段 id
问题描述
我正在处理可变大小的顺序数据。让我们考虑像这样的数据
Y = [ [.01,.02], [.03,.04], [.05,.06], [.07,.08], [.09,.1] ]
l = [ 3, 2 ]
其中Y
是对我的数据执行的一些辅助计算的结果,并l
存储了原始序列的长度。因此,在该示例[.01,.02], [.03,.04], [.05,.06]
中,是对批次的第一个序列执行的计算的结果,并且是对批次的[.07,.08], [.09,.1]
第二个序列执行的计算的结果。现在我想对 的条目做一些进一步的计算,但按序列分组。在 Tensorflow 中,有一些函数可以按组执行。3
2
Y
tf.math.segment_sum
可以说我想使用tf.math.segment_sum
. 我会感兴趣
seq_ids = [ 0, 0, 0, 1, 1 ]
tf.math.segment_sum(Y, segment_ids=seq_ids) #returns [ [0.09 0.12], [0.16 0.18] ]
我现在面临的问题是seq_ids
从l
. 在 numpy 中,人们很容易通过
seq_ids = np.digitize( np.arange(np.sum(l)), np.cumsum(l) )
似乎有一个隐藏的(来自 python api)等价于在 Tensorflow中搜索a时提到的digitize
命名。但似乎引用的内容已从 Tensorflow 中删除,我不清楚是否仍然(并且将会)支持python api 中的函数。我必须得到类似结果的另一个想法是使用该函数。但这次尝试失败了,因为bucketize
digitize
hidden_ops.txt
tensorflow::ops::Bucketize
tf.train.piecewise_constant
seq_ids = tf.train.piecewise_constant(tf.range(tf.math.reduce_sum(l)), tf.math.cumsum(l), tf.range(BATCH_SIZE-1))
失败了object of type 'Tensor' has no len()
。似乎tf.train.piecewise_constant
没有以最通用的方式作为参数实现,boundaries
并且values
需要是列表而不是张量。在l
我的例子中是一个一维张量聚集在我的小批量中tf.data.Dataset
解决方案
这是一种方法:
import tensorflow as tf
def make_seq_ids(lens):
# Get accumulated sums (e.g. [2, 3, 1] -> [2, 5, 6])
c = tf.cumsum(lens)
# Take all but the last accumulated sum value as indices
idx = c[:-1]
# Put ones on every index
s = tf.scatter_nd(tf.expand_dims(idx, 1), tf.ones_like(idx), [c[-1]])
# Use accumulated sums to generate ids for every segment
return tf.cumsum(s)
with tf.Graph().as_default(), tf.Session() as sess:
print(sess.run(make_seq_ids([2, 3, 1])))
# [0 0 1 1 1 2]
编辑:
tf.searchsorted
您也可以使用以与您为 NumPy 建议的方式更相似的方式实现相同的功能:
import tensorflow as tf
def make_seq_ids(lens):
c = tf.cumsum(lens)
return tf.searchsorted(c, tf.range(c[-1]), side='right')
这些实现都不应该成为 TensorFlow 模型中的瓶颈,因此对于任何实际目的,您选择哪一个都无关紧要。然而,有趣的是,在我的特定机器(Win 10、TF 1.12、Core i7 7700K、Titan V)中,第二个实现在 CPU 上运行时慢约 1.5 倍,在 GPU 上运行时快约 3.5 倍。
推荐阅读
- swift - 在父窗口的中心显示工作表
- javascript - 在javascript中从protobuf解码序列化数据
- java - 使用 IntelliJ 从 PropertiesLoader 加载的 Spring Boot 模块时出现 NoClassDefFoundError
- java - 将帧(android 中的 mat 数据)从 android 传递到本机 c++ 并检测人脸
- latex - 有没有办法减少乳胶中标题和副标题之间紧凑外观的上边距?
- spring - 方法级别的弹簧注释建议顺序
- pandas - 将分类数值数据编码到不同的列
- python-3.x - 将 3d 数据输入到 lstm
- gmail - 仅为 Gmail 用户隐藏内容
- c# - 如何按另一个列表对对象列表进行排序