首页 > 解决方案 > 如何将整数序列转换为 .tfrecords 并返回数据集

问题描述

我在将 .csv 文件转换为 .tfrecords 文件,然后读取所述文件以创建数据集时遇到问题。或者更准确地说,是一个数据集,它以我可以使用的形式为我提供了功能。

我有这样的 .csv 文件:

Feature1,Feature2,...,Feature50,Label
    5   ,   19   ,...,    17   ,  0

第一行当然是标题行。它是五十个整数,标签是 0 或 1。我正在逐行读取它并将其写入 .tfrecords 文件,如下所示:

    with tf.python_io.TFRecordWriter(self.abs_write_train_file_path) as writer:
        for row in self.train_file:
            features, label = row[0:50], row[50]
            self.example = tf.train.Example(features=tf.train.Features(feature={
                'features': tf.train.Feature(int64_list=tf.train.Int64List(value=[item for item in features])),
                'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))
            writer.write(self.example.SerializeToString())

这给了我一个这种格式的例子:

features {
  feature {
    key: "features"
    value {
      int64_list {
    value: 5
    value: 19
   ...(50 values all together)...
    value: 17
      }
    }
  }
  feature {
    key: "labels"
    value {
      int64_list {
        value: 0
      }
    }
  }
}

现在我正在尝试像这样使用这个文件:

import tensorflow as tf

COLUMNS = []
for _ in range(1, 51):
    COLUMNS.append('Feature'+str(_))
COLUMNS.append('Label')

def get_dataset(file_path):
    dataset = tf.data.TFRecordDataset([file_path])
    dataset = dataset.map(parse_function)
    return dataset

def parse_function(example_proto):
    features= {
        'Features': tf.FixedLenFeature((50), tf.int64),
        'Labels': tf.FixedLenFeature((), tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['Features'], parsed_features['Labels']

def train_input_fn():
    train_dataset = get_dataset(train_filepath)
    iterator = train_dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return {'features': [features]}, labels

feature_columns = [tf.feature_column.numeric_column(k) for k in COLUMN]
classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns)
classifier.train(input_fn=train_input_fn)

可悲的是 - 但可以理解,鉴于 .tfrecords 文件中的字典 - 这是错误:

Traceback (most recent call last):
  File "c:\Users\REDACTED\Neural Net\test.py", line 53, in <module>
input_fn=train_input_fn
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 355, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 824, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 805, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 318, in _model_fn
config=config)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 158, in _linear_model_fn
logits = logit_fn(features=features)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 99, in linear_logit_fn
cols_to_vars=cols_to_vars)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 433, in linear_model
trainable=trainable)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1710, in _create_weighted_sum
trainable=trainable)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1719, in _create_dense_column_weighted_sum
trainable=trainable)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 2083, in _get_dense_tensor
return inputs.get(self)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1886, in get
transformed = column._transform_feature(self)  # pylint: disable=protected-access
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 2051, in _transform_feature
input_tensor = inputs.get(self.key)
  File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1882, in get
raise ValueError('Feature {} is not in features dictionary.'.format(key))
ValueError: Feature Feature1 is not in features dictionary.

问题是,错误在哪里。创建 .tfrecords 文件是由本书严格完成的,写入其中的数据采用指定用于写入 .tfrecords 文件的格式。另一方面,从 .tfrecords 文件中读取应该很容易,并且不需要太多解析,特别是如果您只想使用 tensorflow 的高级 API。此外,对于较新的版本,在谷歌上找不到序列标签并不常见,几乎每个教程都使用过时的 tf-versions 或只是再次解释 tf-tutorials(顺便说一句,tf-api-guides 的 imo 文档记录非常糟糕),它使用来自 mnist 之类的下载数据。

帮助?

标签: pythoncsvtensorflowdataset

解决方案


推荐阅读