python - Tensorflow:tape.gradient() 为 GRU 层返回 None
问题描述
我使用以下代码(tensorflow==1.14)构建我的模型:
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.embedding = tf.keras.layers.Embedding(10, 5)
self.rnn = tf.keras.layers.GRU(100) # neither GRU nor LSTM works
self.final_layer = tf.keras.layers.Dense(10)
self.loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def call(self, inp):
inp_em = self.embedding(inp) # (batch_size, seq_len, embedding_size)
inp_enc = self.rnn(inp_em) # (batch_size, hidden_size)
logits = self.final_layer(inp_enc) # (batch_size, class_num)
return logits
model = Model()
inp = np.random.randint(0, 10, [5, 50], dtype=np.int32)
out = np.random.randint(0, 10, [5], dtype=np.int32)
with tf.GradientTape() as tape:
logits = model(inp)
loss = model.loss_obj(out, logits)
print(loss)
gradients = tape.gradient(tf.reduce_mean(loss), model.trainable_variables)
print('========== Trainable Variables ==========')
for v in model.trainable_variables:
print(v)
print('========== Gradients ==========')
for g in gradients:
print(g)
但是当我打印网格时,输出是:
Tensor("categorical_crossentropy/weighted_loss/Mul:0", shape=(5,), dtype=float32)
========== Trainable Variables ==========
<tf.Variable 'model/embedding/embeddings:0' shape=(10, 5) dtype=float32>
<tf.Variable 'model/gru/kernel:0' shape=(5, 300) dtype=float32>
<tf.Variable 'model/gru/recurrent_kernel:0' shape=(100, 300) dtype=float32>
<tf.Variable 'model/gru/bias:0' shape=(300,) dtype=float32>
<tf.Variable 'model/dense/kernel:0' shape=(100, 10) dtype=float32>
<tf.Variable 'model/dense/bias:0' shape=(10,) dtype=float32>
========== Gradients ==========
None
None
None
None
Tensor("MatMul:0", shape=(100, 10), dtype=float32)
Tensor("BiasAddGrad:0", shape=(10,), dtype=float32)
最后一层的网格效果很好,但对于 GRU 层等没有。
tf.keras.layers.LSTM
和都试过了tf.keras.layers.GRU
,同样的问题存在。
更新
最后,我替换tf.GradientTape().gradient()
为tf.graidents()
:
logits = model(inp)
loss = model.loss_obj(out, logits)
gradients = tf.gradients(tf.reduce_mean(loss), model.trainable_variables)
渐变有效。但是我仍然不知道这两个工具之间有什么区别。
解决方案
推荐阅读
- java - 在与外部库相同的包中创建一个类,以便使用包私有类
- reactjs - 添加带有反应头盔的 twitter:card
- javascript - 如何在 React 上下文 API 中调用与调用函数并行的函数
- sql - 带有命名空间的 PostgresSQL xpath
- python - 用 NaN 绘图。如何将 NaN 值设置为特定颜色和/或从热图中跳过 NaN
- ios - 通过 Azure Devops Extension 上传到 AppStore 时如何修复“Invalid Provisioning Profile”?
- regex - 如何详细说明接受除空格和特定单词以外的任何字符序列的正则表达式
- html - 在 HTML div 中显示大型 HTML 字符串的最佳方式
- excel - 选择要从 Google 工作表导入 Excel 的页面
- pact - 使用 pact-jvm 生成的合约可以通过 pact-net 或 pact-ruby 进行验证吗?