tensorflow - 在自定义损失函数中使用梯度(tensorflow+keras)
问题描述
在 tensorflow/keras 中是否有一种自然/简单的方法来实现自定义损失函数,该函数使用模型输出相对于模型输入的导数?
我想到了一些类似的东西(不要介意实际的公式——这只是一个演示):
def my_loss_function(y_desired, y_model, x):
return abs(y_desired - y_model) + abs(tf.gradients(y_model,x[0]))
这有两个问题。首先是损失函数通常无法访问输入。我的理解是,我可以通过直接引用Input
层来解决这个问题,例如作为全局变量。(当然,所有代码都只是示意图。)
input_layer = Layer(...)
def my_loss(y1, y2):
return abs(y1-y2)*input_layer[0]
input_layer
第二个问题更为严重:在执行图中似乎无法访问关于 的梯度。
这里有一个非常相似的问题,没有解决方案:Custom loss function involved gradients in Keras/Tensorflow。我已经按照相同的思路进行了尝试,但没有运气。(对我来说,这是正确的方法并不明显,而不是说,Layer
以始终跟踪导数的方式包装 s 。)
解决方案
我无法使用自动fit
方法实施培训。但是,它当然可以通过手动编写循环来完成。我将提供仅使用梯度来学习函数的示例。
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape(persistent=True) as tape:
# Create tensor that you will watch
x_tensor = tf.convert_to_tensor(x_batch_train, dtype=tf.float64)
tape.watch(x_tensor)
# Feed forward
output = model(x_tensor, training=True)
# Gradient and the corresponding loss function
o_x = tape.gradient(output, x_tensor)
loss_value = loss_fn(y_batch_train, o_x)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %s samples" % ((step + 1) * 64))
loss_fn
在这种情况下很简单
loss_fn = tf.keras.losses.MeanSquaredError()
请记住,您正在使用二阶导数来训练函数,并且with tf.GradientTape(persistent=True) as tape:
会产生一个警告,这对于这种情况是可以的。同样根据我的经验,这种方法对激活函数的选择特别敏感。ReLU 的连续可微变体可能是要走的路。
推荐阅读
- terraform - Terraform 定义的任务角色对于 ECS 计划任务无法正常工作
- c# - WinAPI:模拟完整的按键(按键向下和向上)
- r - 使用数据库中的数据(不同的表)定义 mlr3 任务?
- javascript - NextJS:恢复输入字段中的占位符文本
- database - 为什么 PostgreSQL 索引不包含可见性信息?
- python - 如何在此字符串列表中保留名字和姓氏之间的空格?
- python - 使用 Python 解析未知数据类型对象 Json-like
- android-studio - Android Emulator 卡在“关机”状态
- javascript - 取消静音成员后,Discord bot 不返回角色
- python - TypeError:当从 json 计算温度平均值时,'float' 对象不可下标