python - tensorflow TFRecords 无法解析序列化示例
问题描述
尝试解码 TFRecord 示例时出现以下错误:
InvalidArgumentError:名称:,功能:相关性(数据类型:float)是必需的,但找不到。[操作:解析示例]
要解码示例,我使用tf.io.parse_example
如下
example_features = tf.compat.v1.io.parse_example(
tf.reshape(serialized_list, [-1]), peritem_feature_spec)
serialized_list
看起来像哪里
[ <example_object>, ...
b'\n\xcc\x01\n\x15\n\trelevance\x12\x08\x12\x06\n\x04\x9a\xe9`D\n\xb2
\x01\n\x13encoded_clust_index\x12\x9a\00\ <more...>]
看起来peritem_feature_spec
像
peritem_feature_spec = {
'relevance':tf.FixedLenFeature([], tf.float32),
'encoded_clust_index':tf.VarLenFeature(tf.float32)
}
我很困惑为什么找不到功能“相关性”。我认为我正确编码并创建了我的 TFRecord 对象。我是否错误地创建了 feature_spec?我的想法是tf.VarLenFeature
使用不正确的功能类型,但我无法弄清楚什么是正确的。
使用tensorflow_ranking.python.data.parse_from_example_in_example
能够将 TFRecord 正确解码为其特征,但我不知道为什么会tf.io.parse_example
失败
解决方案
编辑:我的旧答案低于这个答案
正确的答案是提供default_value
功能规范
peritem_feature_spec = {
'relevance':tf.FixedLenFeature([], tf.float32, default_value=0.0),
'encoded_clust_index':tf.VarLenFeature(tf.float32, default_value=0.0)
}
我的旧(有效,但不正确)答案如下
所以问题归结为如何在 tensorflow_ranking 库中填充我的功能。它正在填充列表功能,例如:
def pad_fn():
return tf.pad(
tensor=serialized_list,
paddings=[[0, 0], [0, list_size -cur_list_size]],
constant_values="")
此方法将空字节附加到张量的末尾。解析器feature_name
在空张量中查找 a,并响应它找不到它。我的解决方法是附加序列化的 TFRecord 示例原型而不是空字节字符串。我是这样完成的:
def pad_fn():
# Create feature spec for tf.train.Example to append
pad_spec = {}
# Default values are 0 or an empty byte string depending on
# original serialized data type
dtype_map = {tf.float32:tf.train.Feature(
float_list=tf.train.FloatList(value=[0.0])),
tf.int32:tf.train.Feature(
int64_list=tf.train.Int64List(value=[0])),
tf.string:tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes('', encoding='UTF-8')]))}
# Create the feature spec
for key, item in peritem_feature_spec.items():
dtype = item.dtype
pad_spec[key] = dtype_map[dtype]
# Make and serialize example to append
constant_values = tf.train.Example(
features=tf.train.Features(feature=pad_spec))
constant_val_str = constant_values.SerializeToString()
# Add serialized padding to end of list
return tf.pad(
tensor=serialized_list,
paddings=[[0, 0], [0, list_size - cur_list_size]],
constant_values=constant_val_str)
推荐阅读
- node.js - 如何从 Flutter 生成 access_token 并将其发送到节点服务器?
- python - 我的 matplotlib 脚本的性能很差
- django - 'RuntimeError: populate() 不可重入'
- stm32 - STM32F103 无法从应用程序跳转到引导加载程序
- javascript - 如何使用 JavaScript 用配置文件中的数据填充复选框?
- php - php strtotime 'now' 和 '-1 month' 显示相同的日期
- linux - 在 hadoop 集群上设置气流时面临 GCC 安装问题
- build - 生产构建中 GraphQL 和 Next JS 的问题
- python - 如何修复 Tkinter error: can't delete Tcl command 这个错误?
- java - Spock Spy/Mock 未注册调用