python - Tensorflow 和 Keras 实现之间的比较(第 1 部分:模型)
问题描述
我正在尝试使用 Keras 重写 Tensorflow 网络。Tensorflow 中的模型定义为
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
def leaky_relu(x, alpha=0.2):
return tf.nn.relu(x) - alpha * tf.nn.relu(-x)
X = tf.placeholder(tf.float32, shape=[None, 9, 15])
W1 = tf.Variable(xavier_init([135, 128]))
b1 = tf.Variable(tf.zeros(shape=[128]))
W11 = tf.Variable(xavier_init([128, 256]))
b11 = tf.Variable(tf.zeros(shape=[256]))
W12 = tf.Variable(xavier_init([256, 512]))
b12 = tf.Variable(tf.zeros(shape=[512]))
W13 = tf.Variable(xavier_init([512, 45]))
b13 = tf.Variable(tf.zeros(shape=[45]))
W2 = tf.Variable(xavier_init([135, 128]))
b2 = tf.Variable(tf.zeros(shape=[128]))
W21 = tf.Variable(xavier_init([128, 256]))
b21 = tf.Variable(tf.zeros(shape=[256]))
W22 = tf.Variable(xavier_init([256, 512]))
b22 = tf.Variable(tf.zeros(shape=[512]))
W23 = tf.Variable(xavier_init([512, 540]))
b23 = tf.Variable(tf.zeros(shape=[540]))
def fcn(x):
out1 = tf.reshape(x, (-1, 135))
out1 = leaky_relu(tf.matmul(out1, W1) + b1)
out1 = leaky_relu(tf.matmul(out1, W11) + b11)
out1 = leaky_relu(tf.matmul(out1, W12) + b12)
out1 = leaky_relu(tf.matmul(out1, W13) + b13)
out1 = tf.reshape(out1, (-1, 9, 5))
out2 = tf.reshape(x, (-1, 135))
out2 = leaky_relu(tf.matmul(out2, W2) + b2)
out2 = leaky_relu(tf.matmul(out2, W21) + b21)
out2 = leaky_relu(tf.matmul(out2, W22) + b22)
out2 = leaky_relu(tf.matmul(out2, W23) + b23)
out2 = tf.reshape(out2, [-1, 9, 4, 15])
out2 = leaky_relu(tf.matmul(tf.transpose(out2, perm=[0, 2, 1, 3]), tf.transpose(out2, perm=[0, 2, 3, 1])))
out2 = tf.transpose(out2, perm=[0, 2, 3, 1])
return [out1, out2]
我已经“翻译”了这个,这是我的 Keras 实现
def keras_version():
input = Input(shape=(135,), name='feature_input')
out1 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
out1 = LeakyReLU(alpha=.2)(out1)
out1 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out1)
out1 = LeakyReLU(alpha=.2)(out1)
out1 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out1)
out1 = LeakyReLU(alpha=.2)(out1)
out1 = Dense(45, kernel_initializer='glorot_normal', activation='linear')(out1)
out1 = LeakyReLU(alpha=.2)(out1)
out1 = Reshape((9, 5))(out1)
out2 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((9, 4, 15))(out2)
out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
out2 = Flatten()(out2)
out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
# K.dot should be of size (-1, 4, 9, 9), so I set the output size to 324, and later on, reshaped data
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((4, 9, 9))(out2)
out2 = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)))(out2)
out1 = Lambda(identity, name='output_1')(out1)
out2 = Lambda(identity, name='output_2')(out2)
return Model(input, [out1, out2])
我想知道这个实现是否正确,特别是:
- 定义层维度的方式。
- 权重的初始化方式。
- 矩阵乘法的方式被展平,并重新整形。
如果您能指出是否存在错误实施或我没有正确理解的内容,我将不胜感激。
编辑:这是摘要:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
feature_input (InputLayer) (None, 135) 0
__________________________________________________________________________________________________
dense_5 (Dense) (None, 128) 17408 feature_input[0][0]
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128) 0 dense_5[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 256) 33024 leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256) 0 dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 512) 131584 leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 512) 0 dense_7[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 128) 17408 feature_input[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 540) 277020 leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 128) 0 dense_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 540) 0 dense_8[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 256) 33024 leaky_re_lu_1[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 9, 4, 15) 0 leaky_re_lu_8[0][0]
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 256) 0 dense_2[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 4, 9, 9) 0 reshape_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 512) 131584 leaky_re_lu_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 324) 0 lambda_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 512) 0 dense_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 324) 105300 flatten_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 45) 23085 leaky_re_lu_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 324) 0 dense_9[0][0]
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 45) 0 dense_4[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape) (None, 4, 9, 9) 0 leaky_re_lu_9[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 9, 5) 0 leaky_re_lu_4[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 9, 9, 4) 0 reshape_3[0][0]
__________________________________________________________________________________________________
output_1 (Lambda) (None, 9, 5) 0 reshape_1[0][0]
__________________________________________________________________________________________________
output_2 (Lambda) (None, 9, 9, 4) 0 lambda_2[0][0]
==================================================================================================
Total params: 769,437
Trainable params: 769,437
Non-trainable params: 0
__________________________________________________________________________________________________
解决方案
推荐阅读
- amazon-web-services - 我们可以使用 terraform 从自定义 AMI 启动实例吗
- jenkins - 使用 cypress 上传文件在 Jenkins 管道中不起作用?
- ios - 隐私政策 Apple(Facebook 登录)。苹果报告的问题
- java-8 - 在 allOf Future 完成后,希望异步运行另一个未来循环
- android - 将颤振模块添加到现有的 android studio 应用程序
- c# - 更改 Crystal Reports 的数据库源
- c - 返回函数'str'有什么问题?谁能解释在函数中传递字符串后如何返回字符串
- javascript - 过滤数组的难点理解 - Eloquent Javascript 第 5 章
- vue.js - 状态栏如何导出excel
- python - 将字符串列表添加到 input 元素下的下拉元素中,以使搜索引擎具有自动完成功能