首页 > 解决方案 > 有没有比使用 Reshape 更好的方法来重复张量内的图层?

问题描述

我想将 [Batch Size, 1] 形状的张量转换为 [Batch Size, Channel Size]。为此,我编写了以下算法:

@tf.function
def fit_channel_layers(input_tensor, target_tensor):
    final_channel_layers = tf.shape(target_tensor)[-1]
    input_tensor = tf.repeat(input_tensor, repeats=[final_channel_layers])
    input_tensor = tf.reshape(input_tensor, [local_batch_size, final_channel_layers])
    return input_tensor

input_tensor = tf.convert_to_tensor([[1.],[0.],[0.],[1.],[0.],[0.]])
target_tensor = tf.convert_to_tensor([[1.,1.,1.,1.,1.,1.],
 [0.,0.,0.,0.,0.,0.],
 [0.,0.,0.,0.,0.,0.],
 [1.,1.,1.,1.,1.,1.],
 [0.,0.,0.,0.,0.,0.],
 [0.,0.,0.,0.,0.,0.]])
fit_channel_layers(input_tensor, target_tensor)

输出合适

array([[1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)>

但是,我不确定使用 reshape 对性能的影响。我知道使用 numpy 可以使用逗号和乘数运算符轻松扩展 numpy 数组。但是,我想知道是否有比目前用@tf.function 图形代码编写的更好的方法来完成这项任务?

标签: pythonalgorithmtensorflowtensorflow2.0

解决方案


推荐阅读