tensorflow - 为什么 tf.GradientTape.jacobian 没有给出?
问题描述
我正在使用 IRIS 数据集,并且正在关注这个官方教程:自定义培训:演练
在训练循环中,我试图分别收集epoch%50==0
列表中每个模型的输出和权重m_outputs_mod50, gather_weights
:
# Keep results for plotting
train_loss_results = []
train_accuracy_results = []
m_outputs_mod50 = []
gather_weights = []
num_epochs = 201
for epoch in range(num_epochs):
epoch_loss_avg = tf.keras.metrics.Mean()
epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
# gather_kernel(model)
# Training loop - using batches of 32
for x, y in train_dataset:
# Optimize the model
loss_value, grads = grad(model, x, y)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Track progress
epoch_loss_avg.update_state(loss_value) # Add current batch loss
# Compare predicted label to actual label
# training=True is needed only if there are layers with different
# behavior during training versus inference (e.g. Dropout).
epoch_accuracy.update_state(y, model(x, training=True))
# End epoch
train_loss_results.append(epoch_loss_avg.result())
train_accuracy_results.append(epoch_accuracy.result())
# pred_hist.append(model.predict(x))
if epoch % 50 == 0:
m_outputs_mod50.append(model(x))
gather_weights.append(model.weights)
print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,
epoch_loss_avg.result(),
epoch_accuracy.result()))
m_outputs_mod50[0]
运行上述程序并尝试使用 epoch 0 (使用and gather_weights[0]
)获取 jacobian
with tf.GradientTape() as tape:
print(tape.jacobian(target = m_outputs_mod50[0], sources = gather_weights[0]))`
我得到一个无列表作为输出。
为什么?
解决方案
您需要了解 GradientTape 的运作方式。为此,您可以遵循指南:渐变和自动微分简介。这是一段摘录:
TensorFlow 提供了
tf.GradientTape
自动微分的 API;也就是说,计算相对于某些输入的计算梯度,通常是tf.Variables
。TensorFlow 将在 a 上下文中执行的相关操作“记录”tf.GradientTape
到“磁带”上。然后,TensorFlow 使用该磁带来计算使用反向模式微分的“记录”计算的梯度。
要计算梯度(或雅可比),磁带需要记录在其上下文中执行的操作。然后,在其上下文之外,一旦执行了前向传递,就可以使用磁带来计算梯度/雅可比。
你可以使用类似的东西:
if epoch % 50 == 0:
with tf.GradientTape() as tape:
out = model(x)
jacobian = tape.jacobian(out, model.weights)
推荐阅读
- dart - 我如何接收未来值作为字符串
- c# - 将 ExpandoObject 转换为 T 的 AnonymousType
- php - 当我使用全新安装的 IIS 10 访问我的默认网站时,它会将我重定向到“http://localhost/installation/index.php”。为什么?
- javascript - ECharts - 将两种不同的颜色应用于同一轴上的标签
- css - 防止容器的高度随着我们添加记录而增加
- ios - 配置文件包括更新版本的签名证书
- linkedin - 通过 LinkedIn API 缩小发布图像
- jmeter - 如何在 JSR223 PostProcessor 中使用 Java 类和 JMeter API 类
- excel - Excel VBA 如何将多个控件和变量转换为通用函数
- javascript - Vujs在组件模板中迭代不起作用