python - 如何在 tensorflow.keras 模型中正确使用自定义损失(例如骰子系数)?
问题描述
我正在尝试将自定义损失或度量函数输入到 tensorflow.keras 模型编译中,但出现错误。看起来输入没有正确传递。
我已经能够使用以下张量流损失运行模型编译/拟合:
loss = tf.keras.losses.MeanSquaredError() # Breaks if I remove ()
loss = tf.nn.sigmoid_cross_entropy_with_logits # Breaks if I add ()
我对使这些工作的语法差异感到困惑(可能是不正确的?),如果我使用自定义骰子损失,它不适用于任何一种语法。
当我在下面运行自定义骰子损失时,输入标签正确传递,batch_size*height*width
但输入 logits 传递为None,None,None,None
(似乎不正确?)并且骰子损失函数出错。我正在寻找批量优化,所以损失应该由model.fit
.
def generalized_dice(labels, logits):
smooth = 1e-17
shape = tf.TensorShape(logits.shape).as_list()
depth = int(shape[-1])
labels = tf.one_hot(labels, depth, dtype=tf.float32)
logits = tf.nn.softmax(logits)
weights = 1.0 / (tf.reduce_sum(labels, axis=[0, 1, 2])**2)
numerator = tf.reduce_sum(labels * logits, axis=[0, 1, 2])
numerator = tf.reduce_sum(weights * numerator)
denominator = tf.reduce_sum(labels + logits, axis=[0, 1, 2])
denominator = tf.reduce_sum(weights * denominator)
loss = 2.0*(numerator + smooth)/(denominator + smooth)
return loss
def generalized_dice_loss(dice):
return 1-dice
model = tf.keras.Model(inputs=[input_x], outputs=[predictions])
loss = tf.keras.losses.MeanSquaredError() # Breaks if I remove (), with inputs being passed as None's
# loss = tf.nn.sigmoid_cross_entropy_with_logits
metric = generalized_dice # If no parantehses, labels passed as None's;
# Both tf.nn.sigmoid_cross_entropy_with_logits and generalized_dice don't work if I add () saying inputs and labels must be provided; since I am training in batches, these need to be provided by the model during training, not when I am defining loss?
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=loss,
metrics=[metric])
由于 dice loss 的输入作为 None 传递,
labels = tf.one_hot(labels, depth, dtype=tf.float32)
结果是:
TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64
解决方案
推荐阅读
- django - 如何在 Django 中过滤两个模型(一对一)
- ios - 当从同一目标下的目标 c 文件访问时,在 Swift 文件中设置的用户默认值返回 nil
- python - 使用 matplotlib 的毫米波动画
- c# - 在计算 Vector3.Distance 时是否应考虑帧速率依赖性?
- uber-api - 优步访问菜单获取 API 的问题
- java - 普通对象变量还是 JavaFX 对象属性?
- c++ - 在 C++ 类中初始化 C 数组
- sql - 即使有数据查询也不给出输出
- java - 如何将 XML 转换为具有二进制数据内容的 Json
- javascript - 页面转换时清除 redux 状态