首页 > 解决方案 > “没有这样的文件或目录”使用 TFRecordWriter

问题描述

我正在按照教程使用 Tensorflow 构建自定义对象检测器,并且我正在使用 google 驱动器中的文件在 google colab 上运行所有内容。

应该写入 TF 记录的代码部分是:

DATA_BASE_PATH = '/gdrive/"My Drive"/object_detection/data/'
image_dir = DATA_BASE_PATH +'images/'

def class_text_to_int(row_label):
        if row_label == 'pistol':
                return 1
        else:
                None

def split(df, group):
        data = namedtuple('data', ['filename', 'object'])
        gb = df.groupby(group)
        return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]

def create_tf_example(group, path):
        with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
                encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        width, height = image.size

        filename = group.filename.encode('utf8')
        image_format = b'jpg'
        xmins = []
        xmaxs = []
        ymins = []
        ymaxs = []
        classes_text = []
        classes = []

        for index, row in group.object.iterrows():
                xmins.append(row['xmin'] / width)
                xmaxs.append(row['xmax'] / width)
                ymins.append(row['ymin'] / height)
                ymaxs.append(row['ymax'] / height)
                classes_text.append(row['class'].encode('utf8'))
                classes.append(class_text_to_int(row['class']))

        tf_example = tf.train.Example(features=tf.train.Features(feature={
                'image/height': dataset_util.int64_feature(height),
                'image/width': dataset_util.int64_feature(width),
                'image/filename': dataset_util.bytes_feature(filename),
                'image/source_id': dataset_util.bytes_feature(filename),
                'image/encoded': dataset_util.bytes_feature(encoded_jpg),
                'image/format': dataset_util.bytes_feature(image_format),
                'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
                'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
                'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
                'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
                'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
                'image/object/class/label': dataset_util.int64_list_feature(classes),
        }))
        return tf_example

for csv in ['train_labels', 'test_labels']:
  writer = tf.io.TFRecordWriter(DATA_BASE_PATH + csv + ".record")
  path = os.path.join(image_dir)
  examples = pd.read_csv(DATA_BASE_PATH + csv + '.csv')
  grouped = split(examples, 'filename')
  for group in grouped:
      tf_example = create_tf_example(group, path)
      writer.write(tf_example.SerializeToString())

  writer.close()
  output_path = os.path.join(os.getcwd(), DATA_BASE_PATH + csv + '.record')
  print('Successfully created the TFRecords: {}'.format(DATA_BASE_PATH +csv + '.record'))

当我尝试运行它时,我收到以下消息:

---------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
<ipython-input-20-3f2d9f35e22c> in <module>()
     58 
     59 for csv in ['train_labels', 'test_labels']:
---> 60   writer = tf.io.TFRecordWriter(DATA_BASE_PATH + csv + ".record")
     61   path = os.path.join(image_dir)
     62   examples = pd.read_csv(DATA_BASE_PATH + csv + '.csv')

1 frames
/tensorflow-1.15.2/python3.6/tensorflow_core/python/lib/io/tf_record.py in __init__(self, path, options)
    216       # pylint: disable=protected-access
    217       self._writer = pywrap_tensorflow.PyRecordWriter_New(
--> 218           compat.as_bytes(path), options._as_record_writer_options(), status)
    219       # pylint: enable=protected-access
    220 

/tensorflow-1.15.2/python3.6/tensorflow_core/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    554             None, None,
    555             compat.as_text(c_api.TF_Message(self.status.status)),
--> 556             c_api.TF_GetCode(self.status.status))
    557     # Delete the underlying status object from memory otherwise it stays alive
    558     # as there is a reference to status from this from the traceback due to

NotFoundError: /gdrive/"My Drive"/object_detection/data/train_labels.record; No such file or directory

我尝试稍微调整一下代码,但没有奏效,而且我是 tensorflow 的初学者,所以我不想把所有事情都搞砸。有任何想法吗?

标签: pythontensorflow

解决方案


我看到数据库路径的路径应该从 DATA_BASE_PATH = '/gdrive/"My Drive"/object_detection/data/' 更正为 DATA_BASE_PATH = '/gdrive/My Drive/object_detection/data/'

让我知道它是否不起作用。当我进行上述更改时,它对我有用。


推荐阅读