首页 > 解决方案 > 如何将字符串数据保存到 TFRecord?

问题描述

保存到 TFRecord 时,我使用:

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

one_example = tf.train.Example(
    features=tf.train.Features(
        feature={
            "image": _bytes_feature(img.tobytes()),
            "label": _bytes_feature(label.tobytes()),
            "file_name": _bytes_feature(this_city_file_name), #this line doesn't work
            "nb_rows": _int64_feature(nb_rows), 
            "nb_cols": _int64_feature(nb_cols), 
            "index_i": _int64_feature(i), 
            "index_j": _int64_feature(j),
        }
    )
)

并且this_city_file_name在我运行此代码时具有一种字符串,这会导致错误:

TypeError: 'xxxxxxx' 具有类型,但预期为以下之一:((,),)

简单地使用 bytes(this_city_file_name) 也会导致错误:

TypeError:没有编码的字符串参数

从 TFRecord 加载时,我使用

features = tf.parse_single_example(serialized_example,
                                   features={
                                       "image": tf.FixedLenFeature([], tf.string),
                                       "label": tf.FixedLenFeature([], tf.string),
                                       "file_name": tf.FixedLenFeature([], tf.string),
                                       "nb_rows": tf.FixedLenFeature([], tf.int64),
                                       "nb_cols": tf.FixedLenFeature([], tf.int64),
                                       "index_i": tf.FixedLenFeature([], tf.int64),
                                       "index_j": tf.FixedLenFeature([], tf.int64),
                                   },
                                   )

我知道如何将 int 和 np.array 类型保存到 TFRecord 并从中读取,但是如何从 TFRecord 保存和加载字符串数据?

标签: pythontensorflowtfrecord

解决方案


我知道这很旧,但是您必须转换this_city_file_name为字节对象。查看本指南

以下是相关代码:

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

推荐阅读