python - 为什么在我训练我的 RNN 网络以解决分类问题后输出总是相同的?
问题描述
神经网络由一个 LSTM 层、3 个以 relu 为激活函数的全连接层和一个具有 sigmoid 激活函数的输出层组成。我的输入数据的形状是(batch_size、time_step、hidden_units),网络的输入数据是多种多样的,而输出几乎总是相同的(差异很小)。我不知道问题出在哪里。LSTM网络如下:
class RNN_eval(object):
def __init__(
self,
cname,
n_steps,
input_size,
learning_rate,
full1_neurons,
full2_neurons,
full3_neurons,
output_size
):
self.name = cname
self.n_steps = n_steps
self.input_size = input_size
self.lr = learning_rate
self.full1_neurons = full1_neurons
self.full2_neurons = full2_neurons
self.full3_neurons = full3_neurons
self.output_size = output_size
self.w_initializer = tf.random_normal_initializer(0., 0.3)
self.b_initializer = tf.constant_initializer(0.1)
with tf.name_scope('eval_inputs'):
self.s = tf.placeholder(tf.float32, [None, self.n_steps, self.input_size], name='input')
self.q_target = tf.placeholder(tf.float32, [None, output_size], name='q_target')
self.batch_size = tf.placeholder(tf.int32, [], name='batch_size')
with tf.name_scope('eval_LSTM_cell'):
self.add_cell()
with tf.name_scope('eval_hidden_layers'):
self.add_fullconnect_layer1()
self.add_fullconnect_layer2()
self.add_fullconnect_layer3()
self.add_output_layer()
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval)) # 基于Q估计与Q现实,构造loss-function
with tf.variable_scope('train'):
self._train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss) # 进行训练
def add_cell(self):
with tf.name_scope('eval_lstm'):
eval_lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=self.input_size, state_is_tuple=True, name='eval_lstm')
with tf.name_scope('eval_initial_state'):
self.cell_init_state = eval_lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
eval_lstm_cell, self.s, dtype=tf.float32, time_major=False
)
def add_fullconnect_layer1(self):
l_out_x = tf.unstack(tf.transpose(self.cell_outputs, [1, 0, 2]))
w1 = tf.get_variable('w1', [self.input_size, self.full1_neurons], initializer=self.w_initializer, collections=self.name)
b1 = tf.get_variable('b1', [self.full1_neurons, ], initializer=self.b_initializer, collections=self.name)
with tf.name_scope('eval_full_connected1'):
self.full_out1 = tf.nn.relu(tf.matmul(l_out_x[-1], w1) + b1)
def add_fullconnect_layer2(self):
w2 = tf.get_variable('w2', [self.full1_neurons, self.full2_neurons], initializer=self.w_initializer,collections=self.name)
b2 = tf.get_variable('b2', [self.full2_neurons, ], initializer=self.b_initializer, collections=self.name)
with tf.name_scope('eval_full_connected2'):
self.full_out2 = tf.nn.relu(tf.matmul(self.full_out1, w2) + b2)
def add_fullconnect_layer3(self):
w3 = tf.get_variable('w3', [self.full2_neurons, self.full3_neurons], initializer=self.w_initializer, collections=self.name)
b3 = tf.get_variable('b3', [self.full3_neurons, ], initializer=self.b_initializer, collections=self.name)
with tf.name_scope('eval_full_connected3'):
self.full_out3 = tf.nn.relu(tf.matmul(self.full_out2, w3) + b3)
def add_output_layer(self):
w_out = tf.get_variable('w_out', [self.full3_neurons, self.output_size], initializer=self.w_initializer, collections=self.name)
b_out = tf.get_variable('b_out', [self.output_size, ], initializer=self.b_initializer, collections=self.name)
with tf.name_scope('eval_output'):
self.q_eval = tf.nn.sigmoid(tf.matmul(self.full_out3, w_out) + b_out, name="eval_op")
训练部分如下:(q_target可以看成是输入数据的标签q_eval是前向传播的结果)
if self.learn_step_counter == 0:
feed_dic = {
self.eval_net.s: batch_memory[:, :, :self.n_features],
self.eval_net.q_target: q_target,
self.eval_net.batch_size: 32,
self.eval_net.q_eval: q_eval
}
else:
feed_dic = {
self.eval_net.s: batch_memory[:, :, :self.n_features],
self.eval_net.q_target: q_target,
self.eval_net.batch_size: 32,
self.eval_net.q_eval: q_eval,
self.eval_net.cell_init_state: self.last_state
}
_, self.cost, self.last_state = self.sess.run([self.eval_net._train_op, self.eval_net.loss, self.eval_net.cell_final_state],
feed_dict=feed_dic)
self.cost_his.append(self.cost) # 反向训练
解决方案
推荐阅读
- javascript - 解决 Cloud Functions 中的“TypeError:无法读取未定义的属性‘数据’”
- jquery - JQuery 到 Angular-Typescript
- unit-testing - 在 buildSrc 目录中运行测试用例的问题
- github - 为什么 GitHub 检查不反映 Azure Pipelines 构建状态?
- sql - SQL Server Where 子句性能
- flutter - 当我双击flutter_console.bat时,它是在一秒钟内打开和关闭自己?
- android - 如何使 SwipeRefreshLayout wrap_content?
- spring - 可以将 Spring 指标从 Micrometer 导出到 Kafka 吗?
- javascript - 如何在使用 jquery/javascript 时打破行尾?