python - 为什么tensorflow中的这个RNN不学习?
问题描述
我正在尝试在 Python 3.7 的 tensorflow (2) 中不使用 RNN API 来训练 RNN,因此代码非常基础。确实出了点问题,但我不确定是什么。
作为参考,我使用的是来自这个 tensorflow 教程的数据集,所以我知道错误应该大致收敛到什么程度。我的 RNN 代码如下。它试图做的是使用前 20 个时间步来预测第 21 个时间步的序列值。我正在批量训练 256 号。
虽然随着时间的推移损失会减少,但如果我遵循教程方法,上限大约是 10 倍。随着时间的推移,反向传播会不会有问题?
state_size = 20 #dimensionality of the network
BATCH_SIZE = 256
#define recurrent weights and biases. W has 1 more dimension that the state
#dimension as also processes the inputs
W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)
#weights and biases for the output
W2 = tf.Variable(np.random.rand(state_size, 1),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,1)), dtype=tf.float32)
init_state = tf.Variable(np.random.normal(size=[BATCH_SIZE,state_size]),dtype='float32')
optimizer = tf.keras.optimizers.Adam(1e-3)
losses = []
for epoch in range(20):
with tf.GradientTape() as tape:
loss = 0
for batch_idx in range(200):
current_state = init_state
batchx = x_train_uni[batch_idx*BATCH_SIZE:(batch_idx+1)*BATCH_SIZE].swapaxes(0,1)
batchy = y_train_uni[batch_idx*BATCH_SIZE:(batch_idx+1)*BATCH_SIZE]
#forward pass through the timesteps
for x in batchx:
inst = tf.concat([current_state,x],1) #concatenate state and inputs for that timepoint
current_state = tf.tanh(tf.matmul(inst, W) + b) #
#predict using the hidden state after the full forward pass
pred = tf.matmul(current_state,W2) + b2
loss += tf.reduce_mean(tf.abs(batchy-pred))
#get gradients with respect to parameters
gradients = tape.gradient(loss, [W,b,W2,b2])
#apply gradients
optimizer.apply_gradients(zip(gradients, [W,b,W2,b2]))
losses.append(loss)
print(loss)
解决方案
推荐阅读
- sql-server - SQL Server 2016 - 开发者版开发 - 标准版生产
- windows - Julia stdin 错误与 WIndows 中的 Vim 和 AsyncRun
- php - 带有子查询的 Doctrine QueryBuilder
- jenkins - Jenkinsfile 管道访问“全局密码”
- regex - 用于目标的 Google Analytics Reg Ex
- linux-kernel - cgroups blkio 子系统未正确计算容器应用程序的块写入字节数
- javascript - 避免回调地狱和组织 Node.js 代码
- batch-file - 使用批处理脚本打开文件
- android - android studio 在 Viewflipper 中给出 OutOfMemoryError
- java - JPA 工具正在生成具有数据库中“旧”/已删除列的实体