首页 > 解决方案 > 使用带有对象检测的数据增强生成的图像数量

问题描述

我试图在文档、代码和此处搜索答案,但我没有运气。我想知道使用 Tensorflow 中的对象检测 API 进行数据增强生成的最终图像数量是多少。为了清楚起见,我举了一个例子:假设我有一个包含 2 个类的数据集,每个类最初都有 50 个图像。然后我应用这个配置:

  data_augmentation_options {
    ssd_random_crop {
    }
  }

  data_augmentation_options {
    random_rgb_to_gray {
    }
  }

  data_augmentation_options {
    random_distort_color {
    }
  }

  data_augmentation_options {
    ssd_random_crop_pad_fixed_aspect_ratio {
    }
  }

我如何知道为训练我的模型而生成的最终图像数量?(如果有办法)。顺便说一句,我正在使用 model_main.py 来训练我的模型。

提前致谢。

标签: tensorflowobject-detection-api

解决方案


在文件inputs.py中,可以在函数中看到augment_input_fn所有数据增强选项都传递给preprocessor.preprocess方法。详细信息都在文件preprocessor.py中,特别是在 function 中preprocess

for option in preprocess_options:
  func, params = option
  if func not in func_arg_map:
    raise ValueError('The function %s does not exist in func_arg_map' %
                   (func.__name__))
  arg_names = func_arg_map[func]
  for a in arg_names:
    if a is not None and a not in tensor_dict:
      raise ValueError('The function %s requires argument %s' %
                     (func.__name__, a))

  def get_arg(key):
    return tensor_dict[key] if key is not None else None

  args = [get_arg(a) for a in arg_names]
  if (preprocess_vars_cache is not None and
      'preprocess_vars_cache' in inspect.getargspec(func).args):
    params['preprocess_vars_cache'] = preprocess_vars_cache
  results = func(*args, **params)
  if not isinstance(results, (list, tuple)):
    results = (results,)
  # Removes None args since the return values will not contain those.
  arg_names = [arg_name for arg_name in arg_names if arg_name is not None]
  for res, arg_name in zip(results, arg_names):
    tensor_dict[arg_name] = res

请注意,在上面的代码中,arg_names包含所有原始图像名称,这意味着每个增强选项将仅对原始图像执行(而不是在先前增强选项后获得的图像)。

同样在preprocessor.py中,我们可以看到每个增强选项只会生成与原始图像形状相同的图像。

因此,在您的情况下,四个选项和 100 个原始图像,400 个增强图像将被添加到tensor_dict.


推荐阅读