tensorflow - 如何通过 Keras 1 和 Keras 2 使用 Conv1D 获得相同的结果
问题描述
我首先使用 keras 1.2.0 在 CPU 上运行相同的代码(具有相同的数据),然后在两个代码中使用 keras 2.0.3 keras 与 TensorFlow 后端。问题来自 Conv1D,许多参数发生变化,我想从 keras 1.2.0 重现相同的 Conv1D。因为我与 keras 2 没有相同的结果。
这是我在 Keras 1 上的代码:
def core_model_CNN(sequence_input
,sequence_length
,vocabulary_size
,n_out
,embedding_dim
,embedding_matrix
,filter_sizes = [1,2,3]
,num_filters = 100
,drop = 0.1) :
embedding = Embedding(input_dim=vocabulary_size, output_dim=embedding_dim
, input_length=sequence_length,weights=[embedding_matrix],trainable=False)
embedded_sequences = embedding(sequence_input)
filter_sizes = filter_sizes
convs = []
for fsz in filter_sizes:
conv = Conv1D(nb_filter=32,
filter_length=fsz,
border_mode='valid',
activation='relu',
subsample_length=1)(embedded_sequences)
pool = MaxPooling1D(pool_length=sequence_length-fsz+1)(conv)
flattenMax = Flatten()(pool)
convs.append(flattenMax)
l_merge = concatenate(convs, axis=1)
#flatten = Flatten()(l_merge)
dense1= Dense(300,activation='swish')(l_merge)
dense1=BatchNormalization()(dense1)
dense1= Dense(250,activation='swish')(dense1)
dense1=BatchNormalization()(dense1)
dense1= Dense(200,activation='swish')(dense1)
dense1=BatchNormalization()(dense1)
dense1= Dense(150,activation='swish')(dense1)
dense1=BatchNormalization()(dense1)
dense1= Dense(100,activation='swish')(dense1)
dense1=BatchNormalization()(dense1)
output = Dense(units=n_out, activation='softmax',kernel_regularizer=regularizers.l2(2),)(dense1)
return output
在 Keras 2 上它变成
conv = tf.keras.layers.Conv1D(filters=params['nb_filter'],
kernel_size=fsz,
padding='valid',
activation='relu',
bias_initializer='zeros',
strides=1)(embedded_sequences)
pool = tf.keras.layers.MaxPooling1D(pool_size=sequence_length-fsz+1)(conv)
我按照此链接进行了修改,但我没有相同的结果。
解决方案
推荐阅读
- java - 从groovy中的JSON数组中提取特定数据
- oracle - SQL获取列的总数或总和
- c++ - 无法访问公共功能?没有成员命名?
- python - 如何从每行中的字符串中提取年份并使用这些年份生成新行
- cookies - 与 PC 上的所有用户共享 Win10 IE 11 cookie
- r - 我需要帮助弄清楚为什么我的正则表达式与我正在寻找的内容不匹配
- javascript - 如何断言复制到剪贴板的值是正确的?
- amazon-web-services - 为什么将我的云形成设计器保存到 S3 存储桶会挂起?
- matlab - S-Function Level 2 C Simulink (R2006b) , 从参数中获取字符串
- javascript - 具有高阶函数的递归