首页 > 解决方案 > 是否可以根据 Keras 中批次中样本输出的差异编写自定义损失函数?

问题描述

我正在尝试在 Keras 中实现损失函数,它可以执行以下操作:

假设 y0, y1, ..., yn 是批量输入 x0, x1, ..., xn 的模型批量输出,这里 batch_size 是 n+1,每个 xi 的输出 yi 是一个标量值,我想要什么用于计算该批次的整体损失的损失函数如下:

K.log(K.sigmoid(y1-y0))+K.log(K.sigmoid(y2-y1))+...+K.log(K.sigmoid(yn-yn-1))

我正在考虑使用 Lambda 层首先将批量输出 [y0,y1,...,yn] 转换为 [y1-y0, y2-y1, ...,yn-yn-1],然后使用自定义损失对转换后的输出函数。

但是,我不确定 Keras 是否可以理解 Lambda 层中没有要更新的权重,并且我不清楚Keras 将如何通过 Lambda 层将梯度传播回来,因为 Keras 通常需要每个层/损失函数在其上运行单个样本输入,但我的层将获取一批样本的全部输出。以前有没有人解决过类似的问题?谢谢!

标签: tensorflowkeras

解决方案


像下面这样的切片对你有用吗(虽然我没有使用 keras)。

batch = 4
num_classes = 6
logits = tf.random.uniform(shape=[batch, num_classes])

logits1 = tf.slice(logits, (0, 0), [batch, num_classes-1])
logits2 = tf.slice(logits, (0, 1), [batch, num_classes-1])

delta = logits2 - logits1
loss = tf.reduce_sum(tf.log(tf.nn.sigmoid(delta)), axis=-1)

with tf.Session() as sess:
  logits, logits1, logits2, delta, loss  = sess.run([logits, logits1, logits2, 
                                                     delta, loss])

  print 'logits\n', logits
  print 'logits2\n', logits2
  print 'logits1\n', logits1
  print 'delta\n', delta
  print 'loss\n', loss

结果:

logits
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits2
[[ 0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits1
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476 ]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007 ]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008]]
delta
[[ 0.08833623  0.28258169 -0.57153714  0.18255019  0.24810815]
 [ 0.19779766  0.48463416 -0.35889077  0.43516552  0.09704292]
 [ 0.01640451  0.27573657  0.23713291 -0.65134227  0.22253489]
 [ 0.0092324   0.43470633 -0.17721283 -0.19374382 -0.41894495]]
loss
[-3.41376281 -3.11249781 -3.49031925 -3.69255161]

推荐阅读