首页 > 解决方案 > 什么是火炬的 torch.cat 与 tensorflow 等价?

问题描述

def cxcy_to_xy(cxcy):
    """
    Convert bounding boxes from center-size coordinates (c_x, c_y, w, h) to boundary coordinates (x_min, y_min, x_max, y_max).

    :param cxcy: bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4)
    :return: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4)
    """
    return torch.cat([cxcy[:, :2] - (cxcy[:, 2:] / 2),  # x_min, y_min
                      cxcy[:, :2] + (cxcy[:, 2:] / 2)], 1)  # x_max, y_max

我想用 tensorflow 2.0 改变这个 torch.cat

标签: tensorflowpytorchtorch

解决方案


几个选项取决于您使用的 TF 中的 API:

  • tf.concat- 最类似于torch.cat

    tf.concat(values, axis, name='concat')
    
  • tf.keras.layers.concatenate- 如果您使用 Keras 顺序 API:

    tf.keras.layers.concatenate(values, axis=-1, **kwargs)
    
  • tf.keras.layers.Concatenate- 如果你使用 Keras 函数式 API:

    x = tf.keras.layers.Concatenate(axis=-1, **kwargs)(values)
    

如果您使用的是 Keras API,此答案有助于了解所有 Keras 连接函数之间的差异。


推荐阅读