python - 为什么我的再训练模型的准确性很差?
问题描述
我正在尝试使用相同的数据集(MNIST handrwitten digit dataset)重新训练预训练模型的最后一层,但重新训练模型的准确性比初始模型差得多。我的初始模型的准确度约为 98%,而重新训练的模型准确度根据运行情况在 40-80% 之间变化。当我根本不费心训练前两层时,我得到了类似的结果。
和代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
epochs1 = 150
epochs2 = 300
batch_size = 11000
learning_rate1 = 1e-3
learning_rate2 = 1e-4
# Base model
def base_model(input, reuse=False):
with tf.variable_scope('base_model', reuse=reuse):
layer1 = tf.contrib.layers.fully_connected(input, 300)
features = tf.contrib.layers.fully_connected(layer1, 300)
return features
mnist = input_data.read_data_sets('./mnist/', one_hot=True)
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
features1 = base_model(image, reuse=False)
features2 = base_model(image, reuse=True)
# Logits1 trained with the base model
with tf.variable_scope('logits1', reuse=False):
logits1 = tf.contrib.layers.fully_connected(features1, 10, tf.nn.relu)
# Logits2 trained while the base model is frozen
with tf.variable_scope('logits2', reuse=False):
logits2 = tf.contrib.layers.fully_connected(features2, 10, tf.nn.relu)
# Var Lists
var_list_partial1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='logits1')
var_list_partial2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='base_model')
var_list1 = var_list_partial1 + var_list_partial2
var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='logits2')
# Sanity check
print("var_list1:", var_list1)
print("var_list2:", var_list2)
# Cross Entropy Losses
loss1 = tf.nn.softmax_cross_entropy_with_logits(logits=logits1, labels=label)
loss2 = tf.nn.softmax_cross_entropy_with_logits(logits=logits2, labels=label)
# Train the final logits layer
train1 = tf.train.AdamOptimizer(learning_rate1).minimize(loss1, var_list=var_list1)
train2 = tf.train.AdamOptimizer(learning_rate2).minimize(loss2, var_list=var_list2)
# Accuracy operations
correct_prediction1 = tf.equal(tf.argmax(logits1, 1), tf.argmax(label, 1))
correct_prediction2 = tf.equal(tf.argmax(logits2, 1), tf.argmax(label, 1))
accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, "float"))
accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, "float"))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
batches = int(len(mnist.train.images) / batch_size)
# Train base model and logits1
for epoch in range(epochs1):
for batch in range(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train1, feed_dict={image: batch_xs, label: batch_ys})
# Train logits2 keeping the base model frozen
for epoch in range(epochs2):
for batch in range(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train2, feed_dict={image: batch_xs, label: batch_ys})
# Print the both models after training
accuracy = sess.run(accuracy1, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Initial Model Accuracy After training final model:", accuracy)
accuracy = sess.run(accuracy2, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Final Model Accuracy After Training:", accuracy)
提前致谢!
解决方案
尝试从“logits1”和“logits2”中消除非线性。
我将您的代码更改为:
# Logits1 trained with the base model
with tf.variable_scope('logits1', reuse=False):
#logits1 = tf.contrib.layers.fully_connected(features1, 10, tf.nn.relu)
logits1 = tf.contrib.layers.fully_connected(features1, 10, None)
# Logits2 trained while the base model is frozen
with tf.variable_scope('logits2', reuse=False):
#logits2 = tf.contrib.layers.fully_connected(features2, 10, tf.nn.relu)
logits2 = tf.contrib.layers.fully_connected(features2, 10, None)
结果更改为:
Initial Model Accuracy After training final model: 0.9805
Final Model Accuracy After Training: 0.9658
PS 而且 300 + 300 个神经元对于 MNIST 分类器来说太多了,但我认为你的意思不是对 MNIST 进行分类 :)
推荐阅读
- javascript - 如何访问对应的 Vue-component rsp。ThreeJs Raycaster 的用户数据?
- google-bigquery - 在 Dataproc 中为 BQ 数据初始化 sql 查询时出现错误
- android-studio - 使用当前代码向 Kotlin okHttp retrofit2 中的每个请求添加自定义标头
- google-apps-script - App Scripts 用相同级别的新列表项替换 Google Doc 中的字符串,而不更改列表样式
- vba - Wscript.Shell Exec 中的“编码”一词确实会阻止 powershell 脚本运行。我能做些什么?
- python - 调用 export_png 时未反映散景图的更改(例如缩放)
- docker - 在无根 docker 中运行图像,图像未显示
- c++ - 文件创建失败,因为我无法添加名称
- vba - Excel 宏需要多次执行才能正常运行
- flutter - 颤振网;Riverpod Firebase 身份验证;热重启/重新加载 - 用户会话丢失?