python - Tensorflow 自定义正则化项将预测值与真实值进行比较
问题描述
您好,我需要一个自定义正则化术语来添加到我的(二进制交叉熵)损失函数中。有人可以帮助我使用 Tensorflow 语法来实现这一点吗?我尽可能地简化了一切,以便更容易帮助我。
该模型将 18 x 18 二进制配置的数据集 10000 作为输入,并将 16x16 的配置集作为输出。神经网络仅包含 2 个卷积层。
我的模型如下所示:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
EPOCHS = 10
model = models.Sequential()
model.add(layers.Conv2D(1,2,activation='relu',input_shape=[18,18,1]))
model.add(layers.Conv2D(1,2,activation='sigmoid',input_shape=[17,17,1]))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),loss=tf.keras.losses.BinaryCrossentropy())
model.fit(initial.reshape(10000,18,18,1),target.reshape(10000,16,16,1),batch_size = 1000, epochs=EPOCHS, verbose=1)
output = model(initial).numpy().reshape(10000,16,16)
现在我写了一个函数,我想用它作为一个附加的正则化术语来作为一个正则化术语。该函数采用真值和预测值。基本上,它将两者的每一点与它的“正确”邻居相乘。然后取差值。我假设真实和预测项是 16x16(而不是 10000x16x16)。这个对吗?
def regularization_term(prediction, true):
order = list(range(1,4))
order.append(0)
deviation = (true*true[:,order]) - (prediction*prediction[:,order])
deviation = abs(deviation)**2
return 0.2 * deviation
我真的很感激能在我的损失中添加类似这个函数的东西作为正则化项,以帮助神经网络更好地训练这种“正确的邻居”交互。我真的很努力使用可定制的 Tensorflow 功能。谢谢,非常感谢。
解决方案
这很简单。您需要指定自定义损失,在其中定义添加正则化项。像这样的东西:
# to minimize!
def regularization_term(true, prediction):
order = list(range(1,4))
order.append(0)
deviation = (true*true[:,order]) - (prediction*prediction[:,order])
deviation = abs(deviation)**2
return 0.2 * deviation
def my_custom_loss(y_true, y_pred):
return tf.keras.losses.BinaryCrossentropy()(y_true, y_pred) + regularization_term(y_true, y_pred)
model.compile(optimizer='Adam', loss=my_custom_loss)
正如keras所说:
任何带有签名 loss_fn(y_true, y_pred) 且返回损失数组(输入批次中的样本之一)的可调用函数都可以作为损失传递给 compile()。请注意,任何此类损失都会自动支持样本加权。
所以一定要返回一个损失数组(编辑:正如我现在所看到的,它也可以返回一个简单的标量。如果你使用例如reduce函数没关系)。基本上 y_true 和 y_predicted 将批量大小作为第一个维度。
推荐阅读
- python - Python:函数总是从循环中返回零
- php - 使用 XAMPP 或 Wamp 创建本地 WordPress 站点 ..但出现此错误..建立数据库连接时出错
- pytorch - 声明 PyTorch 张量时有关尺寸的 ValuerError
- json - 如何从json文件制作数据框表
- php - SSL 证书配置错误
- r - 如何按时间对 R 中的 XTS 对象进行分组?
- docker - Docker 占用所有空间 100%
- python - 如何将 2 因素身份验证与 Google Oauth Login 集成?
- javascript - 使用 ajax 加载动态数据时不显示模态窗口
- delphi - 在 Windows 10 中使用时,Delphi 7 opendialog 的文件名中有垃圾