tensorflow - 分别构建在 tensorflow API 和 keras API 上的两个代码块有什么不同?我的计算结果差距很大
问题描述
我正在建立一个模型来对序列类进行分类。首先,我使用 keras API 构建模型。众所周知,keras API 打包了 tensorflow 函数,但是当我将 keras 代码转换为 tensorflow API 时,我发现两个框架的结果是不同的。下面是关键代码。
张量流代码x = tf.placeholder(tf.int32, shape=[None, time_steps], name='x_input')
y = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_label')
定义网络结构
def rnn_model(x):
x = tf.one_hot(x,api_vob_size)
rnn_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
rnn_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 将输入送入rnn,得到输出与中间状态,输出shape为[batch_size, time_steps, rnn_size]
outputs, states = tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw,rnn_cell_bw, x, dtype=tf.float32)
# 获取最后一个时刻的输出,输出shape为[batch_size, rnn_size]
outputs1 = tf.concat(outputs, 2)
output = tf.transpose(outputs1, [1, 0, 2])[-1]
# 全连接层,最终输出大小为[batch_size, num_classes]
fc_w = tf.Variable(tf.random_normal([2*rnn_size, num_classes]))
fc_b = tf.Variable(tf.random_normal([num_classes]))
return tf.matmul(output, fc_w) + fc_b `
# 构建网络
logits= rnn_model(x)
prediction = tf.nn.softmax(logits)
# 定义损失函数与优化器
loss_op = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits, name='cross_entropy'))
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
train_op = optimizer.minimize(loss_op,name='optimizer_min')
#keras API
model = Sequential()
model.add(Bidirectional(LSTM(units=150), merge_mode='concat'))
model.add(Dense(9, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=64)
那么为什么两个代码块有不同的结果。谢谢你的答案 !!!!
解决方案
推荐阅读
- python - Python PATH 无法识别附加到它的目录
- ethereum - 什么时候需要将 ERC-1155 元数据 URI 补零到 64 个十六进制字符?
- android-studio - 如何减小 Google Play Console 的应用程序大小(我正在使用 Android Studio)
- google-cloud-platform - 缺少权限:尝试从 AI Platform (Google Cloud) 导出模型时的 storage.objects.update
- django - How to Count the number of replies from a comment in a given photo using Django Query?
- php - Symfony PHPUnit问题安装Xdebug for --coverage-html
- java - Java 安全性:在没有 pin 弹出窗口的情况下签名
- r - 将 sum 与 group_by 一起使用时行消失
- python - 什么是最好的python正则表达式,只排除一对大括号之间的一个逗号实例?
- node.js - PostgreSQL 的 TypeORM 性能问题