tensorflow - deeplab 我的自定义数据集的权重标准是什么?
问题描述
我正在通过在三个类中制作自定义数据集来训练Deeplab v3,包括背景
然后,我的班级是背景,熊猫,瓶子,有1949张图片。
我正在使用moblienetv2模型
和segmentation_dataset.py已修改如下。
_MYDATA_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 975, # num of samples in images/training
'trainval': 1949,
'val': 974, # num of samples in images/validation
},
num_classes=3,
ignore_label=0,
)
train.py已修改如下。
flags.DEFINE_boolean('initialize_last_layer', False,
'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', True,
'Only consider logits as last layers or not.')
train_utils.py没有被修改。
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels, ignore_label)) * loss_weight
我得到了一些结果,但不是完美的。
例如,熊猫和瓶子的面具颜色相同或不同
我想要的结果是红色的熊猫和绿色的瓶子
所以,我判断重量有问题。
根据其他人的问题,train_utils.py配置如下
irgore_weight = 0
label0_weight =1
label1_weight = 10
label2_weight = 15
not_ignore_mask =
tf.to_float(tf.equal(scaled_labels, 0)) * label0_weight +
tf.to_float(tf.equal(scaled_labels, 1)) * label1_weight +
tf.to_float(tf.equal(scaled_labels, 2)) * label2_weight +
tf.to_float(tf.equal(scaled_labels, ignore_label)) * irgore_weight
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
我在这里有个问题。
重量的标准是什么?
我的数据集包括以下内容。
它是自动生成的,所以我不确切知道哪个更多,但数量差不多。
还有一件事,我正在使用 Pascal 的颜色映射类型。
这是第一个黑色背景和第二个红色第三个绿色。
我想准确地将熊猫指定为红色,将瓶子指定为绿色。我应该怎么办?
解决方案
我认为您可能混淆了标签定义。也许我可以帮你。请再次检查您的segmentation_dataset.py。在这里,您将“0”定义为被忽略的标签。这意味着所有标记为“0”的像素都被排除在训练过程之外(更具体地说,在损失函数的计算中被排除在外,因此对权重的更新没有影响)。鉴于这种情况,重要的是不要“忽略”背景类,因为它也是您想要正确预测的类。在train_utils.py 中,您为被忽略的类分配了一个权重因子,这将不起作用---> 确保不要将三个训练类 [background、panada、bottle] 与“ignored”标签混为一谈。
在您的情况下num_classes =3 应该是正确的,因为它指定了要预测的标签数量(模型自动假设这些标签是 0、1 和 2。如果您想忽略某些标签,您必须使用第四个标签类来注释它们(只需为此选择一个> 2的数字)然后将此标签分配给ignore_label。如果您没有要忽略的像素仍然设置ignore_label = 255,它不会影响您的训练;)
推荐阅读
- java - 从java中的类名实例化类
- css - 覆盖 CSS 嵌套属性
- c++ - Boost::asio::streambuf consume() 不会清空缓冲区
- r - 如何为fluidRow titlePanel R Shiny设置多行?
- r - 在 R 中使用 For 循环匹配负值和正值
- javascript - jQuery AJAX GET 请求未到达 Node.JS(使用 NGINX)
- javascript - C3中水平条形图标签中的长字符串
- google-sheets - 如何连接按第二列中的值分组的多个单元格
- javascript - 在 aws lambda 上为 python 脚本使用 child_process spawn
- go - Go:将 uuid.UUID (satori) 类型的 reflect.Value 再次转换回 uuid.UUID