首页 > 解决方案 > 不能将 tf.float64 dtype 与 tfrecord 一起使用

问题描述

我正在尝试使用 tfrecord 编写和读取 float64 类型的示例,这是示例代码:

import tensorflow as tf

# write
def _float_feature(value):
  """Returns a floast_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def parse(fea):
  data = {'fea' : _float_feature(fea)}
  #create an Example, wrapping the single features
  example = tf.train.Example(features=tf.train.Features(feature=data))

  return example

feas = [1.0, 2.0]

with tf.io.TFRecordWriter('my_example.tfrecords') as writer:
    for fea in feas:
        example = parse(fea)
        writer.write(example.SerializeToString())


# read
def parse_one_tfrecord(element):
    #use the same structure as above; it's kinda an outline of the structure we now want to create
    data = {
      'fea':tf.io.FixedLenFeature([], tf.float64),
    }

    content = tf.io.parse_single_example(element, data)
    fea = content["fea"]

    return fea

for sample_idx, sample in enumerate(tf.data.TFRecordDataset('my_example.tfrecords').map(parse_one_tfrecord)):
    print(sample_idx, sample)

读取写入的tfrecords时,出现如下异常:

TypeError: in user code:

    <ipython-input-5-c24b1277299d>:7 parse_one_tfrecord  *
        content = tf.io.parse_single_example(element, data)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\ops\parsing_ops.py:452 parse_single_example_v2
        return parse_example_v2(serialized, features, example_names, name)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
        return target(*args, **kwargs)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\ops\parsing_ops.py:314 parse_example_v2
        outputs = _parse_example_raw(serialized, example_names, params, name=name)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\ops\parsing_ops.py:362 _parse_example_raw
        name=name)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\ops\gen_parsing_ops.py:772 parse_example_v2
        dense_shapes=dense_shapes, name=name)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\framework\op_def_library.py:650 _apply_op_helper
        param_name=input_name)
    c:\users\admin\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\framework\op_def_library.py:63 _SatisfiesTypeConstraint
        ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))

    TypeError: Value passed to parameter 'dense_defaults' has DataType float64 not in list of allowed values: float32, int64, string

在这种情况下我不能使用 tf.float64 dtype 吗?如果没有,我如何将 tf.float64 dtype 与 tfrecord 一起使用?

标签: pythontensorflow

解决方案


推荐阅读