python - 在 tensorflow(或 keras)中应用通道洗牌
问题描述
我正在尝试在 tensorflow(或 keras)中实现通道洗牌功能。我找到了这个实现,但它似乎是错误的,因为我认为它是基于这个pytorch 实现的。
我已经设法做到了,concatenate()
但我想要一个使用permute_dimensions()
. 另外,我不确定连接版本是否较慢(如果有人能回答这个问题,我将不胜感激)。
一个工作的张量流实现使用concatenate()
:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import models
import numpy as np
a = tf.constant([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
sess = tf.Session()
print('x', sess.run(a))
groups = 2 # separate into 2 group
h, w, in_channel = K.int_shape(a)[1:]
l = K.reshape(a, [-1, h, w, in_channel // groups, groups])
m = K.concatenate((l[..., 1], l[..., 0]))
l = K.reshape(m, [-1, h, w, in_channel])
print('y', sess.run(l))
输出:
x [[[[ 1 2]
[ 3 4]
[ 5 6]]
[[ 7 8]
[ 9 10]
[11 12]]]]
y [[[[ 2 1]
[ 4 3]
[ 6 5]]
[[ 8 7]
[10 9]
[12 11]]]]
keras 非工作实现如下:
def channel_shuffle(x):
g = 2
b, h, w, c = x.shape.as_list()
x = K.reshape(x, [-1, h, w, g, c // g])
x = K.permute_dimensions(x, (0, 1, 2, 4, 3))
x = K.reshape(x, [-1, h, w, c])
return x
input_shape = (2, 3, 2)
x = np.array([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
image_input = layers.Input(input_shape)
permuted_x = channel_shuffle4(image_input)
model = models.Model(inputs=[image_input], outputs=permuted_x)
y = model.predict(x)
print('x', x)
print('y', y)
输出:
x [[[[ 1 2]
[ 3 4]
[ 5 6]]
[[ 7 8]
[ 9 10]
[11 12]]]]
y [[[[ 1. 2.]
[ 3. 4.]
[ 5. 6.]]
[[ 7. 8.]
[ 9. 10.]
[11. 12.]]]]
这显然根本不会改变输入数据。那么,我怎样才能达到预期的效果呢?基本上我应该交换哪些轴?我做了一些实验,但似乎找不到合适的。
解决方案
您需要在 之后执行最后一个通道的反转permute_dimensions
。permute_dimensions
是一样的tf.transpose
。这是一个直接作用于张量的解决方案:
import tensorflow as tf
import numpy as np
def channel_shuffle(x):
g = 2
b, h, w, c = x.shape
x = tf.reshape(x, [-1, h, w, g, c // g])
x = tf.transpose(x, perm = [0, 1, 2, 4, 3])
x = tf.reverse(x,[-1])
x = tf.reshape(x, [-1, h, w, c])
return x
x = np.ones(shape = (1,2,2,4))
for c in range(4):
x[:,:,:,c] = c
y = channel_shuffle(x)
print(tf.__version__)
print("start:")
print(x)
print("result:")
print(y)
带输出:
2.3.1
start:
[[[[0. 1. 2. 3.]
[0. 1. 2. 3.]]
[[0. 1. 2. 3.]
[0. 1. 2. 3.]]]]
result:
tf.Tensor(
[[[[2. 0. 3. 1.]
[2. 0. 3. 1.]]
[[2. 0. 3. 1.]
[2. 0. 3. 1.]]]], shape=(1, 2, 2, 4), dtype=float64)
推荐阅读
- sublimetext3 - Sublime 3 高亮问题
- vue.js - 如何使用 Vue 检测浏览器后退按钮?
- sql - Azure Synapse - 检索插入的行标识值
- react-native - 活动指示器不包含 React-Native 上的标头
- python - Python Tkinter,如何同步多个闪烁的小部件?
- python - 如果使用 jinja 匹配,则从列表创建列表
- javascript - 跨多个组件重用 React.useCallback() 函数
- apache-spark - 分区 DataFrame 时 AWS Glue Spark 作业无法扩展
- c++ - 以下 C++ 代码中实现的 DCL(双重检查锁定)是否是线程安全的?
- sql - BigQuery:对象数组中的最低时间戳和特定时间戳之间的平均值