首页 > 解决方案 > 在 tensorflow(或 keras)中应用通道洗牌

问题描述

我正在尝试在 tensorflow(或 keras)中实现通道洗牌功能。我找到了这个实现,但它似乎是错误的,因为我认为它是基于这个pytorch 实现的。

我已经设法做到了,concatenate()但我想要一个使用permute_dimensions(). 另外,我不确定连接版本是否较慢(如果有人能回答这个问题,我将不胜感激)。

一个工作的张量流实现使用concatenate()

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import models
import numpy as np

a = tf.constant([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
sess = tf.Session()
print('x', sess.run(a))
groups = 2  # separate into 2 group
h, w, in_channel = K.int_shape(a)[1:]
l = K.reshape(a, [-1, h, w, in_channel // groups, groups])
m = K.concatenate((l[..., 1], l[..., 0]))
l = K.reshape(m, [-1, h, w, in_channel])
print('y', sess.run(l))

输出:

x [[[[ 1  2]
   [ 3  4]
   [ 5  6]]
  [[ 7  8]
   [ 9 10]
   [11 12]]]]
y [[[[ 2  1]
   [ 4  3]
   [ 6  5]]
  [[ 8  7]
   [10  9]
   [12 11]]]]

keras 非工作实现如下:

def channel_shuffle(x):
    g = 2
    b, h, w, c = x.shape.as_list()
    x = K.reshape(x, [-1, h, w, g, c // g])
    x = K.permute_dimensions(x, (0, 1, 2, 4, 3))
    x = K.reshape(x, [-1, h, w, c])
    return x

input_shape = (2, 3, 2)
x = np.array([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
image_input = layers.Input(input_shape)
permuted_x = channel_shuffle4(image_input)
model = models.Model(inputs=[image_input], outputs=permuted_x)
y = model.predict(x)
print('x', x)
print('y', y)

输出:

x [[[[ 1  2]
   [ 3  4]
   [ 5  6]]

  [[ 7  8]
   [ 9 10]
   [11 12]]]]
y [[[[ 1.  2.]
   [ 3.  4.]
   [ 5.  6.]]

  [[ 7.  8.]
   [ 9. 10.]
   [11. 12.]]]]

这显然根本不会改变输入数据。那么,我怎样才能达到预期的效果呢?基本上我应该交换哪些轴?我做了一些实验,但似乎找不到合适的。

标签: pythontensorflowkeras

解决方案


您需要在 之后执行最后一个通道的反转permute_dimensionspermute_dimensions是一样的tf.transpose。这是一个直接作用于张量的解决方案:

import tensorflow as tf 
import numpy as np

def channel_shuffle(x):
    g = 2
    b, h, w, c = x.shape
    x = tf.reshape(x, [-1, h, w, g, c // g])
    x = tf.transpose(x, perm = [0, 1, 2, 4, 3])
    x = tf.reverse(x,[-1])
    x = tf.reshape(x, [-1, h, w, c])
    return x

x = np.ones(shape = (1,2,2,4))
for c in range(4):
    x[:,:,:,c] = c

y = channel_shuffle(x)
print(tf.__version__)
print("start:")
print(x)
print("result:")
print(y)

带输出:

2.3.1
start:
[[[[0. 1. 2. 3.]
   [0. 1. 2. 3.]]

  [[0. 1. 2. 3.]
   [0. 1. 2. 3.]]]]
result:
tf.Tensor(
[[[[2. 0. 3. 1.]
   [2. 0. 3. 1.]]

  [[2. 0. 3. 1.]
   [2. 0. 3. 1.]]]], shape=(1, 2, 2, 4), dtype=float64)

推荐阅读