首页 > 解决方案 > 通过连接最后一维中的数字来减少张量

问题描述

为了识别唯一的序列,我需要使用 tensorflow 函数通过连接最后一维中的数字来将 2D int 张量减少为 1D int 张量。

例如,

[[1, 0], [2, 1], [1, 3], [2, 0], [0, 1]]

应该成为

[10, 21, 13, 20, 1]

到目前为止我所拥有的是

def reduce_concat(input):
  def join(x):
    dec = tf.range(0, x.shape[-1], 1)
    dec = tf.map_fn(lambda x: tf.math.pow(10, x), dec)
    return tf.math.reduce_sum(x * dec)
  return tf.map_fn(join, input)

这几乎可以工作,但它忽略了零并且不是很优雅。

如果有人能为这个问题提供一个优雅的解决方案,我将不胜感激 - 谢谢。

标签: pythontensorflow

解决方案


您可以尝试以下方法:

def reduce_concat(input):
  dec = 10**tf.range(input.shape[-1]-1, -1, -1)
  return tf.reduce_sum(input * dec, axis=-1)

您输入的结果:

[10, 21, 13, 20,  1]

推荐阅读