python-3.x - PyTorch 到 Keras 模型
问题描述
我正在尝试复制一个模型,但我在使用 Keras 时遇到了困难。这是我当前的实现:
filters = 256
kernel_size = 3
strides = 1
# Head module
input = Input(shape=(img_height//scale_fact, img_width//scale_fact, img_depth))
conv0 = Conv2D(filters, kernel_size, strides=strides, padding='same',
kernel_regularizer=regularizers.l2(0.01))(input)
# Body module
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv0)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv0, res])
for i in range(res_blocks):
res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
act = ReLU()(res1)
res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([res_rec, res2])
conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
kernel_regularizer=regularizers.l2(0.01))(res_rec)
add = Add()([conv0, conv])
# Tail module
conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
kernel_regularizer=regularizers.l2(0.01))(add)
act = ReLU()(conv)
up = UpSampling2D(size=scale_fact if scale_fact != 4 else 2)(act) # TODO: try "Conv2DTranspose"
# mul = Multiply([np.zeros((img_width,img_height,img_depth)).fill(0.1), up])(up)
# When it's a 4X factor, we want the upscale split in two procedures
if(scale_fact == 4):
conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
kernel_regularizer=regularizers.l2(0.01))(up)
act = ReLU()(conv)
up = UpSampling2D(size=2)(act) # TODO: try "Conv2DTranspose"
output = Conv2D(filters=3,
kernel_size=1,
strides=1,
padding='same',
kernel_regularizer=regularizers.l2(0.01))(up)
model = Model(inputs=input, outputs=output)
这是我要复制的文件的链接。我应该如何复制这个UpSampler
实现自定义 PixelShuffling 方法的自定义 PyTorch?
UpSampler
在大多数情况下,这是我遇到麻烦的相关部分:
import tensorflow as tf
import tensorflow.contrib.slim as slim
"""
Method to upscale an image using
conv2d transpose. Based on upscaling
method defined in the paper
x: input to be upscaled
scale: scale increase of upsample
features: number of features to compute
activation: activation function
"""
def upsample(x,scale=2,features=64,activation=tf.nn.relu):
assert scale in [2,3,4]
x = slim.conv2d(x,features,[3,3],activation_fn=activation)
if scale == 2:
ps_features = 3*(scale**2)
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
x = PS(x,2,color=True)
elif scale == 3:
ps_features =3*(scale**2)
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,9,stride=1,activation_fn=activation)
x = PS(x,3,color=True)
elif scale == 4:
ps_features = 3*(2**2)
for i in range(2):
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
x = PS(x,2,color=True)
return x
"""
Borrowed from https://github.com/tetrachrome/subpixel
Used for subpixel phase shifting after deconv operations
"""
def _phase_shift(I, r):
bsize, a, b, c = I.get_shape().as_list()
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(X, a, 1) # a, [bsize, b, r, r]
X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, b, a*r, r
X = tf.split(X, b, 1) # b, [bsize, a*r, r]
X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
"""
Borrowed from https://github.com/tetrachrome/subpixel
Used for subpixel phase shifting after deconv operations
"""
def PS(X, r, color=False):
if color:
Xc = tf.split(X, 3, 3)
X = tf.concat([_phase_shift(x, r) for x in Xc],3)
else:
X = _phase_shift(X, r)
return X
解决方案
推荐阅读
- reactjs - `className` 道具和新的 MUI 系统实用程序 `sx` 道具之间的区别?
- flutter - 在颤振中加载设计
- firebase - 如何仅获取特殊字段数据firebase实时数据库反应本机
- flutter - 在 dart 中将秒转换为 mm:ss
- selenium-webdriver - 无论我给出什么代码,只有 google.com 才能打开 chrome
- flatbuffers - 为什么 flatbuffer 结构字段不能是向量/表/字符串?
- c# - 如何从命令行运行已编译的 Blazor Web 客户端?
- excel - 按顺序创建包含内容的电子邮件:文本、图像、文本、图像、文本、签名
- c++ - 为什么指针在循环的所有迭代中一直指向相同的地址?
- stored-procedures - 雪花存储过程动态列枢轴