python - 使用 tensorflow 进行可变形状输入
问题描述
我目前正在尝试使用 tensorflow 模型/函数来进行可变长度输入的计算。
更准确地说,我有两个列表作为输入(我们称它们为L_1
和L_2
),每个列表都包含相同数量的数组,但每个数组都有不同的形状。如果可能,我需要以矢量化方式将 tensorflow 函数应用于各个元组 ( L_1[0], L_2[0]), (L_1[1], L_2[1]), ..., (L_1[n], L_2[n]
)。
这是函数,它将两个张量作为输入并返回(可变长度)结果:
def nearest_point_calculations(input):
surface_points = tf.expand_dims(input[0], axis=0)
border_centers = tf.expand_dims(input[1], axis=1)
dif = tf.math.reduce_euclidean_norm(tf.subtract(surface_points, border_centers), axis=-1)
min = tf.argmin(dif, axis=1)
return tf.RaggedTensor.from_tensor(tf.squeeze(tf.gather(surface_points, min, axis=1), axis=0))
该函数在单个输入上运行良好,但我想一次处理整个数组列表,所以我不只是想使用 while 循环按顺序处理所有内容。
我能够创建的唯一工作方法如下:
def nearest_point_wrapper(surface_points, border_centers):
return tf.map_fn(nearest_point_calculations, (surface_points, border_centers), fn_output_signature=tf.RaggedTensorSpec(shape=(None, 2)))
def calculate_nearest_point_init(surface_points, border_centers):
surface_points_tensor = tf.ragged.constant(surface_points, dtype=tf.float32)
border_centers_tensor = tf.ragged.constant(border_centers, dtype=tf.float32)
return nearest_point_wrapper(surface_points_tensor, border_centers_tensor)
因此,此方法RaggedTensors
为两个输入创建两个并map_fn
在两个上使用RaggedTensors
。但是,RaggedTensors
如果在运行中完成创建非常慢(大约需要 1.5 秒,而为每个子阵列创建单个张量总共只需要 0.17 秒)。
因此,在第二个示例中,我尝试将每个数组单独转换为张量(这要快得多)并用于map_fn
生成的输入列表:
def nearest_point_wrapper(surface_points, border_centers):
return tf.map_fn(nearest_point_calculations, (surface_points, border_centers), fn_output_signature=tf.RaggedTensorSpec(shape=(None, 2)))
def calculate_nearest_point_init(surface_points, border_centers):
surface_points_tensor = [tf.constant(x, dtype=tf.float32) for x in surface_points]
border_centers_tensor = [tf.constant(x, dtype=tf.float32) for x in border_centers]
return nearest_point_wrapper(surface_points_tensor, border_centers_tensor)
但是使用这种方法,我只是ValueError
告诉我
第一个列表中张量的形状不匹配。
我想知道为什么map_fn
还要关心这一点,因为它只是应该采用第一个列表的第一个条目和第二个列表的第一个条目并应用于它们的功能?
为什么它甚至会检查列表中的各个元素是否具有相同的形状?
解决方案
推荐阅读
- wordpress - 字体的 Cors 问题
- azure - Databricks Connect:无法连接到 azure 上的远程集群,命令:“databricks-connect test”停止
- laravel - 子类别未显示在管理面板上
- python - 函数不返回 pyspark DataFrame
- android - 为什么 continueStroke 函数不起作用
- android - 当用户单击网页中的按钮时,将数据从 Trusted Web Activity 保存在内部存储中
- android - 如何以编程方式在android中获取状态启用或禁用自动启动权限
- linux - 使用正则表达式进行 Grep 并使用组进行捕获
- java - 我的查询有什么问题??它完全没有错误
- c# - 使用 P/Invoke 传递取消标志时是否需要同步?