首页 > 解决方案 > Tensorflow 和 Keras 实现之间的比较(第 1 部分:模型)

问题描述

我正在尝试使用 Keras 重写 Tensorflow 网络。Tensorflow 中的模型定义为

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def leaky_relu(x, alpha=0.2):
  return tf.nn.relu(x) - alpha * tf.nn.relu(-x)

X = tf.placeholder(tf.float32, shape=[None, 9, 15])
W1 = tf.Variable(xavier_init([135, 128]))
b1 = tf.Variable(tf.zeros(shape=[128]))
W11 = tf.Variable(xavier_init([128, 256]))
b11 = tf.Variable(tf.zeros(shape=[256]))
W12 = tf.Variable(xavier_init([256, 512]))
b12 = tf.Variable(tf.zeros(shape=[512]))
W13 = tf.Variable(xavier_init([512, 45]))
b13 = tf.Variable(tf.zeros(shape=[45]))

W2 = tf.Variable(xavier_init([135, 128]))
b2 = tf.Variable(tf.zeros(shape=[128]))
W21 = tf.Variable(xavier_init([128, 256]))
b21 = tf.Variable(tf.zeros(shape=[256]))
W22 = tf.Variable(xavier_init([256, 512]))
b22 = tf.Variable(tf.zeros(shape=[512]))
W23 = tf.Variable(xavier_init([512, 540]))
b23 = tf.Variable(tf.zeros(shape=[540]))

def fcn(x):
    out1 = tf.reshape(x, (-1, 135))
    out1 = leaky_relu(tf.matmul(out1, W1) + b1)
    out1 = leaky_relu(tf.matmul(out1, W11) + b11)
    out1 = leaky_relu(tf.matmul(out1, W12) + b12)
    out1 = leaky_relu(tf.matmul(out1, W13) + b13)
    out1 = tf.reshape(out1, (-1, 9, 5))

    out2 = tf.reshape(x, (-1, 135))
    out2 = leaky_relu(tf.matmul(out2, W2) + b2)
    out2 = leaky_relu(tf.matmul(out2, W21) + b21)
    out2 = leaky_relu(tf.matmul(out2, W22) + b22)
    out2 = leaky_relu(tf.matmul(out2, W23) + b23)
    out2 = tf.reshape(out2, [-1, 9, 4, 15])
    out2 = leaky_relu(tf.matmul(tf.transpose(out2, perm=[0, 2, 1, 3]), tf.transpose(out2, perm=[0, 2, 3, 1])))
    out2 = tf.transpose(out2, perm=[0, 2, 3, 1])
    return [out1, out2]

我已经“翻译”了这个,这是我的 Keras 实现

def keras_version():
    input = Input(shape=(135,), name='feature_input')
    out1 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(45, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Reshape((9, 5))(out1)

    out2 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((9, 4, 15))(out2)
    out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
    out2 = Flatten()(out2)
    out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
    # K.dot should be of size (-1, 4, 9, 9), so I set the output size to 324, and later on, reshaped data
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((4, 9, 9))(out2)
    out2 = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)))(out2)

    out1 = Lambda(identity, name='output_1')(out1)
    out2 = Lambda(identity, name='output_2')(out2)

    return Model(input, [out1, out2])

我想知道这个实现是否正确,特别是:

  1. 定义层维度的方式。
  2. 权重的初始化方式。
  3. 矩阵乘法的方式被展平,并重新整形。

如果您能指出是否存在错误实施或我没有正确理解的内容,我将不胜感激。


编辑:这是摘要:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
feature_input (InputLayer)      (None, 135)          0                                            
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 128)          17408       feature_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128)          0           dense_5[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 256)          33024       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 256)          0           dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 512)          131584      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 512)          0           dense_7[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          17408       feature_input[0][0]              
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 540)          277020      leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 128)          0           dense_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 540)          0           dense_8[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 256)          33024       leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 9, 4, 15)     0           leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 256)          0           dense_2[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 4, 9, 9)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 512)          131584      leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 324)          0           lambda_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 512)          0           dense_3[0][0]                    
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 324)          105300      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 45)           23085       leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 324)          0           dense_9[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 45)           0           dense_4[0][0]                    
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 4, 9, 9)      0           leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 9, 5)         0           leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 9, 9, 4)      0           reshape_3[0][0]                  
__________________________________________________________________________________________________
output_1 (Lambda)               (None, 9, 5)         0           reshape_1[0][0]                  
__________________________________________________________________________________________________
output_2 (Lambda)               (None, 9, 9, 4)      0           lambda_2[0][0]                   
==================================================================================================
Total params: 769,437
Trainable params: 769,437
Non-trainable params: 0
__________________________________________________________________________________________________

标签: pythontensorflowkeras

解决方案


推荐阅读