首页 > 解决方案 > Mypy:Tensorflow TFRecords `Bytes` 对象的 `mypy` 类型

问题描述

这是一个奇怪的问题。所以我正在创建一些 Tensorflow TFRecords 文件来编码数据。我想检查mypy单个记录文件的类型,该文件被编码为二进制字符串。

现在,当我运行下面的代码并检查type()它所指示的字符串时<class 'bytes'>。但是当我使用mypy reveal_type()它时,它表示error: Revealed type is 'Any',所以它看起来像是mypy不识别字节类型。这有意义吗?我真的不想编码一些东西,Any因为这并不能真正帮助捕捉我想要捕捉的错误类型mypy.

这是我用来生成错误的示例代码。我从新的 tensorflow TFRecords 指南中获取了代码,但最后几行是我自己的。

import tensorflow as tf
import numpy as np
import typing

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# the number of observations in the dataset
n_observations = int(1e4)

# boolean feature, encoded as False or True
feature0 = np.random.choice([False, True], n_observations)

# integer feature, random from 0 .. 4
feature1 = np.random.randint(0, 5, n_observations)

# string feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)


def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.Example message ready to be written to a file.
  """

  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.

  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
reveal_type(serialized_example)

print(type(serialized_example)) 

标签: pythonpython-3.xtensorflowbinarymypy

解决方案


推荐阅读