tensorflow - 是否可以根据 Keras 中批次中样本输出的差异编写自定义损失函数?
问题描述
我正在尝试在 Keras 中实现损失函数,它可以执行以下操作:
假设 y0, y1, ..., yn 是批量输入 x0, x1, ..., xn 的模型批量输出,这里 batch_size 是 n+1,每个 xi 的输出 yi 是一个标量值,我想要什么用于计算该批次的整体损失的损失函数如下:
K.log(K.sigmoid(y1-y0))+K.log(K.sigmoid(y2-y1))+...+K.log(K.sigmoid(yn-yn-1))
我正在考虑使用 Lambda 层首先将批量输出 [y0,y1,...,yn] 转换为 [y1-y0, y2-y1, ...,yn-yn-1],然后使用自定义损失对转换后的输出函数。
但是,我不确定 Keras 是否可以理解 Lambda 层中没有要更新的权重,并且我不清楚Keras 将如何通过 Lambda 层将梯度传播回来,因为 Keras 通常需要每个层/损失函数在其上运行单个样本输入,但我的层将获取一批样本的全部输出。以前有没有人解决过类似的问题?谢谢!
解决方案
像下面这样的切片对你有用吗(虽然我没有使用 keras)。
batch = 4
num_classes = 6
logits = tf.random.uniform(shape=[batch, num_classes])
logits1 = tf.slice(logits, (0, 0), [batch, num_classes-1])
logits2 = tf.slice(logits, (0, 1), [batch, num_classes-1])
delta = logits2 - logits1
loss = tf.reduce_sum(tf.log(tf.nn.sigmoid(delta)), axis=-1)
with tf.Session() as sess:
logits, logits1, logits2, delta, loss = sess.run([logits, logits1, logits2,
delta, loss])
print 'logits\n', logits
print 'logits2\n', logits2
print 'logits1\n', logits1
print 'delta\n', delta
print 'loss\n', loss
结果:
logits
[[ 0.61241663 0.70075285 0.98333454 0.4117974 0.5943476 0.84245574]
[ 0.02499413 0.22279179 0.70742595 0.34853518 0.7837007 0.88074362]
[ 0.35030317 0.36670768 0.64244425 0.87957716 0.22823489 0.45076978]
[ 0.38116801 0.39040041 0.82510674 0.64789391 0.45415008 0.03520513]]
logits2
[[ 0.70075285 0.98333454 0.4117974 0.5943476 0.84245574]
[ 0.22279179 0.70742595 0.34853518 0.7837007 0.88074362]
[ 0.36670768 0.64244425 0.87957716 0.22823489 0.45076978]
[ 0.39040041 0.82510674 0.64789391 0.45415008 0.03520513]]
logits1
[[ 0.61241663 0.70075285 0.98333454 0.4117974 0.5943476 ]
[ 0.02499413 0.22279179 0.70742595 0.34853518 0.7837007 ]
[ 0.35030317 0.36670768 0.64244425 0.87957716 0.22823489]
[ 0.38116801 0.39040041 0.82510674 0.64789391 0.45415008]]
delta
[[ 0.08833623 0.28258169 -0.57153714 0.18255019 0.24810815]
[ 0.19779766 0.48463416 -0.35889077 0.43516552 0.09704292]
[ 0.01640451 0.27573657 0.23713291 -0.65134227 0.22253489]
[ 0.0092324 0.43470633 -0.17721283 -0.19374382 -0.41894495]]
loss
[-3.41376281 -3.11249781 -3.49031925 -3.69255161]
推荐阅读
- php - 提取 7z 扩展 - PHP
- amazon-web-services - Alexa,处理您的输入时出现问题
- javascript - SyntaxError:关于从控制器传递的 Gson 字符串变量的输入意外结束
- sql-server - 如何使用仅给定名称(sql)的现有表的结构创建临时表
- c# - 尝试反序列化 PlayerStats 时获取 JsonSerializationException
- node.js - 节点 js 不允许我使用代理
- php - 从字符串中删除一个以 \U000 开头的单词
- c# - 无法从用法中推断方法“xyz”的类型参数
- polymer - 屏幕阅读器在悬停时未读取带有 aria-labelledby 参考的单选按钮
- ruby-on-rails - 如何在 Rails lib 子文件夹中修补数组?