首页 > 解决方案 > 使用 tensorflow 进行可变形状输入

问题描述

我目前正在尝试使用 tensorflow 模型/函数来进行可变长度输入的计算。
更准确地说,我有两个列表作为输入(我们称它们为L_1L_2),每个列表都包含相同数量的数组,但每个数组都有不同的形状。如果可能,我需要以矢量化方式将 tensorflow 函数应用于各个元组 ( L_1[0], L_2[0]), (L_1[1], L_2[1]), ..., (L_1[n], L_2[n])。

这是函数,它将两个张量作为输入并返回(可变长度)结果:

def nearest_point_calculations(input):
    surface_points = tf.expand_dims(input[0], axis=0)
    border_centers = tf.expand_dims(input[1], axis=1)
    dif = tf.math.reduce_euclidean_norm(tf.subtract(surface_points, border_centers), axis=-1)
    min = tf.argmin(dif, axis=1)
    return tf.RaggedTensor.from_tensor(tf.squeeze(tf.gather(surface_points, min, axis=1), axis=0))

该函数在单个输入上运行良好,但我想一次处理整个数组列表,所以我不只是想使用 while 循环按顺序处理所有内容。

我能够创建的唯一工作方法如下:

def nearest_point_wrapper(surface_points, border_centers):
    return tf.map_fn(nearest_point_calculations, (surface_points, border_centers), fn_output_signature=tf.RaggedTensorSpec(shape=(None, 2)))

def calculate_nearest_point_init(surface_points, border_centers):
    surface_points_tensor = tf.ragged.constant(surface_points, dtype=tf.float32)
    border_centers_tensor = tf.ragged.constant(border_centers, dtype=tf.float32)
    return nearest_point_wrapper(surface_points_tensor, border_centers_tensor)

因此,此方法RaggedTensors为两个输入创建两个并map_fn在两个上使用RaggedTensors。但是,RaggedTensors如果在运行中完成创建非常慢(大约需要 1.5 秒,而为每个子阵列创建单个张量总共只需要 0.17 秒)。

因此,在第二个示例中,我尝试将每个数组单独转换为张量(这要快得多)并用于map_fn生成的输入列表:

def nearest_point_wrapper(surface_points, border_centers):
    return tf.map_fn(nearest_point_calculations, (surface_points, border_centers), fn_output_signature=tf.RaggedTensorSpec(shape=(None, 2)))

def calculate_nearest_point_init(surface_points, border_centers):
    surface_points_tensor = [tf.constant(x, dtype=tf.float32) for x in surface_points]
    border_centers_tensor = [tf.constant(x, dtype=tf.float32) for x in border_centers]
    return nearest_point_wrapper(surface_points_tensor, border_centers_tensor)

但是使用这种方法,我只是ValueError告诉我 第一个列表中张量的形状不匹配

我想知道为什么map_fn还要关心这一点,因为它只是应该采用第一个列表的第一个条目和第二个列表的第一个条目并应用于它们的功能?
为什么它甚至会检查列表中的各个元素是否具有相同的形状?

标签: pythontensorflowtensorflow2.0

解决方案


推荐阅读