首页 > 解决方案 > 在 tf.function 中使用复杂的控制流可以吗?

问题描述

我有以下 Python 函数,我想将它包装成@tf.function(最初输入参数是 numpy 数组,但为了在 GPU 上执行,将它们转换为 TF 张量不是问题)。

def reproject(before_frame, motion_vecs):
    reprojected_image = np.zeros((before_frame.shape[0], before_frame.shape[1], before_frame.shape[2]))
    for row_idx in range(before_frame.shape[0]):
        for col_idx in range(before_frame.shape[1]):
            for c_idx in range(before_frame.shape[2]):
                diff_u = int(round(
                         (before_frame.shape[1] * motion_vecs[row_idx][col_idx][0])
                     ))
                diff_v  = int(round(
                         (before_frame.shape[0] * motion_vecs[row_idx][col_idx][1])
                     ))
                before_pixel_position = (
                    row_idx + diff_v,
                    col_idx + diff_u
                )
                if before_pixel_position[0] < before_frame.shape[0] and before_pixel_position[1] < before_frame.shape[1] \
                    and before_pixel_position[0] > 0 and before_pixel_position[1] > 0:
                    reprojected_image[row_idx][col_idx][c_idx] = before_frame[
                        before_pixel_position[0]
                    ][
                        before_pixel_position[1]
                    ][c_idx]
    return reprojected_image

我可以看到在 Tensorflow 教程中人们使用vectorized_mapormap_fn代替循环,而tf.cond不是if运算符。那么使用这些函数是控制流的唯一选择吗?如果是,背后的原因是什么?

标签: tensorflowtensorflow2.0

解决方案


推荐阅读