python - 用于 YOLO 的 Keras 自定义损失函数
问题描述
我正在尝试在 Keras 中定义自定义损失函数
def yolo_loss(y_true, y_pred):
这里 y_true 和 y_pred 的形状是 [batch_size,19,19,5]。
对于批次中的每个图像,我想将损失计算为:
loss =
square(y_true[:,:,0] - y_pred[:,:,0])
+ square(y_true[:,:,1] - y_pred[:,:,1])
+ square(y_true[:,:,2] - y_pred[:,:,2])
+ (sqrt(y_true[:,:,3]) - sqrt(y_pred[:,:,3]))
+ (sqrt(y_true[:,:,4]) - sqrt(y_pred[:,:,4]))
我想了几种方法来做到这一点,
1)使用for循环:
def yolo_loss(y_true, y_pred):
y_ret = tf.zeros([1,y_true.shape[0]])
for i in range(0,int(y_true.shape[0])):
op1 = y_true[i,:,:,:]
op2 = y_pred[i,:,:,:]
class_error = tf.reduce_sum(tf.multiply((op1[:,:,0]-op2[:,:,0]),(op1[:,:,0]-op2[:,:,0])))
row_error = tf.reduce_sum(tf.multiply((op1[:,:,1]-op2[:,:,1]),(op1[:,:,1]-op2[:,:,1])))
col_error = tf.reduce_sum(tf.multiply((op1[:,:,2]-op2[:,:,2]),(op1[:,:,2]-op2[:,:,2])))
h_error = tf.reduce_sum(tf.abs(tf.sqrt(op1[:,:,3])-tf.sqrt(op2[:,:,3])))
w_error = tf.reduce_sum(tf.abs(tf.sqrt(op1[:,:,4])-tf.sqrt(op2[:,:,4])))
total_error = class_error + row_error + col_error + h_error + w_error
y_ret[0,i] = total_error
return y_ret
然而,这给了我一个错误:
ValueError:无法将部分已知的 TensorShape 转换为张量:(1,?)
这是因为我猜批量大小是未定义的。
2)另一种方法是将sqrt变换应用于批处理中的每个图像张量,然后将它们相减,然后应用平方变换。
例如
1) sqrt(y_true[:,:,:,3])
2) sqrt(y_pred[:,:,:,3])
3) sqrt(y_true[:,:,:,4])
4) sqrt(y_pred[:,:,:,4])
5) y_new = y_true-y_pred
6) square(y_new[:,:,:,0])
7) square(y_new[:,:,:,1])
8) square(y_new[:,:,:,2])
9) reduce_sum for each new tensor in the batch and return o/p in shape [1,batch_size]
但是我找不到在 Keras 中执行此操作的方法。
有人可以建议,实现此损失函数的最佳方法是什么。我在后端使用带有 tensorflow 的 Keras。
解决方案
你可以查看这个 git hub 页面。
https://github.com/experiencor/keras-yolo2
推荐阅读
- sql-server - UNION SELECT、并行化和 IDENTITY
- sql - ms-access中的sql
- html - 如何在水平模式下在全屏 WebApp 中运行时擦除移动 Safari 中的状态栏?
- javascript - 从 http 响应正文中获取特定值
- linux - Linux内核跟踪点:将探测函数连接到跟踪点时符号未定义
- mysql - 我如何在谷歌脚本应用程序 onEdit() 函数中使用 ajax 调用。(https://script.google.com)
- c# - 如何动态反序列化 Json 字符串
- c++ - 初始化启动对象时出现编译时错误
- ios - 可能是因为 afnetworking 而崩溃了
- mysql - 如何获取表 XYZ 的 ABC 数据库的数据库转储以及仅组织 ID“22”的数据库转储?